学习笔记GAN004:DCGAN main.py

Scipy 高端科学计算:http://blog.chinaunix.net/uid-21633169-id-4437868.html

import os #引用操作系统函数文件
import scipy.misc #引用scipy包misc模块 图像形式存取数组
import numpy as np #引用numpy包 矩阵计算
from model import DCGAN #引用model文件DCGAN类
from utils import pp, visualize, to_json, show_all_variables #引用utils文件pp对象,visualize, to_json, show_all_variables方法
import tensorflow as tf #引用tensorflow
flags = tf.app.flags #接受命令行传递参数,相当于接受argv。第一个是参数名称,第二个参数是默认值,第三个是参数描述
flags.DEFINE_integer("epoch", 25, "Epoch to train [25]") #训练轮数 25
flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]") #adam优化器 学习速率 0.0002
flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]") #adam优化器 动量(参数移动平均数) 0.5
flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]") #训练画像尺寸,默认无限大正数
flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]") #图像批大小 64
flags.DEFINE_integer("input_height", 108, "The size of image to use (will be center cropped). [108]") #输入图像高度 108 均衡的缩放图像(保持图像原始比例),使图片的两个坐标(宽、高)都大于等于 相应的视图坐标(负的内边距)。图像则位于视图的中央。
flags.DEFINE_integer("input_width", None, "The size of image to use (will be center cropped). If None, same value as input_height [None]") #输入图像宽度,None与高度相同
flags.DEFINE_integer("output_height", 64, "The size of the output images to produce [64]") #输出图像高度 64
flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]") #输出图像宽度,None与高度相同
flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]") #数据集名称 celebA mnist lsun
flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]") #图片文件名的搜索扩展名
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]") #检查点目录名
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]") #图片样本保存目录名
flags.DEFINE_boolean("train", False, "True for training, False for testing [False]") #训练流程开关
flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]") #训练流程开关
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]") #可视化开关
FLAGS = flags.FLAGS
def main(_): #主程序
  pp.pprint(flags.FLAGS.__flags) #打印命令行参数
  if FLAGS.input_width is None: #如果没有配置输入图像宽度
    FLAGS.input_width = FLAGS.input_height #把输入图像高度作为宽度
  if FLAGS.output_width is None: #如果没有配置输出图像宽度
    FLAGS.output_width = FLAGS.output_height #把输出图像高度作为宽度
  if not os.path.exists(FLAGS.checkpoint_dir): #如果检查点目录不存在
    os.makedirs(FLAGS.checkpoint_dir) #创建检查点目录
  if not os.path.exists(FLAGS.sample_dir): #如果样本目录不存在
    os.makedirs(FLAGS.sample_dir) #创建样本目录
  #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333) #设置GPU显存占用比例
  run_config = tf.ConfigProto() #获取配置对象
  run_config.gpu_options.allow_growth = True #GPU显存占用按需增加
  with tf.Session(config=run_config) as sess: #指定配置构建会话
    if FLAGS.dataset == ‘mnist‘: #如果指定数据集为mnist
      dcgan = DCGAN( #构建DCGAN
          sess, #提定会话
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          y_dim=10, #标签维度为10
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          crop=FLAGS.crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir)
    else:
      dcgan = DCGAN( #构建DCGAN,不指定标签维度
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          crop=FLAGS.crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir)
    show_all_variables() #显示所有参数
    if FLAGS.train: #如果是训练
      dcgan.train(FLAGS) #指定参数执行构建DCGAN 训练方法
    else: #如果是测试
      if not dcgan.load(FLAGS.checkpoint_dir)[0]: #在检查点目录没有检查点文件,即没有已训练好的模型
        raise Exception("[!] Train a model first, then run test mode") #抛出异常:请先训练模型再执行测试
      
    # to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0], #JSON格式化:w,b,gbn
    #                 [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
    #                 [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
    #                 [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
    #                 [dcgan.h4_w, dcgan.h4_b, None])
    # Below is codes for visualization
    OPTION = 1
    visualize(sess, dcgan, FLAGS, OPTION) #执行可视化方法,传入会话、DCGAN、配置参数,选项
if __name__ == ‘__main__‘: #如果直接执行本脚本文件,运行以下代码,一般作调试用。如果作为其它脚本模块引入,则不执行以下代码
  tf.app.run() #运行APP.run 解析FLAGS,执行main方法

欢迎付费咨询(150元每小时),我的微信:qingxingfengzi

我创建GAN日报群,以每天各报各的进度为主。把正在研究GAN的人聚在一起,互相鼓励,一起前进。加我微信拉群,请注明:加入GAN日报群。

时间: 2024-10-12 20:26:22

学习笔记GAN004:DCGAN main.py的相关文章

Hadoop源码学习笔记(2) ——进入main函数打印包信息

Hadoop源码学习笔记(2) ——进入main函数打印包信息 找到了main函数,也建立了快速启动的方法,然后我们就进去看一看. 进入NameNode和DataNode的主函数后,发现形式差不多: public static void main(String args[]) {     try {       StringUtils.startupShutdownMessage(DataNode.class, args, LOG);       DataNode datanode = crea

学习笔记GAN002:DCGAN

Ian J. Goodfellow 论文:https://arxiv.org/abs/1406.2661 两个网络:G(Generator),生成网络,接收随机噪声Z,通过噪声生成样本,G(z).D(Dicriminator),判别网络,判别样本是否真实,输入样本x,输出D(x)代表x真实概率,如果1,100%真实样本,如果0,代表不可能是真实样本. 训练过程,生成网络G尽量生成真实样本欺骗判别网络D,判别网络D尽量把G生成样本和真实样本分别开.理想状态下,G生成样本G(z),使D难以判断真假,

Python学习笔记七:web.py

安装pip: 到github上下载pip:https://github.com/pypa/pip 解压后,在解压出来的文件夹中打开命令行,输入 python setup.py install 安装完毕后,配置系统环境变量:在Path后,添加 python安装目录\Scripts 然后在命令行环境下,输入pip list测试是否安装成功. 最后使用pip安装模块,比如web.py: pip install web.py

Hadoop源码学习笔记(1) ——第二季开始——找到Main函数及读一读Configure类

Hadoop源码学习笔记(1) ——找到Main函数及读一读Configure类 前面在第一季中,我们简单地研究了下Hadoop是什么,怎么用.在这开源的大牛作品的诱惑下,接下来我们要研究一下它是如何实现的. 提前申明,本人是一直搞.net的,对java略为生疏,所以在学习该作品时,会时不时插入对java的学习,到时也会摆一些上来,包括一下设计模式之类的.欢迎高手指正. 整个学习过程,我们主要通过eclipse来学习,之前已经讲过如何在eclipse中搭建调试环境,这里就不多述了. 在之前源码初

《LINUX内核设计的艺术》第一章从开机家电到执行main函数之前的过程 学习笔记之一

从开机加电到实行main函数之前的过程 分为三步,目的是实现从启动盘加载操作系统程序,完成实现main函数的准备工作 启动BLOS,准备是模式下的中断向量表和中断服务程序 从启动盘加载操作系统程序到内存.加载操作系统程序就是靠第一步实现的 为实现32位的main函数做过度工作 1.1启动blos,准备实模式下的中断向量表和中断服务程序 由blos来加载软件操作系统的任务 1.1.1         BLOS的启动原理 0XFFFF0 由硬件来启动,CPU硬件设计逻辑设计为加电瞬间就强行将CS的值

cocos2dx游戏开发学习笔记2-从helloworld开始

一.新建工程 具体安装和新建工程的方法在cocos2dx目录下的README.md文件中已经有详细说明,这里只做简单介绍. 1.上官网下载cocos2dx-3.0的源码,http://www.cocos2d-x.org/ 2.安装python2.7 3.运行setup.py安装 4.执行cocos new helloworld -p helloworld -l cpp,生成新工程 二.新建工程中包含的东西 -Classes AppDelegate.cpp      -----游戏真正开始执行的地

cocos2dx学习笔记(2)

昨天尝试了cocos2dx在win下的开发环境配置,并且运行了cocos的helloword程序,晚上想要尝试一下android开发环境配置,顺便学习cocos在eclipse下的JNI机制,按照cocoa中文论坛的android环境配置弄了NDK,并配置了环境变量,由于想要学习cocos的luabind机制(这个我们公司游戏的引擎用的很多,确实比较有兴趣),一切搞定不明就里的用eclipse导入了cocos3.0rc中的tests目录下的cpp-tests工程(这算android开发久了的毛病

[学习笔记] Python标准库简明教程 [转]

1 操作系统接口 os 模块提供了一系列与系统交互的模块: >>> os.getcwd() # Return the current working directory '/home/minix/Documents/Note/Programming/python/lib1' >>> os.chdir('~/python') # Change current working directory Traceback (most recent call last): File

C51学习笔记

转自:http://blog.csdn.net/gongyuan073/article/details/7856878 单片机C51学习笔记 一,   C51内存结构深度剖析 二,   reg51.头文件剖析 三,   浅淡变量类型及其作用域 四,   C51常用头文件 五,   浅谈中断 六,   C51编译器的限制 七,                        小淡C51指针 八,                        预处理命令