【机器学习PAI实战】—— 玩转人工智能之利用GAN自动生成二次元头像

前言

深度学习作为人工智能的重要手段,迎来了爆发,在NLP、CV、物联网、无人机等多个领域都发挥了非常重要的作用。最近几年,各种深度学习算法层出不穷, Generative Adverarial Network(GAN)自2014年提出以来,引起广泛关注,身为深度学习三巨头之一的Yan Lecun对GAN的评价颇高,认为GAN是近年来在深度学习上最大的突破,是近十年来机器学习上最有意思的工作。围绕GAN的论文数量也迅速增多,各种版本的GAN出现,主要在CV领域带来了一些贡献,如下图所示。

我们可以利用GAN生成一些我们需要的图像或者文本,比如二次元头像。

GAN简介

GAN主要的应用是自动生成一些东西,包括图像和文本等,比如随机给一个向量作为输入,通过GAN的Generator生成一张图片,或者生成一串语句。Conditional GAN的应用更多一些,比如数据集是一段文字和图像的数据对,通过训练,GAN可以通过给定一段文字生成对应的图像。

GAN主要可以分为Generator(生成器)和Discriminator(判别器)两个部分,其中Generator其实就是一个神经网络,输入一个向量,可以输出一张图像(即一个高维的向量表示),如下图示。

?Discriminator也是一个神经网络,输入为一张图像,输出为一个数值,输出的数值用于判断输入的图像是否是真的,数值越大,说明图像是真的,数值越小,说明图像为假的,如下图示。

?Generator负责生成图像,Discriminator负责对Generator生成的图像和真实图像去进行对比,区别出真假,Generator需要不断优化来欺骗Discriminator,以假乱真;而Discriminator也不断优化,来提高识别能力,能够识别出Generator的把戏。二者的这种关系可以形象地通过下图展示。

Generator和Discriminator连接起来,形成一个比较大的深层网络,即为GAN网络。

场景描述

深度学习的各种算法在PAI上可以通过PAI-DSW进行实现,在PAI-DSW上进行训练数据,利用GAN自动生成二次元头像。

数据准备

首先需要准备真实的二次元头像作为数据集,这里从网上找到一些共享的资源,存储在了钉钉钉盘中,钉盘地址 ,提取密码: c2pz,数据集如下图示,约5万多张:

算法实践

利用PAI-DSW进行GAN算法实践,首先需要安装准备好环境。

首先进入到Notebook建模,创建新实例,之后打开实例,进入Terminal,在Terminal下用户可以像在自己本地一样安装相应的依赖包,进行操作。

准备好环境之后,我们可以通过如下图示方法,将基于Tensorflow的DCGAN代码和数据集上传上去。 ?

用于训练的DCGAN代码地址:https://github.com/carpedm20/DCGAN-tensorflow,关于DCGAN的网络框架图如下,详细介绍可以参考论文:https://arxiv.org/abs/1511.06434,这里我们不做详述。

数据集和代码上传成功,如下图示。

其中,data目录下的faces即为数据集,该文件夹下为对应的5万多张真实二次元头像。DCGAN-tensorflow为整个代码路径,其中最主要的两个代码文件是main.py和model.py,其中最主要的核心代码如下。

def main(_):
  pp.pprint(flags.FLAGS.__flags)

  if FLAGS.input_width is None:
    FLAGS.input_width = FLAGS.input_height
  if FLAGS.output_width is None:
    FLAGS.output_width = FLAGS.output_height

  if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)
  if not os.path.exists(FLAGS.sample_dir):
    os.makedirs(FLAGS.sample_dir)

  #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
  run_config = tf.ConfigProto()
  run_config.gpu_options.allow_growth=True

  with tf.Session(config=run_config) as sess:
    if FLAGS.dataset == ‘mnist‘:
      dcgan = DCGAN(
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          y_dim=10,
          z_dim=FLAGS.generate_test_images,
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          crop=FLAGS.crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir,
          data_dir=FLAGS.data_dir)
    else:
      dcgan = DCGAN(
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          z_dim=FLAGS.generate_test_images,
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          crop=FLAGS.crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir,
          data_dir=FLAGS.data_dir)

    show_all_variables()

    if FLAGS.train:
      dcgan.train(FLAGS)

        else:
          # Update D network
          _, summary_str = self.sess.run([d_optim, self.d_sum],
            feed_dict={ self.inputs: batch_images, self.z: batch_z })
          self.writer.add_summary(summary_str, counter)

          # Update G network
          _, summary_str = self.sess.run([g_optim, self.g_sum],
            feed_dict={ self.z: batch_z })
          self.writer.add_summary(summary_str, counter)

          # Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
          _, summary_str = self.sess.run([g_optim, self.g_sum],
            feed_dict={ self.z: batch_z })
          self.writer.add_summary(summary_str, counter)

          errD_fake = self.d_loss_fake.eval({ self.z: batch_z })
          errD_real = self.d_loss_real.eval({ self.inputs: batch_images })
          errG = self.g_loss.eval({self.z: batch_z})

一切就绪之后,我们执行命令进行训练,调用命令如下:

?python main.py --input_height 96 --input_width 96 --output_height 48 --output_width 48 --dataset faces --crop --train --epoch 300 --input_fname_pattern "*.jpg"

其中,参数dateset指定数据集的目录,epoch指定循环迭代的次数,input_height、input_width用于指定输入文件的大小,输出文件的大小同样也需要参数设定,代码执行过程如下图示:?

?

我们来看下执行结果,分别看一下epoch为1,30,100的时候生成的二次元头像效果图。

epoch=1

epoch=30

epoch=100?

我们发现,随着不断迭代,生成的二次元头像也越来越逼真。

总结

通过上面的实践,我们领略到了GAN的魅力,GAN的变种有很多,除此之外我们还可以利用GAN做非常多的有意思的事情,比如通过文字生成图像,通过简单文字生成宣传海报等。PAI-DSW像是一个练武场,为我们准备好了深度学习所需要的环境和条件,让我们可以尽情享受大数据和深度学习的乐趣,除了GAN,像比较火热的Bert等模型,我们也都可以试一试。



本文作者:不等_赵振才

原文链接

本文为云栖社区原创内容,未经允许不得转载。

原文地址:https://www.cnblogs.com/zhaowei121/p/10601025.html

时间: 2024-08-09 02:20:46

【机器学习PAI实战】—— 玩转人工智能之利用GAN自动生成二次元头像的相关文章

七色花基本权限系统(5)- 实体配置的使用和利用T4自动生成实体配置

在前面的章节里,用户表的结构非常简单,没有控制如何映射到数据库.通常,需要对字段的长度.是否可为空甚至特定数据类型进行设置,因为EntityFramework的默认映射规则相对而言比较简单和通用.在这篇日志里,将演示如何对数据实体进行映射配置,并利用T4模板自动创建映射配置类文件. 配置方式 EntityFramework的实体映射配置有2种. 第一种是直接在实体类中以特性的方式进行控制,这些特性有部分是EF实现的,也有部分是非EF实现的.也就是说,在数据实体层不引用EF的情况下,只能使用不全的

[06] 利用mybatis-generator自动生成代码

1.mybatis-generator 概述 MyBatis官方提供了逆向工程 mybatis-generator,可以针对数据库表自动生成MyBatis执行所需要的代码(如Mapper.java.Mapper.xml.POJO).mybatis-generator 有三种用法:命令行.eclipse插件.maven插件.而maven插件的方式比较通用,本文也将概述maven插件的使用方式. 2.pom.xml中配置plugin (官方文档:Running MyBatis Generator W

Spring boot入门(三):SpringBoot集成结合AdminLTE(Freemarker),利用generate自动生成代码,利用DataTable和PageHelper进行分页显示

关于SpringBoot和PageHelper,前篇博客已经介绍过Spring boot入门(二):Spring boot集成MySql,Mybatis和PageHelper插件,前篇博客大致讲述了SpringBoot如何集成Mybatis和Pagehelper,但是没有做出实际的范例,本篇博客是连接上一篇写的.通过AdminLTE前端框架,利用DataTable和PageHelper进行分页显示,通过对用户列表的增删改查操作,演示DataTable和PageHelper的使用. (1)Admi

【机器学习PAI实战】—— 玩转人工智能之美食推荐

前言 在生活中,我们经常给朋友推荐一些自己喜欢的东西,也时常接受别人的推荐.怎么能保证推荐的电影或者美食就是朋友喜欢的呢?一般来说,你们两个人经常对同一个电影或者美食感兴趣,那么你喜欢的东西就很大程度上朋友也会比较感兴趣.在大数据的背景下,算法会帮我寻找兴趣相似的那些人,并关注他们喜欢的东西,以此来给我们推荐可能喜欢的事物. 场景描述 某外卖店铺收集了一些用户对本店铺美食的评价和推荐分,并计划为一些新老客户推荐他们未曾尝试的美食. 数据分析 A B C D E F G H I J K 0[0,

利用mybatis-generator自动生成代码

mybatis-generator有三种用法:命令行.eclipse插件.maven插件.个人觉得maven插件最方便,可以在eclipse/intellij idea等ide上可以通用. 下面是从官网上的截图:(不过官网www.mybatis.org 最近一段时间,好象已经挂了) 一.在pom.xml中添加plugin 1 <plugin> 2 <groupId>org.mybatis.generator</groupId> 3 <artifactId>m

利用Python自动生成暴力破解的字典

Python是一款非常强大的语言.用于测试时它非常有效,因此Python越来越受到欢迎. 因此,在此次教程中我将聊一聊如何在Python中生成字典,并将它用于任何你想要的用途. 前提要求 1,Python 2.7(对于Python 3.x的版本基本相同,你只需要做一些微小调整) 2,Peace of mine(作者开的一个玩笑,这是一首歌名) 如果你用virtualenv搭建Python开发环境,请确保已经安装了itertools.因为我们将会用到itertools生成字典.我们将一步一步地演示

【转】利用mybatis-generator自动生成代码

本文转自:http://www.cnblogs.com/yjmyzz/p/4210554.html mybatis-generator有三种用法:命令行.eclipse插件.maven插件.个人觉得maven插件最方便,可以在eclipse/intellij idea等ide上可以通用. 下面是从官网上的截图:(不过官网www.mybatis.org 最近一段时间,好象已经挂了) 一.在pom.xml中添加plugin <plugin> <groupId>org.mybatis.g

利用mybatis-generator自动生成表实例类和映射文件

我们经常用到mybatis来进行程序代码级别对数据库的操作,然而需要编写大量的表实例类与映射文件,现在使用工具mybatis-generator就可实现上述文件的自动生成,下面简要介绍一下其使用方法. 1.创建工程 为了下载jar包比较方便,本人创建一个名为mybatis的maven工程来应用mybatis-generator. 2.修改pom.xml文件,下载依赖的jar包 <project xmlns="http://maven.apache.org/POM/4.0.0" x

mybatis利用maven自动生成mapper、xml、domain

第一种方式: 配置maven插件 在src/main/resources下新建generatorConfig.xml   内容如下: <?xml version="1.0" encoding="UTF-8"?> <!DOCTYPE generatorConfiguration PUBLIC "-//mybatis.org//DTD MyBatis Generator Configuration 1.0//EN" "ht