总体来讲keras这个深度学习框架真的很“简易”,它体现在可参考的文档写的比较详细,不像caffe,装完以后都得靠技术博客,keras有它自己的官方文档(不过是英文的),这给初学者提供了很大的学习空间。
在此做下代码框架应用笔记
class VGGNetwork: def append_vgg_network(self, x_in, true_X_input): return x #x is output of VGG def load_vgg_weight(self, model): return model class DiscriminatorNetwork: def append_gan_network(self, true_X_input): return x class GenerativeNetwork: def create_sr_model(self, ip): return x def get_generator_output(self, input_img, srgan_model): return self.output_func([input_img]) class SRGANNetwork: def build_srgan_pretrain_model(self): return self.srgan_model_ def build_discriminator_pretrain_model(self): return self.discriminative_model_ def build_srgan_model(self): return self.srgan_model_ def pre_train_srgan(self, image_dir, nb_images=50000, nb_epochs=1, use_small_srgan=False): for i in range(nb_epochs): for x in datagen.flow_from_directory if iteration % 50 == 0 and iteration != 0 validation//print psnr Train only generator + vgg network if iteration % 1000 == 0 and iteration != 0 Saving model weights def pre_train_discriminator(self, image_dir, nb_images=50000, nb_epochs=1, batch_size=128): for i in range(nb_epochs): for x in datagen.flow_from_directory Train only discriminator if iteration % 1000 == 0 and iteration != 0 Saving model weights def train_full_model(self, image_dir, nb_images=50000, nb_epochs=10): for i in range(nb_epochs): for x in datagen.flow_from_directory if iteration % 50 == 0 and iteration != 0 validation//print psnr if iteration % 1000 == 0 and iteration != 0 Saving model weights Train only discriminator, disable training of srgan Train only generator, disable training of discriminator if __name__ == "__main__": from keras.utils.visualize_util import plot # Path to MS COCO dataset coco_path = r"D:\Yue\Documents\Dataset\coco2014\train2014" ‘‘‘ Base Network manager for the SRGAN model Width / Height = 32 to reduce the memory requirement for the discriminator. Batch size = 1 is slower, but uses the least amount of gpu memory, and also acts as Instance Normalization (batch norm with 1 input image) which speeds up training slightly. ‘‘‘ srgan_network = SRGANNetwork(img_width=32, img_height=32, batch_size=1) srgan_network.build_srgan_model() #plot(srgan_network.srgan_model_, ‘SRGAN.png‘, show_shapes=True) # Pretrain the SRGAN network #srgan_network.pre_train_srgan(coco_path, nb_images=80000, nb_epochs=1) # Pretrain the discriminator network #srgan_network.pre_train_discriminator(coco_path, nb_images=40000, nb_epochs=1, batch_size=16) # Fully train the SRGAN with VGG loss and Discriminator loss srgan_network.train_full_model(coco_path, nb_images=80000, nb_epochs=5)
原文地址:https://www.cnblogs.com/68xi/p/8590600.html
时间: 2024-11-09 08:54:17