总结一些常用的训练 GANs 的方法

众所周知,GANs 的训练尤其困难,笔者自从跳入了 GANs 这个领域(坑),就一直在跟如何训练 GANs 做「对抗训练」,受启发于 ganhacks,并结合自己的经验记录总结了一些常用的训练 GANs 的方法,以备后用。

什么是 GANs?

GANs(Generative Adversarial Networks)可以说是一种强大的「万能」数据分布拟合器,主要由一个生成器(generator)和判别器(discriminator)组成。生成器主要从一个低维度的数据分布中不断拟合真实的高维数据分布,而判别器主要是为了区分数据是来源于真实数据还是生成器生成的数据,他们之间相互对抗,不断学习,最终达到Nash均衡,即任何一方的改进都不会导致总体的收益增加,这个时候判别器再也无法区分是生成器生成的数据还是真实数据。

GANs 最初由 Ian Goodfellow [1] 于 2014 年提出,目前已经在图像、语音、文字等方面得到广泛研究和应用,特别是在图像生成方面,可谓是遍地开花,例如图像风格迁移(style transfer)、图像修复(image inpainting)、超分辨率(super resolution)等。

GANs 出了什么问题?

GANs 通常被定义为一个 minimax 的过程:

其中 P_r 是真实数据分布,P_z 是随机噪声分布。乍一看这个目标函数,感觉有点相互矛盾,其实这就是 GANs 的精髓所在—— 对抗训练。

在原始的 GANs 中,判别器要不断的提高判别是非的能力,即尽可能的将真实样本分类为正例,将生成样本分类为负例,所以判别器需要优化如下损失函数:

作为对抗训练,生成器需要不断将生成数据分布拉到真实数据分布,Ian Goodfellow 首先提出了如下式的生成器损失函数:

由于在训练初期阶段,生成器的能力比较弱,判别器这时候也比较弱,但仍然可以足够精准的区分生成样本和真实样本,这样 D(x) 就非常接近1,导致 log(1-D(x)) 达到饱和,后续网络就很难再调整过来。为了解决训练初期阶段饱和问题,作者提出了另外一个损失函数,即:

以上面这个两个生成器目标函数为例,简单地分析一下GAN模型存在的几个问题:

Ian Goodfellow 论文里面已经给出,固定 G 的参数,我们得到最优的 D^*:

也就是说,只有当 P_r=P_g 时候,不管是真实样本和生成样本,判别器给出的概率都是 0.5,这个时候就无法区分样本到底是来自于真实样本还是来自于生成样本,这是最理想的情况。

1. 对于第一种目标函数

在最优判别器下 D^* 下,我们给损失函数加上一个与 G 无关的项,(3) 式变成:

注意,该式子其实就是判别器的损失函数的相反数。

把最优判别器 D^* 带入,可以得到:

到这里,我们就可以看清楚我们到底在优化什么东西了,在最优判别器的情况下,其实我们在优化两个分布的 JS 散度。当然在训练过程中,判别器一开始不是最优的,但是随着训练的进行,我们优化的目标也逐渐接近JS散度,而问题恰恰就出现在这个 JS 散度上面。一个直观的解释就是只要两个分布之间的没有重叠或者重叠部分可以忽略不计,那么大概率上我们优化的目标就变成了一个常数 -2log2,这种情况通过判别器传递给生成器的梯度就是零,也就是说,生成器不可能从判别器那里学到任何有用的东西,这也就导致了无法继续学习。

Arjovsky [2] 以其精湛的数学技巧提供一个更严谨的一个数学推导(手动截图原论文了)。

在 Theorm2.4 成立的情况下:

抛开上面这些文绉绉的数学表述,其实上面讲的核心内容就是当两个分布的支撑集是没有交集的或者说是支撑集是低维的流形空间,随着训练的进行,判别器不断接近最优判别器,会导致生成器的梯度处处都是为0。

2. 对于第二种目标函数

同样在最优判别器下,优化 (4) 式等价优化如下

仔细盯着上面式子几秒钟,不难发现我们优化的目标是相互悖论的,因为 KL 散度和 JS 散度的符号相反,优化 KL 是把两个分布拉近,但是优化 -JS 是把两个分布推远,这「一推一拉」就会导致梯度更新非常不稳定。此外,我们知道 KL 不是对称的,对于生成器无法生成真实样本的情况,KL 对 loss 的贡献非常大,而对于生成器生成的样本多样性不足的时候,KL 对 loss 的贡献非常小。

而 JS 是对称的,不会改变 KL 的这种不公平的行为。这就解释了我们经常在训练阶段经常看见两种情况,一个是训练 loss 抖动非常大,训练不稳定;另外一个是即使达到了稳定训练,生成器也大概率上只生成一些安全保险的样本,这样就会导致模型缺乏多样性。

此外,在有监督的机器学习里面,经常会出现一些过拟合的情况,然而 GANs 也不例外。当生成器训练得越来越好时候,生成的数据越接近于有限样本集合里面的数据。特别是当训练集里面包含有错误数据时候,判别器会过拟合到这些错误的数据,对于那些未见的数据,判别器就不能很好的指导生成器去生成可信的数据。这样就会导致 GANs 的泛化能力比较差。

综上所述,原始的 GANs 在训练稳定性、模式多样性以及模型泛化性能方面存在着或多或少的问题,后续学术上的工作大多也是基于此进行改进(填坑)。

训练 GAN 的常用策略

上一节都是基于一些简单的数学或者经验的分析,但是根本原因目前没有一个很好的理论来解释;尽管理论上的缺陷,我们仍然可以从一些经验中发现一些实用的 tricks,让你的 GANs 不再难训。这里列举的一些 tricks 可能跟 ganhacks 里面的有些重复,更多的是补充,但是为了完整起见,部分也添加在这里。

1. model choice

如果你不知道选择什么样的模型,那就选择 DCGAN[3] 或者 ResNet[4] 作为 base model。

2. input layer

假如你的输入是一张图片,将图片数值归一化到 [-1, 1];假如你的输入是一个随机噪声的向量,最好是从 N(0, 1) 的正态分布里面采样,不要从 U(0,1) 的均匀分布里采样。

3. output layer

使用输出通道为 3 的卷积作为最后一层,可以采用 1x1 或者 3x3 的 filters,有的论文也使用 9x9 的 filters。(注:ganhacks 推荐使用 tanh)

4. transposed convolution layer

在做 decode 的时候,尽量使用 upsample+conv2d 组合代替 transposed_conv2d,可以减少 checkerboard 的产生 [5];

在做超分辨率等任务上,可以采用 pixelshuffle [6]。在 tensorflow 里,可以用 tf.depth_to_sapce 来实现 pixelshuffle 操作。

5. convolution layer

由于笔者经常做图像修复方向相关的工作,推荐使用 gated-conv2d [7]。

6. normalization

虽然在 resnet 里的标配是 BN,在分类任务上表现很好,但是图像生成方面,推荐使用其他 normlization 方法,例如 parameterized 方法有 instance normalization [8]、layer normalization [9] 等,non-parameterized 方法推荐使用 pixel normalization [10]。假如你有选择困难症,那就选择大杂烩的 normalization 方法——switchable normalization [11]。

7. discriminator

想要生成更高清的图像,推荐 multi-stage discriminator [10]。简单的做法就是对于输入图片,把它下采样(maxpooling)到不同 scale 的大小,输入三个不同参数但结构相同的 discriminator。

8. minibatch discriminator

由于判别器是单独处理每张图片,没有一个机制能告诉 discriminator 每张图片之间要尽可能的不相似,这样就会导致判别器会将所有图片都 push 到一个看起来真实的点,缺乏多样性。minibatch discriminator [22] 就是这样这个机制,显式地告诉 discriminator 每张图片应该要不相似。在 tensorflow 中,一种实现 minibatch discriminator 方式如下:

上面是通过一个可学习的网络来显示度量每个样本之间的相似度,PGGAN 里提出了一个更廉价的不需要学习的版本,即通过统计每个样本特征每个像素点的标准差,然后取他们的平均,把这个平均值复制到与当前 feature map 一样空间大小单通道,作为一个额外的 feature maps 拼接到原来的 feature maps 里,一个简单的 tensorflow 实现如下:

9. GAN loss

除了第二节提到的原始 GANs 中提出的两种 loss,还可以选择 wgan loss [12]、hinge loss、lsgan loss [13]等。wgan loss 使用 Wasserstein 距离(推土机距离)来度量两个分布之间的差异,lsgan 采用类似最小二乘法的思路设计损失函数,最后演变成用皮尔森卡方散度代替了原始 GAN 中的 JS 散度,hinge loss 是迁移了 SVM 里面的思想,在 SAGAN [14] 和 BigGAN [15] 等都是采用该损失函数。

ps: 我自己经常使用没有 relu 的 hinge loss 版本。

10. other loss

  • perceptual loss [17]
  • style loss [18]
  • total variation loss [17]
  • l1 reconstruction loss

通常情况下,GAN loss 配合上面几种 loss,效果会更好。

11. gradient penalty

Gradient penalty 首次在 wgan-gp 里面提出来的,记为 1-gp,目的是为了让 discriminator 满足 1-lipchitchz 连续,后续 Mescheder, Lars M. et al [19] 又提出了只针对正样本或者负样本进行梯度惩罚,记为 0-gp-sample。Thanh-Tung, Hoang et al [20] 提出了 0-gp,具有更好的训练稳定性。三者的对比如下:

12. Spectral normalization [21]

谱归一化是另外一个让判别器满足 1-lipchitchz 连续的利器,建议在判别器和生成器里同时使用。

ps: 在个人实践中,它比梯度惩罚更有效。

13. one-size label smoothing [22]

平滑正样本的 label,例如 label 1 变成 0.9-1.1 之间的随机数,保持负样本 label 仍然为 0。个人经验表明这个 trick 能够有效缓解训练不稳定的现象,但是不能根本解决问题,假如模型不够好的话,随着训练的进行,后期 loss 会飞。

14. add supervised labels

  • add labels
  • conditional batch normalization

15. instance noise (decay over time)

在原始 GAN 中,我们其实在优化两个分布的 JS 散度,前面的推理表明在两个分布的支撑集没有交集或者支撑集是低维的流形空间,他们之间的 JS 散度大概率上是 0;而加入 instance noise 就是强行让两个分布的支撑集之间产生交集,这样 JS 散度就不会为 0。新的 JS 散度变为:

16. TTUR [23]

在优化 G 的时候,我们默认是假定我们的 D 的判别能力是比当前的 G 的生成能力要好的,这样 D 才能指导 G 朝更好的方向学习。通常的做法是先更新 D 的参数一次或者多次,然后再更新 G 的参数,TTUR 提出了一个更简单的更新策略,即分别为 D 和 G 设置不同的学习率,让 D 收敛速度更快。

17. training strategy

  • PGGAN [10]

PGGAN 是一个渐进式的训练技巧,因为要生成高清(eg, 1024x1024)的图片,直接从一个随机噪声生成这么高维度的数据是比较难的;既然没法一蹴而就,那就循序渐进,首先从简单的低纬度的开始生成,例如 4x4,然后 16x16,直至我们所需要的图片大小。在 PGGAN 里,首次实现了高清图片的生成,并且可以做到以假乱真,可见其威力。此外,由于我们大部分的操作都是在比较低的维度上进行的,训练速度也不比其他模型逊色多少。

  • coarse-to-refine

coarse-to-refine 可以说是 PGGAN 的一个特例,它的做法就是先用一个简单的模型,加上一个 l1 loss,训练一个模糊的效果,然后再把这个模糊的照片送到后面的 refine 模型里,辅助对抗 loss 等其他 loss,训练一个更加清晰的效果。这个在图片生成里面广泛应用。

18. Exponential Moving Average [24]

EMA主要是对历史的参数进行一个指数平滑,可以有效减少训练的抖动。强烈推荐!!!

总结

训练 GAN 是一个精(折)细(磨)的活,一不小心你的 GAN 可能就是一部惊悚大片。笔者结合自己的经验以及看过的一些文献资料,列出了常用的 tricks,在此抛砖引玉,由于笔者能力和视野有限,有些不正确之处或者没补全的 tricks,还望斧正。

原文地址:https://www.cnblogs.com/alan-blog-TsingHua/p/10952890.html

时间: 2024-10-01 22:48:04

总结一些常用的训练 GANs 的方法的相关文章

Numpy中数据的常用的保存与读取方法

Numpy中数据的常用的保存与读取方法 小书匠 深度学习 文章目录: 1.保存为二进制文件(.npy/.npz) numpy.save numpy.savez numpy.savez_compressed 2.保存到文本文件 numpy.savetxt numpy.loadtxt 在经常性读取大量的数值文件时(比如深度学习训练数据),可以考虑现将数据存储为Numpy格式,然后直接使用Numpy去读取,速度相比为转化前快很多. 下面就常用的保存数据到二进制文件和保存数据到文本文件进行介绍: 1.保

常用的操作正则表达式的方法+正则表达式基本元字符使用实例

常用的操作正则表达式的方法: 下面学习一下位于System.Text.RegularExpressions下的Regex类的一些静态方法和委托(只要有一段匹配就会返回true) 1,静态方法IsMatch (返回值是一个布尔类型,用于判断指定的字符串是否与正则表达式字符串匹配,它有三个重载方法) bool IsMatch(string input, string pattern); 参数: input: 要搜索匹配项的字符串. pattern: 要匹配的正则表达式模式. 返回结果: 如果正则表达

Android中常用的三种存储方法浅析

Android中常用的三种存储方法浅析 Android中数据存储有5种方式: [1]使用SharedPreferences存储数据 [2]文件存储数据 [3]SQLite数据库存储数据 [4]使用ContentProvider存储数据 [5]网络存储数据 在这里我只总结了三种我用到过的或即将可能用到的三种存储方法. 一.使用SharedPreferences存储数据 SharedPreferences是Android平台上一个轻量级的存储类,主要是保存一些常用的配置信息比如窗口状态,它的本质是基

mysql 的常用命令及常见问题解决方法

运行sql C:\Users\Martin>mysql -uroot -pyang cdm_db <d:/cdm_db.sql 运行sql mysql>source /tmp/terminal.sql; mysql忘记密码: mysqladmin -uroot flush-privileges password "newpassword" mysql的select into file命令 SELECT a,b,a+b INTO OUTFILE '/tmp/result

[C/C++基础] C语言常用函数strlen的使用方法

函数声明:extern unsigned int strlen(char *s); 所属函数库:<string.h> 功能:返回s所指的字符串的长度,其中字符串必须以'\0'结尾 参数:s为字符串的初始地址 使用举例: 代码如下 编译运行结果 说明: 函数strlen比较容易理解,其功能和sizeof很容易混淆.其中sizeof指的是字符串声明后占用的内存长度,它就是一个操作符,不是函数:而strlen则是一个函数,它从第一个字节开始往后数,直到遇见了'\0',则停止. [C/C++基础] C

js常用内置对象及方法

在js中万物皆对象:字符串,数组,数值,函数...... 内置对象都有自己的属性和方法,访问方法如下: 对象名.属性名称: 对象名.方法名称 1.Array数组对象 unshift( )    数组开头增加 功能:给数组开头增加一个或多个 参数:一个或多个 返回值:数组的长度 原数组发生改变 shift( )        数组开头删除一项 功能:给数组开头删除一个 参数:无 返回值:被删除的内容 原数组发生改变 push( )       数组末尾增加 功能:给数组末尾增加一项或多项 参数:一

常用的生成客户端脚本方法

常用的生成客户端脚本方法: RegisterArraryDeclaration -- 添加javascript数组     RegisterClientScriptBlock-- 在 Web 窗体的开始处(紧接着 <form runat="server"> 标识之后)    RegiserStartScript-- ------- 在</form>前添加script代码块 RegisterStartupScript-- 在 Web 窗体的结尾处    Regis

[C/C++基础] C语言常用函数memset的使用方法

函数声明:void *memset(void *s, int ch, size_t n); 用途:为一段内存的每一个字节都赋予ch所代表的值,该值采用ASCII编码. 所属函数库:<memory.h> 或者 <string.h> 参数:(1)s,开始内存的地址:(2)ch和n,从地址s开始,在之后的n字节长度内,把每一个字节的值都赋值为n. 使用举例: 代码如下 编译运行结果 说明: 该函数最常用的用途就是将一段新分配的内存初始化为0.例如我们代码的第9-10行. 需要注意的是,函

转载-常用的清除浮动的方法有以下三种:

常用的清除浮动的方法有以下三种: 此为未清除浮动源代码,运行代码无法查看到父级元素浅黄色背景. <style type=”text/css”> <!– *{margin:0;padding:0;} body{font:36px bold; color:#F00; text-align:center;} #layout{background:#FF9;} #left{float:left;width:20%;height:200px;background:#DDD;line-height: