更快更稳定:这就是Wasserstein GAN

这篇论文介绍了一种名叫 Wasserstein GAN(WGAN)的全新算法,这是一种可替代标准生成对抗网络(GAN)的训练方法。这项研究没有应用传统 GAN 所用的那种 minimax 形式,而是基于一种名为“Wasserstein 距离”的新型距离指标做了某些修改。

这是基于 MLP 生成器的 WGAN(左上图)和 GAN(右上图)生成的样本,很显然,这里 WGAN 的图像质量优于标准 GAN。

简单来说,WGAN 有两个改变。第一个是取出了判别器中的 sigmoid,这是用于计算输出均值之间的差异的。第二个改变是判别器(这篇论文称之为 Critic),这就只是一个函数,其目标是让假数据有较低的预期值,让真实数据有较高的预期值。注意这些输出不再是对数概率,这样这些损失现在就与二元交叉熵无关了。

Wasserstein GAN

近期一些 GAN 论文提出了一些不同的生成对抗训练架构。但是,这些架构的一个共同点是 f-距离(包括 KL-距离、总变差散度(total variation divergence))。f-距离是真实数据分布和生成数据分布之间的密度比 P_r(x)/P_θ(x) 的函数,非常类似于 Jenson-Shannon(JS)距离。

上式是标准 GAN 的目标。在 GAN 的训练过程中,判别器的目标是最大化上述目标(最大值为 0,最小值为负无穷)。GAN 的估计可对应于 JS 距离度量。我们再看看 f-距离。如果两个分布没有显著的重叠,我们又能做什么?如果不能,那么其概率密度比将为零或无穷,而且其对整体概率估计(比如由 (0, z) 点组成的真实数据,其中 z ~ U (0,1))会有巨大的负面影响,于是样本就会从 y=0 到 y=1 沿垂直轴 x=0 均匀分布。但如果该模型生成样本 (θ, z),则其分布根本不会重叠。在这种情况下,会发生梯度消失问题,会使标准 GAN 崩溃。

所以基于这一事实,这篇论文的作者提出使用 Wasserstein 距离,而不是 JS 距离。Wasserstein 距离定义为:

我们可以这样解读这一等式:首先,所有可能的配置都会被选取,假设是 P_r(x) 和 P_g(x)。然后这些点会根据这两个分布来配对。在那之后,它会计算每组配置中配对的平均距离。这里的 inf 可以被视为最小值,这样最后它将从所有可能的配对配置中选择出最小的平均距离。这篇论文提出使用这一距离度量来替代 f-距离,这样它就不再是密度比的函数的。通过这种方式,即使两个分布没有重叠,Wasserstein 距离也仍然可以描述它们相距多远,并且通过这种方式能从根本上解决梯度消失问题。

由于初始的 Wasserstein 距离定义具有难以解决的计算复杂性,所以研究者使用了一种替代定义:

这会导致 Kantorovich-Rubinstein二元性。

值得注意的是,当且仅当 f(x) 的梯度的幅度由 K 在该空间的所有部分设定了上界时,f(x) 是 K-Lipschitz。这篇论文通过将权重限制在一定范围内,使用网络来近似建模 K-Lipschitz。这里的上界可以被视为是一个最大值(二元表达式)。理论上,其目标是寻找到一个 critic 函数,以最大化真实样本均值和伪造样本均值之间余量。

WGAN 算法

上面描述了 Wasserstein 生成对抗网络(WGAN)算法。经过前面的知识介绍之后,这个算法看起来就更简单一些了。总结如下:

  • 更新 Critic n 次迭代,之后更新生成器;
  • 对于 Critic 的每次迭代,基于 Wasserstein 距离更新梯度,然后剪切权重;
  • 使用 RMSProp;
  • 像普通 GAN 那样更新生成器。

下面给出了实现 WGAN 算法的代码示例:

# (1) update Critic Network

for p in netD.parameters():

p.requires_grad = True

netD.zero_grad()

# train with real

real_cpu, _ = data

netD.zero_grad()

batch_size = real_cpu.size(0)

input.data.resize_(real_cpu.size()).copy_(real_cpu)

errD_real = netD(input)

errD_real.backward(one)

# train with fake

noise.data.resize_(batch_size, nz, 1, 1)

noise.data.normal_(0, 1)

fake = netG(noise)

input.data.copy_(fake.data)

errD_fake = netD(input)

errD_fake.backward(mone)

errD = errD_real - errD_fake

optimizerD.step()

# (2) Update G network

for p in netD.parameters():

p.requires_grad = False # to avoid computation

netG.zero_grad()

noise.data.resize_(opt.batchSize, nz, 1, 1)

noise.data.normal_(0, 1)

fake = netG(noise)

errG = netD(fake)

errG.backward(one)

optimizerG.step()

实证实验

研究者使用 Wasserstein GAN 进行了一些定量实验,并且表明相比于标准 GAN,使用 WGAN有显著的实际好处。

他们提到了两个优势:

  • WGAN 的损失表现出了收敛的特性。

如上所示,上图为 WGAN,下图为标准 GAN。对于 WGAN,随着损失快速下降,样本质量也会增长。相比于 WGAN,标准 GAN 算法的误差曲线是不稳定的,甚至会增大。

  • 优化过程的稳定性提升。

上图是使用无批归一化的该算法得到的生成器的结果。左上基于 WGAN 算法,右上基于标准 GAN 算法。标准 GAN 不能学习的地方,WGAN 依然能稳定地生成合理的样本。

分析师简评

这篇论文提出了一种名为 Wasserstein GAN 的新型生成对抗网络。它从理论上向我们说明了已有的 GAN 模型失败的原因以及 WGAN 有效的原因。相比于 DCGAN 等标准 GAN,这篇论文表明即使没有批归一化,WGAN 也能稳定地训练。但也仍然存在一些值得关注的地方。首先,在更新生成器之前他们更新了 critic n 次迭代,这意味着 critic 的迭代次数仍是人工调节的。是否存在优化两者的更好方法呢?第二,WGAN 在非常深度的网络上的泛化情况如何,比如 152 层的残差网络?第三,他们限制了权重的范围以确保 Lipschitz 连续性,但是否存在建模这种情况的方法?最后,生成对抗训练能否用于词预测等 NLP 任务,同时还能保持稳定性?

原文地址:https://www.cnblogs.com/chuangye95/p/10204555.html

时间: 2024-10-13 05:50:13

更快更稳定:这就是Wasserstein GAN的相关文章

让Python跑得更快

点击关注 异步图书,置顶公众号 每天与你分享 IT好书 技术干货 职场知识 Tips 参与文末话题讨论,即有机会获得异步图书一本. Python很容易学.你之所以阅读本文可能是因为你的代码现在能够正确运行,而你希望它能跑得更快.你可以很轻松地修改代码,反复地实现你的想法,你对这一点很满意.但能够轻松实现和代码跑得够快之间的取舍却是一个世人皆知且令人惋惜的现象.而这个问题其实是可以解决的. 有些人想要让顺序执行的过程跑得更快.有些人需要利用多核架构.集群,或者图形处理单元的优势来解决他们的问题.有

《Java程序性能优化:让你的Java程序更快、更稳定》

Java程序性能优化:让你的Java程序更快.更稳定, 卓越网更便宜,不错的书吧

Generative Adversarial Nets[Wasserstein GAN]

本文来自<Wasserstein GAN>,时间线为2017年1月,本文可以算得上是GAN发展的一个里程碑文献了,其解决了以往GAN训练困难,结果不稳定等问题. 1 引言 本文主要思考的是半监督学习.当我们说到学习一个概率分布,人们传统的意思是学习一个概率密度.这通常是通过定义一个参数化概率密度\((P_{\theta})_{\theta\in R^d}\)家族,然后基于收集的数据进行最大似然:如果当前有真实样本\(\{x^{(i)}\}_{i=1}^m\),那么是问题转换成: \[\unde

更快学习 JS 的 6 个简单思维技巧

当人们尝试学习 JavaScript , 或者其他编程技术的时候,常常会遇到同样的挑战: 有些概念容易混淆,特别是当你学习过其他语言的时候. 很难找到学习的时间(有时候是动力). 一旦当你理解了一些东西的时候,却很容易再一次忘记. 可以使用的工具甚多且经常变化,所以不知道从哪里开始入手. 幸运的是,这些挑战最终都可以被战胜.在这篇文章里,我将介绍 6 个思维技巧来帮你更快的学习 JavaScript ,让你成为一个更快乐更多产的程序员. 1.不要让将来的决定阻止你进步 对于很多学习 JavaSc

Dnsmasq安装与配置-搭建本地DNS服务器 更干净更快无广告DNS解析

默认的情况下,我们平时上网用的本地DNS服务器都是使用电信或者联通的,但是这样也导致了不少的问题,首当其冲的就是上网时经常莫名地弹出广告,或者莫名的流量被消耗掉导致网速变慢.其次是部分网站域名不能正常被解析,莫名其妙地打不开,或者时好时坏. 如果碰上不稳定的本地DNS,还可能经常出现无法解析的情况.除了要避免"坏"的DNS的影响,我们还可以利用DNS做些"好"事,例如管理局域网的DNS.给手机App Store加速.纠正错误的DNS解析记录.保证上网更加安全.去掉网

CSS VS JS动画,哪个更快[译]

英文原文:https://davidwalsh.name/css-js-animation 原作者Julian Shapiro是Velocity.js的作者,Velocity.js是一个高效易用的js动画库.在<Javascript网页动画设计>一书中对这个库有很多更具体的剖析,对Velocity及JS动画感兴趣的可以一看. 基于Javascript的动画怎么可能总是和 CSS transition 一样快,甚至更快呢?到底是什么秘密呢?Adobe 和 Google 是怎么做到让他们的富媒体移

【转】实战USB接口手机充电 看3.0/2.0谁更快

原文网址:http://mb.it168.com/a2012/0816/1385/000001385641_all.shtml [IT168 应用]当下,越来越多的电脑都已普及USB 3.0接口,新买的笔记本上,新装的台式机后,你都能发现这个跟过去2.0时代不一样的蓝汪汪的USB接口.那么,同样是给手机充电,USB 3.0和传统的USB 2.0相比,以及不同主板芯片组之间.台式机与笔记本的USB之间.主板I/O面板与主板扩展USB接口,这些USB3.0与2.0接口在充电速度上都有什么不同?这就是

与阿里云整个生态体系共同成长,更快更好的为房地产行业客户提供高价值的服务。

免费开通大数据服务:https://www.aliyun.com/product/odps "最早是新业务要做,但是买服务器来不及,管理员没到位,而且新业务的成本很高,是否能成功也是未知,因此明源决定采用阿里云,等资金和人到位再搬到自己内部.然而就是这种误打误撞,却让明源抓住了一个很好的机会走在了正确的轨道上."--副总裁童继龙"阿里云数加的覆盖面很广,从存储.计算到上层应用,提供了一整套的解决方案,确实起到了马总说的普惠大数据.此外,数加也在不断的迭代,不停的有新产品出现,

CSS vs JS动画:谁更快?

CSS vs JS动画:谁更快? 2016-05-16 前端大全 (点击上方公众号,可快速关注) 英文:Julian Shapiro 译者:MZhou's blog 链接:http://zencode.in/19.CSS-vs-JS动画:谁更快?.html 这篇文章翻译自 Julian Shapiro 的 CSS vs. JS Animation: Which is Faster?.Julian Shapiro 也是 Velocity.js 的创造者.这是一个非常高效.简单易用的JS动画库.他在