GAN01: Introductory guide to Generative Adversarial Networks (GANs) and their promise!

引用:Introductory guide to Generative Adversarial Networks (GANs) and their promise!

What is a GAN?

Let us take an analogy to explain the concept:

如果你想在某件事上做到更好,例如下棋,你会怎么做? 你或许会找一个比自己厉害的对手. 然后你会在你们对决中分析你错的地方和他对的地方, 并思考在下一场对决中你该如何击败对手.

你会不断重复这个过程,知道你击败对手. 这个理论同样适用于与我们训练一个好模型. So simply, for getting a powerful hero (viz generator), we need a more powerful opponent (viz discriminator)!

How do GANs work?

如下图所示,GAN 由两部分组成:Generator Neural Network and Discriminator Neural Network.

The Generator Network($G(z)$) 接受随机噪声输入($z$ from $p(z)$)来产生假样本($g$). 随后送入 Discriminator Network $D(x)$. Discriminator Network 的任务是判断 real data 和 fake data 的真假. It takes an input $x$ from $P_{data}(x)$ where $P_{data}(x)$ is ourreal data distribution. $D(x)$ then solvers a binary classification problem using sigmoid function giving outpit in the range 0 to 1.

Now the training of GAN is done (as we saw above) as a fight between generator and discriminator. This can be represented mathematically as:

\begin{equation}
\label{a}
\mathop{min}\limits_{G} \mathop{max}\limits_{D}  V(D, G) \\
V(D, G) = E_{x \sim p_{data}(x)} [logD(x)] + E_{z \sim p_{z}(z)} [log(1-D(G(z))] 
\end{equation}

train discriminator stage: 从判别器 $D$ 角度来看,它希望能尽可能区分出真假样本,即 maximize $V(D, G)$ to 0。具体来说,它一边希望 $D(x)$ 尽可能大,即 maximize $D(x)$ to 1. 另一边则希望 $D(G(z))$ 尽可能小,即 maximize $D(G(z))$ it to 0 (i.e. the log probability that the data from generated is fake is equal to 0).

train generator stage:  从生成器 $G$ 角度来看,它希望能够以假乱真 ,即 minimize the function $V$ to -NaN。 具体来说就是希望 $D(G(z))$ 尽可能大,即 maximize $D(G(x))$ to 1  (this stage only have second term).

Note: This method of training a GAN is taken from game theory called the minimax game.

Parts of training GAN

So broadly a training phase has two main subparts and they are done sequentially:

  • Pass 1: Train discriminator and freeze generator (freezing means setting training as false. The network does only forward pass and no backpropagationn is applied)
  • Pass 2: Train generator and freeze discriminator

Steps to train a GAN

Step 1: Define the problem. Do you want to generate fake images or fake text. Here you should completely define the problem and collect data for it.

Step 2: Define architecture of GAN. Define how your GAN should look like. Should both your generator and discriminator be multi layer perceptrons, or convolutional neural networks? This step will depend on what problem you are trying to solve.

Step 3: Train Discriminator on real data for n epochs. Get the real data you want to generate fake on and train the discriminator to correctly predict them as real. Here value n can be any natural number between 1 and infinity.

Step 4: Generate fake inputs for generator and train Discriminator on fake data. Get generated data and let the discriminator correctly predict them as fake. (Step 3 and Step 4 are for Pass 1)

Step 5: Train Generator with the output of Discriminator. Now when the discriminator is trained, you can get its predictions and use it as an objective for training the generator. Train the generator to fool the discriminator. (This is Pass 2)

Step 6: Repeat step 3 to step 5 for a few epochs.

Step 7: Check if the fake data manually if it seems legit. If it seems appropriate, stop training, else go to step 3. This is a bit of a manual task, as hand evaluating the data is the best way to check the fakeness. When this step is over, you can evaluate whether the GAN is performing well enough.

Challenges with GANs

You may ask, if we know what could these beautiful creatures (monsters?) do; why haven’t something happened? This is because we have barely scratched the surface. There’s so many roadblocks into building a “good enough” GAN and we haven’t cleared many of them yet. There’s a whole area of research out there just to find “how to train a GAN

The most important roadblock while training a GAN is stability. If you start to train a GAN, and the discriminator part is much powerful that its generator counterpart, the generator would fail to train effectively. This will in turn affect training of your GAN. On the other hand, if the discriminator is too lenient; it would let literally any image be generated. And this will mean that your GAN is useless.

Another way to glance at stability of GAN is to look as a holistic convergence problem. Both generator and discriminator are fighting against each other to get one step ahead of the other. Also, they are dependent on each other for efficient training. If one of them fails, the whole system fails. So you have to make sure they don’t explode.

This is kind of like the shadow in Prince of Persia game . You have to defend yourself from the shadow, which tries to kill you. If you kill the shadow you die, but if you don’t do anything, you will definitely die!

There are other problems too, which I will list down here. (Reference: http://www.iangoodfellow.com/slides/2016-12-04-NIPS.pdf)

Note: Below mentioned images are generated by a GAN trained on ImageNet dataset.

  • Problem with Counting: GANs fail to differentiate how many of a particular object should occur at a location. As we can see below, it gives more number of eyes in the head than naturally present.
  • Problems with Perspective: GANs fail to adapt to 3D objects. It doesn’t understand perspective, i.e.difference between frontview and backview. As we can see below, it gives flat (2D) representation of 3D objects.
  • Problems with Global Structures: Same as the problem with perspective, GANs do not understand a holistic structure. For example, in the bottom left image, it gives a generated image of a quadruple cow, i.e. a cow standing on its hind legs and simultaneously on all four legs. That is definitely not possible in real life!

A substantial research is being done to take care of these problems. Newer types of models are proposed which give more accurate results than previous techniques, such as DCGAN, WassersteinGAN etc

Implementing a Toy GAN

pytorch implement

import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image

# Device configuration
device = torch.device(‘cuda‘ if torch.cuda.is_available() else ‘cpu‘)
print(torch.__version__, device)

# Hyper-parameters
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = ‘samples‘

# Create a directory if not exists
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

# Image processing
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.1307,),   # 3 for RGB channels
                                     std=(0.3081,))])

# MNIST dataset
mnist = torchvision.datasets.MNIST(root=‘H:/Other_DataSets/MNIST/‘,
                                   train=True,
                                   transform=transform,
                                   download=True)

# Data loader
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size,
                                          shuffle=True)

# Discriminator
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid())

# Generator
G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh())

# Device setting
D = D.to(device)
G = G.to(device)

# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

# Start training
total_step = len(data_loader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(data_loader):
        images = images.reshape(batch_size, -1).to(device)

        # Create the labels which are later used as input for the BCE loss
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # ================================================================== #
        #                      Train the discriminator                       #
        # ================================================================== #

        # Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
        # Second term of the loss is always zero since real_labels == 1
        outputs = D(images)  # batch x 1
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs

        # Compute BCELoss using fake images
        # First term of the loss is always zero since fake_labels == 0
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z) # batch x 784
        outputs = D(fake_images) # batch x 1
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs

        # Backprop and optimize
        d_loss = d_loss_real + d_loss_fake
        reset_grad()
        d_loss.backward()
        d_optimizer.step()

        # ================================================================== #
        #                        Train the generator                         #
        # ================================================================== #

        # Compute loss with fake images
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z) # batch x 784
        outputs = D(fake_images) # batch x 1

        # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
        # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
        g_loss = criterion(outputs, real_labels)

        # Backprop and optimize
        reset_grad()
        g_loss.backward()
        g_optimizer.step()

        if (i+1) % 200 == 0:
            print(‘Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}‘
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),
                          real_score.mean().item(), fake_score.mean().item()))

    # Save real images
    if (epoch+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(denorm(images), os.path.join(sample_dir, ‘real_images.png‘))

    # Save sampled images
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), os.path.join(sample_dir, ‘fake_images-{}.png‘.format(epoch+1)))

# Save the model checkpoints
torch.save(G.state_dict(), ‘G.ckpt‘)
torch.save(D.state_dict(), ‘D.ckpt‘)

tensorflow implement

Applications of GAN

We saw an overview of how these things work and got to know the challenges of training them. We will now see the cutting edge research that has been done using GANs

Increasing Resolution of an image

  • Increasing Resolution of an image : Generate a high resolution photo from a comparatively low resolution.

    Paper: https://arxiv.org/pdf/1609.04802.pdf
    Code: https://github.com/tensorlayer/srgan
  • Interactive Image Generation : Draw simple strokes and let the GAN draw an impressive picture for you!

    Link: https://github.com/junyanz/iGAN

  • Image to Image Translation : Generate an image from another image. For example, given on the left, you have labels of a street scene and you can generate a real looking photo with GAN. On the right, you give a simple drawing of a handbag and you get a real looking drawing of a handbag.


    Paper: https://arxiv.org/pdf/1611.07004.pdf
  • Text to Image Generation : Just say to your GAN what you want to see and get a realistic photo of the target.


    Paper : https://arxiv.org/pdf/1605.05396.pdf

Resources

Here are some resources which you might find helpful to get more in-depth on GAN

End Notes

Phew! I hope you are now as excited about the future as I was when I first read about GANs. They are set to change what machines can do for us. Think of it – from preparing new recipes of food to creating drawings. The possibilities are endless.

In this article, I tried to cover a general overview of GAN and its applications. GAN is very exciting area and that’s why researchers are so excited about building generative models and you can see that new papers on GANs are coming out more frequently.

If you have any questions on GANs, please feel free to share them with me through comments.

Learncompete, hack and get hired!

其他链接

GAN论文阅读——原始GAN(基本概念及理论推导)

原文地址:https://www.cnblogs.com/xuanyuyt/p/11935900.html

时间: 2024-08-04 08:24:25

GAN01: Introductory guide to Generative Adversarial Networks (GANs) and their promise!的相关文章

论文笔记之:UNSUPERVISED REPRESENTATION LEARNING WITH DEEP CONVOLUTIONAL GENERATIVE ADVERSARIAL NETWORKS

UNSUPERVISED REPRESENTATION LEARNING WITH DEEP CONVOLUTIONAL GENERATIVE ADVERSARIAL NETWORKS  ICLR 2016  摘要:近年来 CNN 在监督学习领域的巨大成功 和 无监督学习领域的无人问津形成了鲜明的对比,本文旨在链接上这两者之间的缺口.提出了一种 deep convolutional generative adversarial networks (DCGANs),that have certai

Speech Bandwidth Extension Using Generative Adversarial Networks

论文下载地址.博客园文章地址. 摘要 语音盲带宽扩展技术已经出现了一段时间,但到目前为止还没有出现广泛的部署,部分原因是增加的带宽伴随着附加的工件.本文提出了三代盲带宽扩展技术,从矢量量化映射到高斯混合模型,再到基于生成对抗性网络的深层神经网络的最新体系结构.这种最新的方法在质量上有了很大的提高,并证明了基于机器学习的盲带宽扩展算法在客观上和主观上都能达到与宽带编解码器相当的质量.我们相信,盲带宽扩展现在可以达到足够高的质量,以保证在现有的电信网络中部署. 关键词:盲源带宽扩展,人工带宽扩展,生

StarGAN: Unified Generative Adversarial Networks for Multi-Domain Image-to-Image Translation - 1 - 多个域间的图像翻译论文学习

Abstract 最近在两个领域上的图像翻译研究取得了显著的成果.但是在处理多于两个领域的问题上,现存的方法在尺度和鲁棒性上还是有所欠缺,因为需要为每个图像域对单独训练不同的模型.为了解决该问题,我们提出了StarGAN方法,这是一个新型的可扩展的方法,能够仅使用一个单一模型就实现多领域的图像翻译.StarGAN这样的统一模型的结构允许在单个网络上同时训练带有不同领域的多个数据集.这使得StarGAN的翻译图像质量优于现有的模型,并具有将输入图像灵活地翻译到任意目标域的新能力.通过实验,验证了该

Paper Reading: Perceptual Generative Adversarial Networks for Small Object Detection

Perceptual Generative Adversarial Networks for Small Object Detection 2017-07-11  19:47:46   CVPR 2017 This paper use GAN to handle the issue of small object detection which is a very hard problem in general object detection. As shown in the followin

[论文理解] On the "steerability" of generative adversarial networks

On the "steerability" of generative adversarial networks Intro 本文提出对GAN的latent space进行操作的一种方法,通过对latent space的编辑实现生成域外样本,操控生成样本的基本属性,如控制生成样本的位置.光照.二维旋转.三维旋转等等. 文章的主要贡献为: 证明并实现了通过在latent space中的"walk"能够实现类似人类世界中相机的运动.颜色变换等操作,这些操作是通过自监督的

A Review on Generative Adversarial Networks: Algorithms, Theory, and Applications

1 Inttroduction GANs由两个模型组成:生成器和鉴别器.生成器试图捕获真实示例的分布,以便生成新的数据样本.鉴别器通常是一个二值分类器,尽可能准确地将生成样本与真实样本区分开来.GANs的优化问题是一个极大极小优化问题.优化终止于相对于生成器的最小值和相对于鉴别器的最大值的鞍点. 2.1 Generative algorithms 生成算法可分为两类:显式密度模型和隐式密度模型. 2.1,1 Explicit density model 显式密度模型假设分布,利用真实数据训练包含

生成对抗网络 Generative Adversarial Networks

转自:https://zhuanlan.zhihu.com/p/26499443 生成对抗网络GAN是由蒙特利尔大学Ian Goodfellow教授和他的学生在2014年提出的机器学习架构. 要全面理解生成对抗网络,首先要理解的概念是监督式学习和非监督式学习.监督式学习是指基于大量带有标签的训练集与测试集的机器学习过程,比如监督式图片分类器需要一系列图片和对应的标签("猫","狗"-),而非监督式学习则不需要这么多额外的工作,它们可以自己从错误中进行学习,并降低未来

CS231n assignment3 Q5 Generative Adversarial Networks

LeakyReLU def leaky_relu(x, alpha=0.01): """Compute the leaky ReLU activation function. Inputs: - x: TensorFlow Tensor with arbitrary shape - alpha: leak parameter for leaky ReLU Returns: TensorFlow Tensor with the same shape as x "&qu

对抗生成网络 Generative Adversarial Networks

1. Basic idea 基本任务:要得到一个generator,能够模拟想要的数据分布.(一个低维向量到一个高维向量的映射) discriminator就像是一个score function. 如果想让generator生成想要的目标数据,就把这些真实数据作为discriminator的输入,discriminator的另一部分输入就是generator生成的数据. 1. 初始化generator和discriminator. 2. 迭代: 固定generator的参数,更新discrimi