In this guide, we will implement a simple Generative AI model using Python. We will create a basic Generative Adversarial Network (GAN) to generate synthetic images. For this example, we will use the popular MNIST dataset, which consists of handwritten digits.

1. Setting Up the Environment

Before we start coding, ensure you have the necessary libraries installed. You can install them using pip:


pip install torch torchvision matplotlib

2. Importing Libraries

We will import the required libraries for building our GAN model:


import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt

3. Defining the Generator and Discriminator

Next, we will define the architecture for the generator and discriminator networks. The generator will take random noise as input and produce an image, while the discriminator will evaluate whether the image is real or fake.

Generator


class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 784),
nn.Tanh() # Output range [-1, 1]
)

def forward(self, z):
return self.model(z)

Discriminator


class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(784, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 1),
nn.Sigmoid() # Output range [0, 1]
)

def forward(self, img):
return self.model(img)

4. Loading the MNIST Dataset

We will use the MNIST dataset for training our GAN. The dataset will be loaded and transformed to fit the input requirements of our model.


# Load the MNIST dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1]
])

mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
data_loader = torch.utils.data.DataLoader(mnist_dataset, batch_size=64, shuffle=True)

5. Training the GAN

Now we will set up the training loop for our GAN. We will alternate between training the discriminator and the generator.


# Initialize models, optimizers, and loss function
generator = Generator()
discriminator = Discriminator()
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)

# Training loop
num_epochs = 50
for epoch in range(num_epochs):
for real_images, _ in data_loader:
batch_size = real_images.size(0)

# Train Discriminator
optimizer_D.zero_grad()
real_labels = torch.ones(batch_size, 1) # Real labels
fake_labels = torch.zeros(batch_size, 1) # Fake labels

outputs = discriminator(real_images.view(batch_size, -1)) # Flatten images
d_loss_real = criterion(outputs, real_labels)

z = torch.randn(batch_size, 100) # Random noise
fake_images = generator(z)
outputs = discriminator(fake_images.detach())
d_loss_fake = criterion(outputs, fake_labels)

d_loss = d_loss_real + d_loss_fake
d_loss.backward()
optimizer_D.step()

# Train Generator
optimizer_G.zero_grad()
outputs = discriminator(fake_images)
g_loss = criterion(outputs, real_labels) # Want to fool the discriminator
g_loss.backward()
optimizer_G.step()

# Print losses
if epoch % 10 == 0:
print(f'Epoch [{epoch}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')

<h2>6. Generating Images</h2>
<p>After training, we can generate new images using the trained generator. We will sample random noise and pass it through the generator to create synthetic images.</p>
<pre><code class="python">
# Generate images
with torch.no_grad():
z = torch.randn(16, 100) # Sample random noise
generated_images = generator(z).view(-1, 1, 28, 28) # Reshape to image format

# Plot generated images
grid = torchvision.utils.make_grid(generated_images, nrow=4, normalize=True)
plt.imshow(grid.permute(1, 2, 0).numpy())
plt.title('Generated Images')
plt.axis('off')
plt.show()

7. Conclusion

In this guide, we implemented a simple Generative Adversarial Network (GAN) using Python and PyTorch. We defined the generator and discriminator architectures, loaded the MNIST dataset, and trained the GAN to generate synthetic images of handwritten digits. This basic implementation can be further enhanced by experimenting with different architectures, hyperparameters, and datasets.