PyTorch实现简单的生成对抗网络GAN

生成对抗网络(Generative Adversarial Network, GAN)包括生成网络和对抗网络两部分。生成网络像自动编码器的解码器,能够生成数据,比如生成一张图片。对抗网络用来判断数据的真假,比如是真图片还是假图片,真图片是拍摄得到的,假图片是生成网络生成的。

生成对抗网络就是让生成网络和对抗网络相互竞争,通过生成网络来生成假的数据,对抗网络判别该数据是真是假,最后希望生成网络生成的数据以假乱真骗过判别器。

以下程序主要来自廖星宇的《深度学习之PyTorch》的第六章,本文对原代码进行了改进:

import torch
from torch import nn
import torchvision.transforms as tfs
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import numpy as np
import matplotlib.pyplot as plt

def preprocess_img(x):
    x = tfs.ToTensor()(x)      # x (0., 1.)
    return (x - 0.5) / 0.5     # x (-1., 1.)

def deprocess_img(x):          # x (-1., 1.)
    return (x + 1.0) / 2.0     # x (0., 1.)

def discriminator():
    net = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
        )
    return net

def generator(noise_dim):
    net = nn.Sequential(
        nn.Linear(noise_dim, 1024),
        nn.ReLU(True),
        nn.Linear(1024, 1024),
        nn.ReLU(True),
        nn.Linear(1024, 784),
        nn.Tanh(),
    )
    return net

def discriminator_loss(logits_real, logits_fake):   # 判别器的loss
    size = logits_real.shape[0]
    true_labels = torch.ones(size, 1).float()
    false_labels = torch.zeros(size, 1).float()
    bce_loss = nn.BCEWithLogitsLoss()
    loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
    return loss

def generator_loss(logits_fake):  # 生成器的 loss
    size = logits_fake.shape[0]
    true_labels = torch.ones(size, 1).float()
    bce_loss = nn.BCEWithLogitsLoss()
    loss = bce_loss(logits_fake, true_labels)
    return loss

# 使用 adam 来进行训练,beta1 是 0.5, beta2 是 0.999
def get_optimizer(net, LearningRate):
    optimizer = torch.optim.Adam(net.parameters(), lr=LearningRate, betas=(0.5, 0.999))
    return optimizer

def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss,
                noise_size, num_epochs, num_img):

    f, a = plt.subplots(num_img, num_img, figsize=(num_img, num_img))
    plt.ion()  # Turn the interactive mode on, continuously plot

    for epoch in range(num_epochs):
        for iteration, (x, _)in enumerate(train_data):
            bs = x.shape[0]

            # 训练判别网络
            real_data = x.view(bs, -1)  # 真实数据
            logits_real = D_net(real_data)  # 判别网络得分

            rand_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5  # -1 ~ 1 的均匀分布
            fake_images = G_net(rand_noise)  # 生成的假的数据
            logits_fake = D_net(fake_images)  # 判别网络得分

            d_total_error = discriminator_loss(logits_real, logits_fake)  # 判别器的 loss
            D_optimizer.zero_grad()
            d_total_error.backward()
            D_optimizer.step()  # 优化判别网络

            # 训练生成网络
            rand_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5  # -1 ~ 1 的均匀分布
            fake_images = G_net(rand_noise)  # 生成的假的数据

            gen_logits_fake = D_net(fake_images)
            g_error = generator_loss(gen_logits_fake)  # 生成网络的 loss
            G_optimizer.zero_grad()
            g_error.backward()
            G_optimizer.step()  # 优化生成网络

            if iteration % 20 == 0:
                print(‘Epoch: {:2d} | Iter: {:<4d} | D: {:.4f} | G:{:.4f}‘.format(epoch,
                                                                                  iteration,
                                                                                  d_total_error.data.numpy(),
                                                                                  g_error.data.numpy()))
                imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
                for i in range(num_img ** 2):
                    a[i // num_img][i % num_img].imshow(np.reshape(imgs_numpy[i], (28, 28)), cmap=‘gray‘)
                    a[i // num_img][i % num_img].set_xticks(())
                    a[i // num_img][i % num_img].set_yticks(())
                plt.suptitle(‘epoch: {} iteration: {}‘.format(epoch, iteration))
                plt.pause(0.01)

    plt.ioff()
    plt.show()

if __name__ == ‘__main__‘:

    EPOCH = 5
    BATCH_SIZE = 128
    LR = 5e-4
    NOISE_DIM = 96
    NUM_IMAGE = 4   # for showing images when training
    train_set = MNIST(root=‘/Users/wangpeng/Desktop/all/CS/Courses/Deep Learning/mofan_PyTorch/mnist/‘,
                      train=True,
                      download=False,
                      transform=preprocess_img)
    train_data = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)

    D = discriminator()
    G = generator(NOISE_DIM)

    D_optim = get_optimizer(D, LR)
    G_optim = get_optimizer(G, LR)

    train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss, NOISE_DIM, EPOCH, NUM_IMAGE)

效果:

原文地址:https://www.cnblogs.com/picassooo/p/12601909.html

时间: 2024-11-07 23:08:42

PyTorch实现简单的生成对抗网络GAN的相关文章

利用tensorflow训练简单的生成对抗网络GAN

对抗网络是14年Goodfellow Ian在论文Generative Adversarial Nets中提出来的. 原理方面,对抗网络可以简单归纳为一个生成器(generator)和一个判断器(discriminator)之间博弈的过程.整个网络训练的过程中, 两个模块的分工 判断器,直观来看就是一个简单的神经网络结构,输入就是一副图像,输出就是一个概率值,用于判断真假使用(概率值大于0.5那就是真,小于0.5那就是假) 生成器,同样也可以看成是一个神经网络模型,输入是一组随机数Z,输出是一个

生成对抗网络GAN

详解一:GAN完整理论推导和实现 详解二:详解生成对抗网络(GAN)原理 原文地址:https://www.cnblogs.com/yunkaiL/p/10952881.html

《生成对抗网络GAN的原理与应用专题》笔记

视频教程的链接:http://campus.swarma.org/gpac=8 一.什么是GAN 框架简述 GAN全称是Generative Adversarial Nets,中文叫做"生成对抗网络". 在GAN中有2个网络,一个网络用于生成数据,叫做"生成器".另一个网络用于判别生成数据是否接近于真实,叫做"判别器". 下图展示了最简单的GAN的结构以及工作原理. 模型中最左侧的随机向量是为了让生成器产生不同的输出,而增加的扰动项.这些扰动决定

生成对抗网络 Generative Adversarial Networks

转自:https://zhuanlan.zhihu.com/p/26499443 生成对抗网络GAN是由蒙特利尔大学Ian Goodfellow教授和他的学生在2014年提出的机器学习架构. 要全面理解生成对抗网络,首先要理解的概念是监督式学习和非监督式学习.监督式学习是指基于大量带有标签的训练集与测试集的机器学习过程,比如监督式图片分类器需要一系列图片和对应的标签("猫","狗"-),而非监督式学习则不需要这么多额外的工作,它们可以自己从错误中进行学习,并降低未来

使用生成对抗网络(GAN)生成手写字

先放结果 这是通过GAN迭代训练30W次,耗时3小时生成的手写字图片效果,大部分的还是能看出来是数字的. 实现原理 简单说下原理,生成对抗网络需要训练两个任务,一个叫生成器,一个叫判别器,如字面意思,一个负责生成图片,一个负责判别图片,生成器不断生成新的图片,然后判别器去判断哪儿哪儿不行,生成器再不断去改进,不断的像真实的图片靠近. 这就如同一个造假团伙一样,A负责生产,B负责就鉴定,刚开始的时候,两个人都是菜鸟,A随便画了一幅画拿给B看,B说你这不行,然后A再改进,当然需要改进的不止A,随着A

生成对抗网络浅析(GAN)

生成对抗网络 ? 顾名思义,生成对抗网络由两个部分构成, 生成器(Generator)和判别器(Discriminator), 两个部件相互博弈,最终达到平衡状态. 基本原理 下面以生成图片为例. G: 生成器 接受一个随机的噪声 z,通过噪声产生目标G(z) D:判别器 判别目标是否是"真实的".输入参数是 x,输出为D(x), 表示是否为真实的概率. ? 训练的过程中, G的目的就是尽量生成真实的图片欺骗 D.而 D的目标就是尽量将 G 生成的图片和真实的图片分离开.这样就是一个博

知物由学 | AI网络安全实战:生成对抗网络

"知物由学"是网易云易盾打造的一个品牌栏目,词语出自汉·王充<论衡·实知>.人,能力有高下之分,学习才知道事物的道理,而后才有智慧,不去求问就不会知道."知物由学"希望通过一篇篇技术干货.趋势解读.人物思考和沉淀给你带来收获的同时,也希望打开你的眼界,成就不一样的你. 以下是正文: 作者:Brad Harris,安全研究员,Brad曾在公共和私营部门的网络和计算机安全领域工作过.他已经完成了从渗透测试到逆向工程到应用研究的所有工作,目前他是IBMX-Fo

生成式对抗网络GAN 的研究进展与展望

生成式对抗网络GAN的研究进展与展望.pdf 摘要: 生成式对抗网络GAN (Generative adversarial networks) 目前已经成为人工智能学界一个热门的研究方向. GAN的基本思想源自博弈论的二人零和博弈, 由一个生成器和一个判别器构成, 通过对抗学习的方式来训练. 目的是估测数据样本的潜在分布并生成新的数据样本. 在图像和视觉计算.语音和语言处理.信息安全.棋类比赛等领域, GAN 正在被广泛研究,具有巨大的应用前景. 本文概括了GAN 的研究进展, 并进行展望. 在

正在涌现的新型神经网络模型:优于生成对抗网络

http://www.17bianji.com/lsqh/35130.html 是以,它会让人想起残差前馈收集(residual feed-forward network),但在实际中,强迫这些收集向前传播误差并不克不及让它们在更高财揭捉习到有效的层次表征.是以,它们不克不及基于更上层的表征来竽暌剐效地履行其它义务,例如分类.瓜分.动作辨认.要明白这些限制,还须要更多的实验. 新一代深度神经收集正在出现.它们演变自前馈模型,之前我们曾作过具体分析,参阅机械之心文┞仿 <重磅 | 神经收集架构演