The overview of GAN - Part 2

Learn about generative and discriminant models, overview architecture.
type: insightlevel: medium

In the previous article, we mentioned two main models in the architecture of GAN, the generation model and the discriminant model. To understand them better, in this article we will introduce them in more detail.

Figure 1

I. Generator

Generative Model is a type of model in the field of Machine Learning and Deep Learning, which is used to generate new data with properties similar to the original training data. The goal of the generation model is to learn and understand the structure and characteristics of the training data and then generate new data samples based on the learned knowledge.

This model not only has applications in creating new images, sounds, and texts, but can also be applied in many other fields such as creative arts, design, research and practical applications. . The advancement in the field of biomodeling continues to open up many opportunities and challenges for AI and Deep Learning.

Figure 1

class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(100, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, 784),
nn.Tanh()
)
def forward(self, x):
x = self.model(x)
return x

II. Discriminator

Discriminator is an important component in the GAN (Generative Adversarial Networks) model. The main task of the Discriminator is to distinguish between real data and fake data generated by the generator network.

When training the GAN model, the Discriminator is provided with real data and dummy data as input. The task of the Discriminator is to make a probabilistic prediction that a data sample is real or fake. The Discriminator tries to maximize the accuracy of the distinction between the two data types, while the Generator tries to generate dummy data that the Discriminator cannot distinguish.

Figure 1

GAN training takes place through iteratively updating the Discriminator and Generator. Discriminator has been updated to improve its ability to distinguish between real data and fake data, while the Generator has been updated to produce better fake data, thereby bypassing Discriminator.

Since these two components work together and compete with each other, the GAN model is able to generate high-quality fake data that is close to the real data. This makes it increasingly difficult for the Discriminator to distinguish between real and fake data, and to create the characteristics and distributions of the original data.

class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 1),
nn.Sigmoid()
)
def forward(self, x):
x = self.model(x)
return x

III. Loss Function

The loss function of the liver model is a function that combines the Discriminator target and the Generator target simultaneously.

$\underset{G}{\min}\underset{D}{\max}V(D,G) = E{\textbf{x}\sim p{data}(\textbf{x})}[logD(\textbf{x})] + E_{\textbf{z}\sim p_z(\textbf{z})}[log(1 - D(G(\textbf{z})))] $

Let's analyze this complex loss function together:

  • Generator network symbol is GG, Discriminator network is DD.
  • The symbol G(z)G(z) is the image generated from the Generator.
  • The symbol D(x)D(x) is the discriminator's prediction value whether the image xx is real or not.
  • The symbol D(G(z))D(G(z)) is the value to predict whether the image generated from the Generator is a real image or not.
  • The symbol EE is the expectation, simply understood as taking the average of all data or maximize D(x)D(x) with xx being the data in the traning set.

From the loss function above, it can be seen that training Generator and Discriminator are opposite, while DD tries to maximize loss, GG tries to minimize loss. The GAN training process ends when the GAN model reaches an equilibrium of the two models, called Nash equilibrium.

# Initialize the Generator network and the Discriminator network
generator = Generator()
discriminator = Discriminator()
# Define loss function and optimizer
criterion = nn.BCELoss()
optimizer_generator = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=0.0002)
# Train model GAN
num_epochs = 100
for epoch in range(num_epochs):
for batch_idx, (real_images, _) in enumerate(train_loader):
# Determine batch size and prepare training data
batch_size = real_images.size(0)
real_images = real_images.view(-1, 784)
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
# Discriminator network training
discriminator.zero_grad()
outputs_real = discriminator(real_images)
loss_real = criterion(outputs_real, real_labels)
z = torch.randn(batch_size, 100)
fake_images = generator(z)
outputs_fake = discriminator(fake_images)
loss_fake = criterion(outputs_fake, fake_labels)
loss_discriminator = loss_real + loss_fake
loss_discriminator.backward()
optimizer_discriminator.step()
# Generator network training
generator.zero_grad()
z = torch.randn(batch_size, 100)
fake_images = generator(z)
outputs = discriminator(fake_images)
loss_generator = criterion(outputs, real_labels)
loss_generator.backward()
optimizer_generator.step()

IV. Evaluation and Benchmarking

Generative Adversarial Networks (GAN) model evaluation is an important process to evaluate the performance and quality of the trained model. Here are some common evaluation methods for GAN models:

  1. Evaluating image quality: For imaging applications, a common assessment method is to use image quality indexes such as SSIM (Structural Similarity Index), PSNR (Peak Signal) -to-Noise Ratio), or FID (Fréchet Inception Distance). These indexes measure the similarity between the generated image and the real image, and the higher it is, the better the quality of the model.
  2. Assessment of new data generation: A good GAN model should be able to generate new and diverse data. To assess this, methods such as counting the number of different samples generated, or measuring sample diversity through indicators such as the Inception Score can be used.
  3. Assessment of learning ability and stability: A good GAN model should have fast and stable learning ability. This evaluation can be done by monitoring the evolution of the loss function of the Generator and the Discriminator during training, ensuring that it converges to a stable value and attains a good balance. equals between the two components.
  4. Evaluate the interaction between Generator and Discriminator: A good GAN model should have an effective interaction between Generator and Discriminator. This assessment can be done by looking at the discriminator's discriminant against spurious patterns generated by the Generator, and ensuring that the Generator is capable of bypassing the Discriminator.

Here is an example of evaluating GAN model by calculating FID (Frechet Inception Distance) measure using Inception V3 model to compute feature vectors from real and dummy images:

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
from scipy.linalg import sqrtm
from sklearn.metrics import pairwise_distances
def calculate_fid(real_images, fake_images, batch_size, device):
# Calculate feature vectors from real and fake images
inception_model = torchvision.models.inception_v3(pretrained=True, transform_input=False).to(device)
inception_model.eval()
with torch.no_grad():
real_features = []
fake_features = []
for i in range(0, len(real_images), batch_size):
real_batch = real_images[i:i+batch_size].to(device)
fake_batch = fake_images[i:i+batch_size].to(device)
real_features.append(inception_model(real_batch)[0].view(real_batch.size(0), -1))
fake_features.append(inception_model(fake_batch)[0].view(fake_batch.size(0), -1))
real_features = torch.cat(real_features, dim=0)
fake_features = torch.cat(fake_features, dim=0)
# Calculate mean and covariance matrix of feature vectors
real_mu, real_sigma = torch.mean(real_features, dim=0), torch_cov(real_features, rowvar=False)
fake_mu, fake_sigma = torch.mean(fake_features, dim=0), torch_cov(fake_features, rowvar=False)
# Calculate FID measure
fid = torch.norm(real_mu - fake_mu)**2 + torch.trace(real_sigma + fake_sigma - 2*sqrtm(real_sigma @ fake_sigma))
return fid.item()
# Load MNIST data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5,), std=(0.5,))
])
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=True)
# Initialize Model Generator and Discriminator
generator = Generator().to(device)
discriminator = Discriminator().to(device)
# Model evaluation
num_samples = 1000
noise = torch.randn(num_samples, 100).to(device)
fake_images = generator(noise)
fid = calculate_fid(test_dataset.data[:num_samples].float().unsqueeze(1).to(device), fake_images.detach().cpu(), 100, device)
print(f"FID: {fid:.4f}")

We will calculate the mean and the covariance matrix of the feature vectors to calculate the FID measure. The smaller the FID measure, the better the quality of the fake image.

References

  1. Generative Adversarial Networks, Ian J. Goodfellow
  2. Deep Learning Lectures, Generative Adversarial Networks
  3. GAN model, machinelearningmastery
  4. GAN model train on mnist, machinelearningmastery