Keras用动态数据生成器(DataGenerator)和fitgenerator动态训练模型

 

最近做Kaggle的图像分类比赛:RSNA Intracranial Hemorrhage Detection (https://www.kaggle.com/c/rsna-intracranial-hemorrhage-detection/overview)以及阅读Yolov3

源码的时候接触到深度学习训练时一个有趣的技巧,那就是构造生成器generator 并且用Keras 的fit_generator来批量生成数据,释放内存,该方法适合于大规模数据集的训练。一个DataGenerator是keras的Sequence类的继承类,一般要包含__len__,__getitem__, on_epoch_end等方法,例如下面的批量图片数据生成器:

class DataGenerator(keras.utils.Sequence):

      def __init__(self, list_IDs, labels, batch_size=1, img_size=(512, 512),
                   img_dir, *args, **kwargs):

         """
            self.list_IDs:存放所有需要训练的图片文件名的列表。
            self.labels:记录图片标注的分类信息的pandas.DataFrame数据类型,已经预先给定。
            self.batch_size:每次批量生成,训练的样本大小。
            self.img_size:训练的图片尺寸。
            self.img_dir:图片在电脑中存放的路径。

         """

          self.list_IDs = list_IDs
          self.labels = labels
          self.batch_size = batch_size
          self.img_size = img_size
          self.img_dir = img_dir
          self.on_epoch_end()

      def __len__(self):

          """
             返回生成器的长度,也就是总共分批生成数据的次数。

          """
          return int(ceil(len(self.list_IDs) / self.batch_size))

     def __getitem__(self, index):

         """
            该函数返回每次我们需要的经过处理的数据。
         """

         indices = self.indices[index*self.batch_size:(index+1)*self.batch_size]
         list_IDs_temp = [self.list_IDs[k] for k in indices]
         X, Y = self.__data_generation(list_IDs_temp)
         return X, Y

     def on_epoch_end(self):

         """
            该函数将在训练时每一个epoch结束的时候自动执行,在这里是随机打乱索引次序以方便下一batch运行。

         """
         self.indices = np.arange(len(self.list_IDs))
         np.random.shuffle(self.indices)

     def __data_generation(self, list_IDs_temp):

        """
           给定文件名,生成数据。
        """
        X = np.empty((self.batch_size, *self.img_size, 1))
        Y = np.empty((self.batch_size, 6), dtype=np.float32)

       for i, ID in enumerate(list_IDs_temp):
       X[i,] = mpimg.imread(self.img_dir+ID+".png")
       Y[i,] = self.labels.loc[ID].values

       return X, Y

有了这个生成器,我们就可以用fit_generator 方法进行训练,格式套路如下:

model.fit_generator(generator,

steps_per_epoch=...,

epochs=...,

verbose=...,

callbacks=...,

validation_data=...,

validation_steps=...,

validation_freq=...,

class_weight=None=...,

max_queue_size=...

workers=...,

use_multiprocessing=...,

)

除此以外我们还可以搞批量预测:

model.predict_generator()

原文地址:https://www.cnblogs.com/szqfreiburger/p/11621261.html

时间: 2024-10-03 22:56:23

Keras用动态数据生成器(DataGenerator)和fitgenerator动态训练模型的相关文章

游戏设计一、关于游戏动态数据和静态数据的处理

最近的游戏项目遇到的问题 让我思考了一些东西  比如 游戏开始时会初始化很多数据到世界里面,比如玩家的金钱,玩家一边打怪 一边金钱猛涨,在打怪的时候,金钱的数据应该是直接写到世界的,而不是更新了金钱就写到数据库的,所以这里就有个问题,当玩家查点击 装备的时候 上面会有个金钱的额度 这个数值是通过数据库还是通过世界内存来的? 简单说下 世界内存就是动态数据 静态数据 都是放数据库的,如果要看到及时的额度 那就必须增加动态数值查询的接口

Highcharts 之 【动态数据】

最近项目中需要用到图表,找了几个开源框架,最后选择 Highcharts,原因是 Highcharts 功能强大,稳定,方便,而且开源,社区比较成熟. 首先下载 Highcharts,导入项目. 在 HTML 页面引入相应的 Js 文件.我这个项目是情绪监控相关,所谓情绪也就是热点的意思.大数据团队通过爬虫,先从数据库词典里拿到比较靠前的几个行业名称,然后通过爬虫在网上抓取这几个行业的热度值.每天固定时间抓取,统计一次. <!DOCTYPE HTML> <html> <hea

利用SD_SALESDOCUMENT_CREATE 批导动态数据SO

期初上线时,SO作为动态数据,是批导入系统必须做的一步,好多朋友利用bdc.lsmw.scatt等工具都可以做,下面是项目中利用SD_SALESDOCUMENT_CREATE 进行批导的一些代码,分享一下,希望对用到的朋友有帮助. *&---------------------------------------------------------------------* *& Report  ZSD_BATCH_SO *& *&---------------------

[数据生成器]UVA10054 The Necklace

应吴老师之邀,写了个数据生成器. 目前这个数据生成器可以保证生成的数据都是合法的,且效率也还不错.只是在建立普通连通图的时候zyy偷懒了,直接把所有点串起来从而保证图的连通.如果有大神有更好的方法请不吝指教,zyy不胜感谢~~ 下面是代码: 1 #include<cstdio> 2 #include<ctime> 3 #include<cstring> 4 #include<cstdlib> 5 #include<cmath> 6 #includ

achartengine 实现平行线 动态数据 x轴动态移动

achartengine做平行线的时候经常会遇到: java.lang.IndexOutOfBoundsException: Invalid index 1, size is 1 at java.util.ArrayList.throwIndexOutOfBoundsException(ArrayList.java:251) at org.achartengine.renderer.DefaultRenderer.getSeriesRendererAt(DefaultRenderer.java:

ExtJS4.2学习(20)动态数据表格之前几章总结篇1

本节采用技术:SpringMVC+Jetty+ExtJs4.2+Maven+MySQL5.1以上+SLF4J(前几节学习的大家不知道记住了没,现在来总结复习下,顺便加点新技术) 学习本节前的准备:Eclipse高版本,Jetty插件,Maven插件,JDK1.7 休息了好久没开动教程了,确实最近太累了,大家见谅!先来看下效果,本章节是连续篇,今天是续篇的第一讲,前面都是静态讲解,大家是不是觉得不过瘾啊?咱学Java的嘛,当然得和Java的技术结合讲解咯,前面也说到以后会用动态数据讲解的.一.准备

[CF787D]遗产(Legacy)-线段树-优化Dijkstra(内含数据生成器)

Problem 遗产 题目大意 给出一个带权有向图,有三种操作: 1.u->v添加一条权值为w的边 2.区间[l,r]->v添加权值为w的边 3.v->区间[l,r]添加权值为w的边 求st点到每个点的最短路 Solution 首先我们思考到,若是每次对于l,r区间内的每一个点都执行一次加边操作,不仅耗时还耗空间. 那么我们要想到一个办法去优化它.一看到lr区间,我们就会想到线段树对吧. 没错啦这题就是用线段树去优化它. 首先我们建一棵线段树,然后很容易想到,我们只需要把这一棵线段树当做

c++随机数据生成器

这是随机数据生成器,可以在noi上检测...... #include<iostream> #include<time.h> #include<cstdio> #include<stdlib.h> using namespace std; int main() { freopen("random.txt","w",stdout); long long a,b,n,cou;//cou:数据个数:  cin>>a

Yii Active Record 动态数据表

Active Record(AR)是一种流行的 对象-关系映射(ORM)技术,其映射关系为 AR class:数据表 AR class property:数据表的一列 AR 实例:数据表的一条数据 所以对于常用的数据库操作(CRUD)可以转化成一种面向对象的数据操作形式. 实现一个AR类的的最简代码如下: class Post extends CActiveRecord { public static function model($className=__CLASS__) { return p