dual_GAN实现 (使用tensorflow)

具体实现地址:https://github.com/codehxj/DualGAN

以下是改编成notebook版本

import sys;
sys.path.append("/home/hxj/anaconda3/lib/python3.6/site-packages")
from __future__ import division
import math
import json
import random
import pprint
import scipy.misc
import numpy as np
import os
from time import gmtime, strftime

#pp = pprint.PrettyPrinter()

#get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1])

def load_data(image_path, flip=False, is_test=False, image_size = 128):
    img = load_image(image_path)
    img = preprocess_img(img, img_size=image_size, flip=flip, is_test=is_test)

    img = img/127.5 - 1.
    if len(img.shape)<3:
        img = np.expand_dims(img, axis=2)
    return img

def load_image(image_path):
    img = imread(image_path)
    return img

def preprocess_img(img, img_size=128, flip=False, is_test=False):
    img = scipy.misc.imresize(img, [img_size, img_size])
    if (not is_test) and flip and np.random.random() > 0.5:
        img = np.fliplr(img)
    return img

def get_image(image_path, image_size, is_crop=True, resize_w=64, is_grayscale = False):
    return transform(imread(image_path, is_grayscale), image_size, is_crop, resize_w)

def save_images(images, size, image_path):
    dir = os.path.dirname(image_path)
    if not os.path.exists(dir):
        os.makedirs(dir)
    return imsave(inverse_transform(images), size, image_path)

def imread(path, is_grayscale = False):
    if (is_grayscale):
        return scipy.misc.imread(path, flatten = True)#.astype(np.float)
    else:
        return scipy.misc.imread(path)#.astype(np.float)

def merge_images(images, size):
    return inverse_transform(images)

def merge(images, size):
    h, w = images.shape[1], images.shape[2]
    if len(images.shape) < 4:
        img = np.zeros((h * size[0], w * size[1], 1))
        images = np.expand_dims(images, axis = 3)
    else:
        img = np.zeros((h * size[0], w * size[1], images.shape[3]))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[j*h:j*h+h, i*w:i*w+w, :] = image
    if images.shape[3] ==1:
        return np.concatenate([img,img,img],axis=2)
    else:
        return img.astype(np.uint8)

def imsave(images, size, path):
    return scipy.misc.imsave(path, merge(images, size))

def transform(image, npx=64, is_crop=True, resize_w=64):
    # npx : # of pixels width/height of image
    if is_crop:
        cropped_image = center_crop(image, npx, resize_w=resize_w)
    else:
        cropped_image = image
    return np.array(cropped_image)/127.5 - 1.

def inverse_transform(images):
    return ((images+1.)*127.5)
import tensorflow as tf
from tensorflow.python.framework import ops
from utils import *
def batch_norm(x,  name="batch_norm"):
    eps = 1e-6
    with tf.variable_scope(name):
        nchannels = x.get_shape()[3]
        scale = tf.get_variable("scale", [nchannels], initializer=tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32))
        center = tf.get_variable("center", [nchannels], initializer=tf.constant_initializer(0.0, dtype = tf.float32))
        ave, dev = tf.nn.moments(x, axes=[1,2], keep_dims=True)
        inv_dev = tf.rsqrt(dev + eps)
        normalized = (x-ave)*inv_dev * scale + center
        return normalized

def conv2d(input_, output_dim,
           k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
           name="conv2d"):
    with tf.variable_scope(name):
        w = tf.get_variable(‘w‘, [k_h, k_w, input_.get_shape()[-1], output_dim],
                            initializer=tf.truncated_normal_initializer(stddev=stddev))
        conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding=‘SAME‘)

        biases = tf.get_variable(‘biases‘, [output_dim], initializer=tf.constant_initializer(0.0))
        conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())

        return conv

def deconv2d(input_, output_shape,
             k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
             name="deconv2d", with_w=False):
    with tf.variable_scope(name):
        # filter : [height, width, output_channels, in_channels]
        w = tf.get_variable(‘w‘, [k_h, k_w, output_shape[-1], input_.get_shape()[-1]],
                            initializer=tf.random_normal_initializer(stddev=stddev))
        try:
            deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape,
                                strides=[1, d_h, d_w, 1])

        # Support for verisons of TensorFlow before 0.7.0
        except AttributeError:
            deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape,
                                strides=[1, d_h, d_w, 1])

        biases = tf.get_variable(‘biases‘, [output_shape[-1]], initializer=tf.constant_initializer(0.0))
        deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())

        if with_w:
            return deconv, w, biases
        else:
            return deconv

def lrelu(x, leak=0.2, name="lrelu"):
    return tf.maximum(x, leak*x)

def celoss(logits, labels):
    return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels))
       
from __future__ import division
import os
import time
from glob import glob
import tensorflow as tf
import numpy as np
from six.moves import xrange

#from ops import *
#from utils import *

class DualNet(object):
    def __init__(self, sess, image_size=256, batch_size=1,fcn_filter_dim = 64,                   A_channels = 3, B_channels = 3, dataset_name=‘‘,                  checkpoint_dir=None, lambda_A = 20., lambda_B = 20.,                  sample_dir=None, loss_metric = ‘L1‘, flip = False):
        self.df_dim = fcn_filter_dim
        self.flip = flip
        self.lambda_A = lambda_A
        self.lambda_B = lambda_B

        self.sess = sess
        self.is_grayscale_A = (A_channels == 1)
        self.is_grayscale_B = (B_channels == 1)
        self.batch_size = batch_size
        self.image_size = image_size
        self.fcn_filter_dim = fcn_filter_dim
        self.A_channels = A_channels
        self.B_channels = B_channels
        self.loss_metric = loss_metric

        self.dataset_name = dataset_name
        self.checkpoint_dir = checkpoint_dir

        #directory name for output and logs saving
        self.dir_name = "%s-img_sz_%s-fltr_dim_%d-%s-lambda_AB_%s_%s" % (
                    self.dataset_name,
                    self.image_size,
                    self.fcn_filter_dim,
                    self.loss_metric,
                    self.lambda_A,
                    self.lambda_B
                )
        self.build_model()

    def build_model(self):
    ###    define place holders
        self.real_A = tf.placeholder(tf.float32,[self.batch_size, self.image_size, self.image_size,
                                         self.A_channels ],name=‘real_A‘)
        self.real_B = tf.placeholder(tf.float32, [self.batch_size, self.image_size, self.image_size,
                                         self.B_channels ], name=‘real_B‘)

    ###  define graphs
        self.A2B = self.A_g_net(self.real_A, reuse = False)
        self.B2A = self.B_g_net(self.real_B, reuse = False)
        self.A2B2A = self.B_g_net(self.A2B, reuse = True)
        self.B2A2B = self.A_g_net(self.B2A, reuse = True)

        if self.loss_metric == ‘L1‘:
            self.A_loss = tf.reduce_mean(tf.abs(self.A2B2A - self.real_A))
            self.B_loss = tf.reduce_mean(tf.abs(self.B2A2B - self.real_B))
        elif self.loss_metric == ‘L2‘:
            self.A_loss = tf.reduce_mean(tf.square(self.A2B2A - self.real_A))
            self.B_loss = tf.reduce_mean(tf.square(self.B2A2B - self.real_B))

        self.Ad_logits_fake = self.A_d_net(self.A2B, reuse = False)
        self.Ad_logits_real = self.A_d_net(self.real_B, reuse = True)
        self.Ad_loss_real = celoss(self.Ad_logits_real, tf.ones_like(self.Ad_logits_real))
        self.Ad_loss_fake = celoss(self.Ad_logits_fake, tf.zeros_like(self.Ad_logits_fake))
        self.Ad_loss = self.Ad_loss_fake + self.Ad_loss_real
        self.Ag_loss = celoss(self.Ad_logits_fake, labels=tf.ones_like(self.Ad_logits_fake))+self.lambda_B * (self.B_loss )

        self.Bd_logits_fake = self.B_d_net(self.B2A, reuse = False)
        self.Bd_logits_real = self.B_d_net(self.real_A, reuse = True)
        self.Bd_loss_real = celoss(self.Bd_logits_real, tf.ones_like(self.Bd_logits_real))
        self.Bd_loss_fake = celoss(self.Bd_logits_fake, tf.zeros_like(self.Bd_logits_fake))
        self.Bd_loss = self.Bd_loss_fake + self.Bd_loss_real
        self.Bg_loss = celoss(self.Bd_logits_fake, tf.ones_like(self.Bd_logits_fake))+self.lambda_A * (self.A_loss)

        self.d_loss = self.Ad_loss + self.Bd_loss
        self.g_loss = self.Ag_loss + self.Bg_loss
        ## define trainable variables
        t_vars = tf.trainable_variables()
        self.A_d_vars = [var for var in t_vars if ‘A_d_‘ in var.name]
        self.B_d_vars = [var for var in t_vars if ‘B_d_‘ in var.name]
        self.A_g_vars = [var for var in t_vars if ‘A_g_‘ in var.name]
        self.B_g_vars = [var for var in t_vars if ‘B_g_‘ in var.name]
        self.d_vars = self.A_d_vars + self.B_d_vars
        self.g_vars = self.A_g_vars + self.B_g_vars
        self.saver = tf.train.Saver()

    def clip_trainable_vars(self, var_list):
        for var in var_list:
            self.sess.run(var.assign(tf.clip_by_value(var, -self.c, self.c)))

    def load_random_samples(self):
        #np.random.choice(
        sample_files =np.random.choice(glob(‘./datasets/{}/val/A/*.jpg‘.format(self.dataset_name)),self.batch_size)
        sample_A_imgs = [load_data(f, image_size =self.image_size, flip = False) for f in sample_files]

        sample_files = np.random.choice(glob(‘./datasets/{}/val/B/*.jpg‘.format(self.dataset_name)),self.batch_size)
        sample_B_imgs = [load_data(f, image_size =self.image_size, flip = False) for f in sample_files]

        sample_A_imgs = np.reshape(np.array(sample_A_imgs).astype(np.float32),(self.batch_size,self.image_size, self.image_size,-1))
        sample_B_imgs = np.reshape(np.array(sample_B_imgs).astype(np.float32),(self.batch_size,self.image_size, self.image_size,-1))
        return sample_A_imgs, sample_B_imgs

    def sample_shotcut(self, sample_dir, epoch_idx, batch_idx):
        sample_A_imgs,sample_B_imgs = self.load_random_samples()

        Ag, A2B2A_imgs, A2B_imgs = self.sess.run([self.A_loss, self.A2B2A, self.A2B], feed_dict={self.real_A: sample_A_imgs, self.real_B: sample_B_imgs})
        Bg, B2A2B_imgs, B2A_imgs = self.sess.run([self.B_loss, self.B2A2B, self.B2A], feed_dict={self.real_A: sample_A_imgs, self.real_B: sample_B_imgs})

        save_images(A2B_imgs, [self.batch_size,1], ‘./{}/{}/{:06d}_{:04d}_A2B.jpg‘.format(sample_dir,self.dir_name , epoch_idx, batch_idx))
        save_images(A2B2A_imgs, [self.batch_size,1],    ‘./{}/{}/{:06d}_{:04d}_A2B2A.jpg‘.format(sample_dir,self.dir_name, epoch_idx,  batch_idx))

        save_images(B2A_imgs, [self.batch_size,1], ‘./{}/{}/{:06d}_{:04d}_B2A.jpg‘.format(sample_dir,self.dir_name, epoch_idx, batch_idx))
        save_images(B2A2B_imgs, [self.batch_size,1], ‘./{}/{}/{:06d}_{:04d}_B2A2B.jpg‘.format(sample_dir,self.dir_name, epoch_idx, batch_idx))

        print("[Sample] A_loss: {:.8f}, B_loss: {:.8f}".format(Ag, Bg))

    def train(self, args):
        """Train Dual GAN"""
        decay = 0.9
        self.d_optim = tf.train.RMSPropOptimizer(args.lr, decay=decay)                           .minimize(self.d_loss, var_list=self.d_vars)

        self.g_optim = tf.train.RMSPropOptimizer(args.lr, decay=decay)                           .minimize(self.g_loss, var_list=self.g_vars)
        tf.global_variables_initializer().run()

        self.writer = tf.summary.FileWriter("./logs/"+self.dir_name, self.sess.graph)

        step = 1
        start_time = time.time()

        if self.load(self.checkpoint_dir):
            print(" [*] Load SUCCESS")
        else:
            print(" Load failed...ignored...")
            print(" start training...")

        for epoch_idx in xrange(args.epoch):
            data_A = glob(‘./datasets/{}/train/A/*.jpg‘.format(self.dataset_name))
            data_B = glob(‘./datasets/{}/train/B/*.jpg‘.format(self.dataset_name))
            np.random.shuffle(data_A)
            np.random.shuffle(data_B)
            epoch_size = min(len(data_A), len(data_B)) // (self.batch_size)
            print(‘[*] training data loaded successfully‘)
            print("#data_A: %d  #data_B:%d" %(len(data_A),len(data_B)))
            print(‘[*] run optimizor...‘)

            for batch_idx in xrange(0, epoch_size):
                imgA_batch = self.load_training_imgs(data_A, batch_idx)
                imgB_batch = self.load_training_imgs(data_B, batch_idx)

                print("Epoch: [%2d] [%4d/%4d]"%(epoch_idx, batch_idx, epoch_size))
                step = step + 1
                self.run_optim(imgA_batch, imgB_batch, step, start_time)

                if np.mod(step, 100) == 1:
                    self.sample_shotcut(args.sample_dir, epoch_idx, batch_idx)

                if np.mod(step, args.save_freq) == 2:
                    self.save(args.checkpoint_dir, step)

    def load_training_imgs(self, files, idx):
        batch_files = files[idx*self.batch_size:(idx+1)*self.batch_size]
        batch_imgs = [load_data(f, image_size =self.image_size, flip = self.flip) for f in batch_files]

        batch_imgs = np.reshape(np.array(batch_imgs).astype(np.float32),(self.batch_size,self.image_size, self.image_size,-1))

        return batch_imgs

    def run_optim(self,batch_A_imgs, batch_B_imgs,  counter, start_time):
        _, Adfake,Adreal,Bdfake,Bdreal, Ad, Bd = self.sess.run(
            [self.d_optim, self.Ad_loss_fake, self.Ad_loss_real, self.Bd_loss_fake, self.Bd_loss_real, self.Ad_loss, self.Bd_loss],
            feed_dict = {self.real_A: batch_A_imgs, self.real_B: batch_B_imgs})
        _, Ag, Bg, Aloss, Bloss = self.sess.run(
            [self.g_optim, self.Ag_loss, self.Bg_loss, self.A_loss, self.B_loss],
            feed_dict={ self.real_A: batch_A_imgs, self.real_B: batch_B_imgs})

        _, Ag, Bg, Aloss, Bloss = self.sess.run(
            [self.g_optim, self.Ag_loss, self.Bg_loss, self.A_loss, self.B_loss],
            feed_dict={ self.real_A: batch_A_imgs, self.real_B: batch_B_imgs})

        print("time: %4.4f, Ad: %.2f, Ag: %.2f, Bd: %.2f, Bg: %.2f,  U_diff: %.5f, V_diff: %.5f"                     % (time.time() - start_time, Ad,Ag,Bd,Bg, Aloss, Bloss))
        print("Ad_fake: %.2f, Ad_real: %.2f, Bd_fake: %.2f, Bg_real: %.2f" % (Adfake,Adreal,Bdfake,Bdreal))

    def A_d_net(self, imgs, y = None, reuse = False):
        return self.discriminator(imgs, prefix = ‘A_d_‘, reuse = reuse)

    def B_d_net(self, imgs, y = None, reuse = False):
        return self.discriminator(imgs, prefix = ‘B_d_‘, reuse = reuse)

    def discriminator(self, image,  y=None, prefix=‘A_d_‘, reuse=False):
        # image is 256 x 256 x (input_c_dim + output_c_dim)
        with tf.variable_scope(tf.get_variable_scope()) as scope:
            if reuse:
                scope.reuse_variables()
            else:
                assert scope.reuse == False

            h0 = lrelu(conv2d(image, self.df_dim, name=prefix+‘h0_conv‘))
            # h0 is (128 x 128 x self.df_dim)
            h1 = lrelu(batch_norm(conv2d(h0, self.df_dim*2, name=prefix+‘h1_conv‘), name = prefix+‘bn1‘))
            # h1 is (64 x 64 x self.df_dim*2)
            h2 = lrelu(batch_norm(conv2d(h1, self.df_dim*4, name=prefix+‘h2_conv‘), name = prefix+ ‘bn2‘))
            # h2 is (32x 32 x self.df_dim*4)
            h3 = lrelu(batch_norm(conv2d(h2, self.df_dim*8, d_h=1, d_w=1, name=prefix+‘h3_conv‘), name = prefix+ ‘bn3‘))
            # h3 is (32 x 32 x self.df_dim*8)
            h4 = conv2d(h3, 1, d_h=1, d_w=1, name =prefix+‘h4‘)
            return h4

    def A_g_net(self, imgs, reuse=False):
        return self.fcn(imgs, prefix=‘A_g_‘, reuse = reuse)

    def B_g_net(self, imgs, reuse=False):
        return self.fcn(imgs, prefix = ‘B_g_‘, reuse = reuse)

    def fcn(self, imgs, prefix=None, reuse = False):
        with tf.variable_scope(tf.get_variable_scope()) as scope:
            if reuse:
                scope.reuse_variables()
            else:
                assert scope.reuse == False

            s = self.image_size
            s2, s4, s8, s16, s32, s64, s128 = int(s/2), int(s/4), int(s/8), int(s/16), int(s/32), int(s/64), int(s/128)

            # imgs is (256 x 256 x input_c_dim)
            e1 = conv2d(imgs, self.fcn_filter_dim, name=prefix+‘e1_conv‘)
            # e1 is (128 x 128 x self.fcn_filter_dim)
            e2 = batch_norm(conv2d(lrelu(e1), self.fcn_filter_dim*2, name=prefix+‘e2_conv‘), name = prefix+‘bn_e2‘)
            # e2 is (64 x 64 x self.fcn_filter_dim*2)
            e3 = batch_norm(conv2d(lrelu(e2), self.fcn_filter_dim*4, name=prefix+‘e3_conv‘), name = prefix+‘bn_e3‘)
            # e3 is (32 x 32 x self.fcn_filter_dim*4)
            e4 = batch_norm(conv2d(lrelu(e3), self.fcn_filter_dim*8, name=prefix+‘e4_conv‘), name = prefix+‘bn_e4‘)
            # e4 is (16 x 16 x self.fcn_filter_dim*8)
            e5 = batch_norm(conv2d(lrelu(e4), self.fcn_filter_dim*8, name=prefix+‘e5_conv‘), name = prefix+‘bn_e5‘)
            # e5 is (8 x 8 x self.fcn_filter_dim*8)
            e6 = batch_norm(conv2d(lrelu(e5), self.fcn_filter_dim*8, name=prefix+‘e6_conv‘), name = prefix+‘bn_e6‘)
            # e6 is (4 x 4 x self.fcn_filter_dim*8)
            e7 = batch_norm(conv2d(lrelu(e6), self.fcn_filter_dim*8, name=prefix+‘e7_conv‘), name = prefix+‘bn_e7‘)
            # e7 is (2 x 2 x self.fcn_filter_dim*8)
            e8 = batch_norm(conv2d(lrelu(e7), self.fcn_filter_dim*8, name=prefix+‘e8_conv‘), name = prefix+‘bn_e8‘)
            # e8 is (1 x 1 x self.fcn_filter_dim*8)

            self.d1, self.d1_w, self.d1_b = deconv2d(tf.nn.relu(e8),
                [self.batch_size, s128, s128, self.fcn_filter_dim*8], name=prefix+‘d1‘, with_w=True)
            d1 = tf.nn.dropout(batch_norm(self.d1, name = prefix+‘bn_d1‘), 0.5)
            d1 = tf.concat([d1, e7],3)
            # d1 is (2 x 2 x self.fcn_filter_dim*8*2)

            self.d2, self.d2_w, self.d2_b = deconv2d(tf.nn.relu(d1),
                [self.batch_size, s64, s64, self.fcn_filter_dim*8], name=prefix+‘d2‘, with_w=True)
            d2 = tf.nn.dropout(batch_norm(self.d2, name = prefix+‘bn_d2‘), 0.5)

            d2 = tf.concat([d2, e6],3)
            # d2 is (4 x 4 x self.fcn_filter_dim*8*2)

            self.d3, self.d3_w, self.d3_b = deconv2d(tf.nn.relu(d2),
                [self.batch_size, s32, s32, self.fcn_filter_dim*8], name=prefix+‘d3‘, with_w=True)
            d3 = tf.nn.dropout(batch_norm(self.d3, name = prefix+‘bn_d3‘), 0.5)

            d3 = tf.concat([d3, e5],3)
            # d3 is (8 x 8 x self.fcn_filter_dim*8*2)

            self.d4, self.d4_w, self.d4_b = deconv2d(tf.nn.relu(d3),
                [self.batch_size, s16, s16, self.fcn_filter_dim*8], name=prefix+‘d4‘, with_w=True)
            d4 = batch_norm(self.d4, name = prefix+‘bn_d4‘)

            d4 = tf.concat([d4, e4],3)
            # d4 is (16 x 16 x self.fcn_filter_dim*8*2)

            self.d5, self.d5_w, self.d5_b = deconv2d(tf.nn.relu(d4),
                [self.batch_size, s8, s8, self.fcn_filter_dim*4], name=prefix+‘d5‘, with_w=True)
            d5 = batch_norm(self.d5, name = prefix+‘bn_d5‘)
            d5 = tf.concat([d5, e3],3)
            # d5 is (32 x 32 x self.fcn_filter_dim*4*2)

            self.d6, self.d6_w, self.d6_b = deconv2d(tf.nn.relu(d5),
                [self.batch_size, s4, s4, self.fcn_filter_dim*2], name=prefix+‘d6‘, with_w=True)
            d6 = batch_norm(self.d6, name = prefix+‘bn_d6‘)
            d6 = tf.concat([d6, e2],3)
            # d6 is (64 x 64 x self.fcn_filter_dim*2*2)

            self.d7, self.d7_w, self.d7_b = deconv2d(tf.nn.relu(d6),
                [self.batch_size, s2, s2, self.fcn_filter_dim], name=prefix+‘d7‘, with_w=True)
            d7 = batch_norm(self.d7, name = prefix+‘bn_d7‘)
            d7 = tf.concat([d7, e1],3)
            # d7 is (128 x 128 x self.fcn_filter_dim*1*2)

            if prefix == ‘B_g_‘:
                self.d8, self.d8_w, self.d8_b = deconv2d(tf.nn.relu(d7),[self.batch_size, s, s, self.A_channels], name=prefix+‘d8‘, with_w=True)
            elif prefix == ‘A_g_‘:
                self.d8, self.d8_w, self.d8_b = deconv2d(tf.nn.relu(d7),[self.batch_size, s, s, self.B_channels], name=prefix+‘d8‘, with_w=True)
             # d8 is (256 x 256 x output_c_dim)
            return tf.nn.tanh(self.d8)

    def save(self, checkpoint_dir, step):
        model_name = "DualNet.model"
        model_dir = self.dir_name
        checkpoint_dir = os.path.join(checkpoint_dir, model_dir)

        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        self.saver.save(self.sess,
                        os.path.join(checkpoint_dir, model_name),
                        global_step=step)

    def load(self, checkpoint_dir):
        print(" [*] Reading checkpoint...")

        model_dir =  self.dir_name
        checkpoint_dir = os.path.join(checkpoint_dir, model_dir)

        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
            return True
        else:
            return False

    def test(self, args):
        """Test DualNet"""
        start_time = time.time()
        tf.global_variables_initializer().run()
        if self.load(self.checkpoint_dir):
            print(" [*] Load SUCCESS")
            test_dir = ‘./{}/{}‘.format(args.test_dir, self.dir_name)
            if not os.path.exists(test_dir):
                os.makedirs(test_dir)
            test_log = open(test_dir+‘evaluation.txt‘,‘a‘)
            test_log.write(self.dir_name)
            self.test_domain(args, test_log, type = ‘A‘)
            self.test_domain(args, test_log, type = ‘B‘)
            test_log.close()

    def test_domain(self, args, test_log, type = ‘A‘):
        test_files = glob(‘./datasets/{}/val/{}/*.jpg‘.format(self.dataset_name,type))
        # load testing input
        print("Loading testing images ...")
        test_imgs = [load_data(f, is_test=True, image_size =self.image_size, flip = args.flip) for f in test_files]
        print("#images loaded: %d"%(len(test_imgs)))
        test_imgs = np.reshape(np.asarray(test_imgs).astype(np.float32),(len(test_files),self.image_size, self.image_size,-1))
        test_imgs = [test_imgs[i*self.batch_size:(i+1)*self.batch_size]
                         for i in xrange(0, len(test_imgs)//self.batch_size)]
        test_imgs = np.asarray(test_imgs)
        test_path = ‘./{}/{}/‘.format(args.test_dir, self.dir_name)
        # test input samples
        if type == ‘A‘:
            for i in xrange(0, len(test_files)//self.batch_size):
                filename_o = test_files[i*self.batch_size].split(‘/‘)[-1].split(‘.‘)[0]
                print(filename_o)
                idx = i+1
                A_imgs = np.reshape(np.array(test_imgs[i]), (self.batch_size,self.image_size, self.image_size,-1))
                print("testing A image %d"%(idx))
                print(A_imgs.shape)
                A2B_imgs, A2B2A_imgs = self.sess.run(
                    [self.A2B, self.A2B2A],
                    feed_dict={self.real_A: A_imgs}
                    )
                save_images(A_imgs, [self.batch_size, 1], test_path+filename_o+‘_realA.jpg‘)
                save_images(A2B_imgs, [self.batch_size, 1], test_path+filename_o+‘_A2B.jpg‘)
                save_images(A2B2A_imgs, [self.batch_size, 1], test_path+filename_o+‘_A2B2A.jpg‘)
        elif type==‘B‘:
            for i in xrange(0, len(test_files)//self.batch_size):
                filename_o = test_files[i*self.batch_size].split(‘/‘)[-1].split(‘.‘)[0]
                idx = i+1
                B_imgs = np.reshape(np.array(test_imgs[i]), (self.batch_size,self.image_size, self.image_size,-1))
                print("testing B image %d"%(idx))
                B2A_imgs, B2A2B_imgs = self.sess.run(
                    [self.B2A, self.B2A2B],
                    feed_dict={self.real_B:B_imgs}
                    )
                save_images(B_imgs, [self.batch_size, 1],test_path+filename_o+‘_realB.jpg‘)
                save_images(B2A_imgs, [self.batch_size, 1],test_path+filename_o+‘_B2A.jpg‘)
                save_images(B2A2B_imgs, [self.batch_size, 1],test_path+filename_o+‘_B2A2B.jpg‘)
import argparse
#from model import DualNet
import tensorflow as tf

parser = argparse.ArgumentParser(description=‘Argument parser‘)

""" Arguments related to network architecture"""
#parser.add_argument(‘--network_type‘, dest=‘network_type‘, default=‘fcn_4‘, help=‘fcn_1,fcn_2,fcn_4,fcn_8, fcn_16, fcn_32, fcn_64, fcn_128‘)
parser.add_argument(‘--image_size‘, dest=‘image_size‘, type=int, default=256, help=‘size of input images (applicable to both A images and B images)‘)
parser.add_argument(‘--fcn_filter_dim‘, dest=‘fcn_filter_dim‘, type=int, default=64, help=‘# of fcn filters in first conv layer‘)
parser.add_argument(‘--A_channels‘, dest=‘A_channels‘, type=int, default=1, help=‘# of channels of image A‘)
parser.add_argument(‘--B_channels‘, dest=‘B_channels‘, type=int, default=1, help=‘# of channels of image B‘)

"""Arguments related to run mode"""
parser.add_argument(‘--phase‘, dest=‘phase‘, default=‘train‘, help=‘train, test‘)

"""Arguments related to training"""
parser.add_argument(‘--loss_metric‘, dest=‘loss_metric‘, default=‘L1‘, help=‘L1, or L2‘)
parser.add_argument(‘--niter‘, dest=‘niter‘, type=int, default=30, help=‘# of iter at starting learning rate‘)
parser.add_argument(‘--lr‘, dest=‘lr‘, type=float, default=0.00005, help=‘initial learning rate for adam‘)#0.0002
parser.add_argument(‘--beta1‘, dest=‘beta1‘, type=float, default=0.5, help=‘momentum term of adam‘)
parser.add_argument(‘--flip‘, dest=‘flip‘, type=bool, default=True, help=‘if flip the images for data argumentation‘)
parser.add_argument(‘--dataset_name‘, dest=‘dataset_name‘, default=‘sketch-photo‘, help=‘name of the dataset‘)
parser.add_argument(‘--epoch‘, dest=‘epoch‘, type=int, default=50, help=‘# of epoch‘)
parser.add_argument(‘--batch_size‘, dest=‘batch_size‘, type=int, default=1, help=‘# images in batch‘)
parser.add_argument(‘--lambda_A‘, dest=‘lambda_A‘, type=float, default=20.0, help=‘# weights of A recovery loss‘)
parser.add_argument(‘--lambda_B‘, dest=‘lambda_B‘, type=float, default=20.0, help=‘# weights of B recovery loss‘)

"""Arguments related to monitoring and outputs"""
parser.add_argument(‘--save_freq‘, dest=‘save_freq‘, type=int, default=50, help=‘save the model every save_freq sgd iterations‘)
parser.add_argument(‘--checkpoint_dir‘, dest=‘checkpoint_dir‘, default=‘./checkpoint‘, help=‘models are saved here‘)
parser.add_argument(‘--sample_dir‘, dest=‘sample_dir‘, default=‘./sample‘, help=‘sample are saved here‘)
parser.add_argument(‘--test_dir‘, dest=‘test_dir‘, default=‘./test‘, help=‘test sample are saved here‘)

args=parser.parse_args([]) #这里是关键,使用notebook的必须加上[],如果在命令行使用 args=parser.parse_args()

def main(_):
    if not os.path.exists(args.checkpoint_dir):
        os.makedirs(args.checkpoint_dir)
    if not os.path.exists(args.sample_dir):
        os.makedirs(args.sample_dir)
    if not os.path.exists(args.test_dir):
        os.makedirs(args.test_dir)

    with tf.Session() as sess:
        model = DualNet(sess, image_size=args.image_size, batch_size=args.batch_size,                        dataset_name=args.dataset_name,A_channels = args.A_channels,                         B_channels = args.B_channels, flip  = (args.flip == ‘True‘),                        checkpoint_dir=args.checkpoint_dir, sample_dir=args.sample_dir,                        fcn_filter_dim = args.fcn_filter_dim,                        loss_metric=args.loss_metric, lambda_B=args.lambda_B,                         lambda_A= args.lambda_A)

        if args.phase == ‘train‘:
            model.train(args)
        else:
            model.test(args)

if __name__ == ‘__main__‘:
    tf.app.run()

原文地址:https://www.cnblogs.com/hxjbc/p/8964208.html

时间: 2024-10-08 13:45:07

dual_GAN实现 (使用tensorflow)的相关文章

在Win10 Anaconda中安装Tensorflow

有需要的朋友可以参考一下 1.安装Anaconda 下载:https://www.continuum.io/downloads,我用的是Python 3.5 下载完以后,安装. 安装完以后,打开Anaconda Prompt,输入清华的仓库镜像,更新包更快: conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ conda config --set show_channel_url

Tensorflow 梯度下降实例

# coding: utf-8 # #### 假设我们要最小化函数 $y=x^2$, 选择初始点 $x_0=5$ # #### 1. 学习率为1的时候,x在5和-5之间震荡. # In[1]: import tensorflow as tf TRAINING_STEPS = 10 LEARNING_RATE = 1 x = tf.Variable(tf.constant(5, dtype=tf.float32), name="x") y = tf.square(x) train_op

Ubuntu16.04安装tensorflow+安装opencv+安装openslide+安装搜狗输入法

Ubuntu16.04在cuda以及cudnn安装好之后,安装tensorflow,tensorflow以及opencv可以到网上下载对应的安装包并且直接在安装包所在的路径下直接通过pip与conda进行安装,如下图所示: 前提是要下载好安装包.安装好tensorflow之后还需要进行在~/.bashrc文件中添加系统路径,如下图所示 Openslide是医学图像一个重要的库,这里给出三条命令进行安装 sudo apt-get install openslide-tools sudo apt-g

【tensorflow:Google】三、tensorflow入门

[一]计算图模型 节点是计算,边是数据流, a = tf.constant( [1., 2.] )定义的是节点,节点有属性 a.graph 取得默认计算图 g1 = tf.get_default_graph() 初始化计算图 g1 = tf.Graph() 设置default图 g1.as_default() 定义变量: tf.get_variable('v') 读取变量也是上述函数 对图指定设备 g.device('/gpu:0') 可以定义集合来管理计算图中的资源, 加入集合 tf.add_

TensorFlow之tf.unstack学习循环神经网络中用到!

unstack( value, num=None, axis=0, name='unstack' ) tf.unstack() 将给定的R维张量拆分成R-1维张量 将value根据axis分解成num个张量,返回的值是list类型,如果没有指定num则根据axis推断出! DEMO: import tensorflow as tf a = tf.constant([3,2,4,5,6]) b = tf.constant([1,6,7,8,0]) c = tf.stack([a,b],axis=0

TensorFlow conv2d实现卷积

tf.nn.conv2d是TensorFlow里面实现卷积的函数,参考文档对它的介绍并不是很详细,实际上这是搭建卷积神经网络比较核心的一个方法,非常重要 tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, name=None) 除去name参数用以指定该操作的name,与方法有关的一共五个参数: 第一个参数input:指需要做卷积的输入图像,它要求是一个Tensor,具有[batch, in_height, i

Tensorflow一些常用基本概念与函数(四)

摘要:本系列主要对tf的一些常用概念与方法进行描述.本文主要针对tensorflow的模型训练Training与测试Testing等相关函数进行讲解.为'Tensorflow一些常用基本概念与函数'系列之四. 1.序言 本文所讲的内容主要为以下列表中相关函数.函数training()通过梯度下降法为最小化损失函数增加了相关的优化操作,在训练过程中,先实例化一个优化函数,比如 tf.train.GradientDescentOptimizer,并基于一定的学习率进行梯度优化训练: optimize

Tensorflow一些常用基本概念与函数(三)

摘要:本系列主要对tf的一些常用概念与方法进行描述.本文主要针对tensorflow的数据IO.图的运行等相关函数进行讲解.为'Tensorflow一些常用基本概念与函数'系列之三. 1.序言 本文所讲的内容主要为以下相关函数: 操作组 操作 Data IO (Python functions) TFRecordWrite,rtf_record_iterator Running Graphs Session management,Error classes 2.tf函数 2.1 数据IO {Da

TensorFlow【机器学习】:如何正确的掌握Google深度学习框架TensorFlow(第二代分布式机器学习系统)?

本文标签:   机器学习 TensorFlow Google深度学习框架 分布式机器学习 唐源 VGG REST   服务器 自 2015 年底开源到如今更快.更灵活.更方便的 1.0 版本正式发布,由 Google 推出的第二代分布式机器学习系统 TensorFlow一直在为我们带来惊喜,一方面是技术层面持续的迭代演进,从分布式版本.服务框架 TensorFlow Serving.上层封装 TF.Learn 到 Windows 支持.JIT 编译器 XLA.动态计算图框架 Fold 等,以及