以mnist数据训练为例,学习DCGAN(deep convolutional generative adversarial networks)的网络结构。
代码下载地址https://github.com/carpedm20/DCGAN-tensorflow
注1:发现代码中以mnist为训练集的网络和以无标签数据集(以下简称unlabeled_dataset)为训练集的网络不同,结构有别。以下笔记主要针对前者(Generator=3个ReLU+1个Sigmoid,Discriminator=)。
注2:事实上,以unlabeled_dataset为训练集的网络也同原网页中所画的Generator有一些不同(事实上,区别是每个conv层的filter个数被减半了。原因可能为了减少网络参数利于测试者训练?)。除此以外,结构相同。Discriminator的结构与Generator的结构正好对称。
原代码中对于有label的数据集,Generator和Discriminator的网络结构均只有两个卷积层。
预备:batch_size=64, mnist图片大小28*28*c_dim,其中c_dim=1为颜色维数,类别数10;随机输入维数z_dim=100
Generator 部分
step1(获得输入z):服从均匀分布的输入样本z(shape=64,100)与具有one-hot形式的标签y (shape=64,10)级联,整体作为Generator的输入z (shape=64,110)
step2(获得第一个非线性层的输出h0): 通过线性变换将z变换为维数为gfc_dim=1024的数据,对其块归一化之后进行非线性ReLU变换,得到h0 (shape=64,1024);将h0与y级联,整体作为下一层的输入h0 (shape=64,1034)
step3(获得第二个非线性层的输出h1):通过线性变换将h0变换为维数为128*7*7=6272的数据,对其块归一化之后进行非线性ReLU变换,得到h1 (shape=64,6272);将h1进行reshape操作得到h1(shape=64,7,7,128);将h1与yb(yb为y的reshape形式,即最后一维为label维,yb的shape=64,1,1,10)级联,整体作为下一层的输入h1 (shape=64,7,7,139)
step4(获得第三个非线性层的输出h2):通过deconv2d操作,用128个filter,将h1变换为维数为64*14*14*128的数据,对其块归一化之后进行非线性ReLU变换,得到h2 (shape=64,14,14,128);将h2与yb级联,整体作为下一层的输入h2 (shape=64,14,14,139)
step5(获得最终的生成图像generated_image):通过deconv2d操作,用c_dim个filter,将h1变换为维数为64*28*28*1的数据,不做块归一化,进行非线性Sigmoid变换,得到generated_image(shape=64,28,28,1)
Discriminator 部分
step1(获得输入x):真实/生成图像image(shape=64,28,28,1)和yb(shape=64,1,1,10)级联,整体作为Discriminator的输入x(shape=64,28,28,11)
step2(获得第一个卷积层的输出h0):用c_dim+y_dim=1+10=11个大小为5*5*11的filter对输入x进行二维卷积操作,随后进行非线性LeakyReLU变换,得到h0 (shape=64,14,14,11);将h0与标签yb级联,整体作为下一层的输入h0(shape=64,14,14,21)
step3(获得第二个卷积层的输出h1):用df_dim+y_dim=64+10=74个大小为5*5*21的filter对h0进行二维卷积操作,块归一化(这层有)之后进行非线性LeakReLU变换,得到h1(shape=64,7,7,74);对h1进行reshape拉成每个样本对应一维数据,并与标签y (shape=64,10)级联,整体作为下一层的输入h1(shape=64,7*7*74+10=64,3636)
step4(获得第三个卷积层的输出h2):对h1进行线性变换,输出维数为dfc_dim=1024的数据,块归一化(这层有)操作之后进行非线性LeakyReLU变换,得到h2 (shape=64,1024);将h2与y进行级联,整体作为下一层的输入h2(shape=64,1034)
step5(获得最终输出h3):对h2进行线性变换,输出维数为1(用来判断真假)的数据,非线性Sigmoid变换之后得到最终输出h3 (shape=64,1) (注:实际代码中将线性变换之后的结果也进行了输出,用以计算loss)
Loss部分
真实图像和生成图像这两种图像都需要输入Discriminator得到对应的loss,整体作为Discriminator的loss;
而Generator的loss只包含有关生成图像部分;
用Adam训练,每训练两次Generator,才对Discriminator进行一次训练,防止Discriminator的loss的导数为0(无法更新)。