保存及读取keras模型参数

转自:http://blog.csdn.net/u010159842/article/details/54407745,感谢分享~

你可以使用model.save(filepath)将Keras模型和权重保存在一个HDF5文件中,该文件将包含:

  • 模型的结构,以便重构该模型
  • 模型的权重
  • 训练配置(损失函数,优化器等)
  • 优化器的状态,以便于从上次训练中断的地方开始

使用keras.models.load_model(filepath)来重新实例化你的模型,如果文件中存储了训练配置的话,该函数还会同时完成模型的编译

例子:

from keras.models import load_model

model.save(‘my_model.h5‘)  # creates a HDF5 file ‘my_model.h5‘
del model  # deletes the existing model

# returns a compiled model
# identical to the previous one
model = load_model(‘my_model.h5‘)

如果你只是希望保存模型的结构,而不包含其权重或配置信息,可以使用:

# save as JSON
json_string = model.to_json()

# save as YAML
yaml_string = model.to_yaml()

这项操作将把模型序列化为json或yaml文件,这些文件对人而言也是友好的,如果需要的话你甚至可以手动打开这些文件并进行编辑。

当然,你也可以从保存好的json文件或yaml文件中载入模型:

# model reconstruction from JSON:
from keras.models import model_from_json
model = model_from_json(json_string)

# model reconstruction from YAML
model = model_from_yaml(yaml_string)

如果需要保存模型的权重,可通过下面的代码利用HDF5进行保存。注意,在使用前需要确保你已安装了HDF5和其Python库h5py

model.save_weights(‘my_model_weights.h5‘)

如果你需要在代码中初始化一个完全相同的模型,请使用:

model.load_weights(‘my_model_weights.h5‘)

如果你需要加载权重到不同的网络结构(有些层一样)中,例如fine-tune或transfer-learning,你可以通过层名字来加载模型:

model.load_weights(‘my_model_weights.h5‘, by_name=True)

例如:

"""
假如原模型为:
    model = Sequential()
    model.add(Dense(2, input_dim=3, name="dense_1"))
    model.add(Dense(3, name="dense_2"))
    ...
    model.save_weights(fname)
"""
# new model
model = Sequential()
model.add(Dense(2, input_dim=3, name="dense_1"))  # will be loaded
model.add(Dense(10, name="new_dense"))  # will not be loaded

# load weights from first model; will only affect the first layer, dense_1.
model.load_weights(fname, by_name=True)

 

原文地址:https://www.cnblogs.com/baiting/p/8319321.html

时间: 2024-11-09 04:13:35

保存及读取keras模型参数的相关文章

[深度学习] Pytorch(三)—— 多/单GPU、CPU,训练保存、加载模型参数问题

[深度学习] Pytorch(三)-- 多/单GPU.CPU,训练保存.加载预测模型问题 上一篇实践学习中,遇到了在多/单个GPU.GPU与CPU的不同环境下训练保存.加载使用使用模型的问题,如果保存.加载的上述三类环境不同,加载时会出错.就去研究了一下,做了实验,得出以下结论: 多/单GPU训练保存模型参数.CPU加载使用模型 #保存 PATH = 'cifar_net.pth' torch.save(net.module.state_dict(), PATH) #加载 net = Net()

[MISS静IOS开发原创文摘]-AppDelegate存储全局变量和 NSUserDefaults standardUserDefaults 通过模型保存和读取数据,存储自定义的对象

由于app开发的需求,需要从api接口获得json格式数据并保存临时的 app的主题颜色 和 相关url 方案有很多种: 1, 通过AppDelegate保存为全局变量,再获取 2,使用NSUSerDefault 第一种 :通过AppDelegate方法: 定义全局变量 // // AppDelegate.h // // Created by MISSAJJ on 15/5/5. // Copyright (c) 2015年 MISSAJJ. All rights reserved. // #i

tf.train.Saver()-tensorflow中模型的保存及读取

作用:训练网络之后保存训练好的模型,以及在程序中读取已保存好的模型 使用步骤: 实例化一个Saver对象 saver = tf.train.Saver() 在训练过程中,定期调用saver.save方法,像文件夹中写入包含当前模型中所有可训练变量的checkpoint文件 saver.save(sess,FLAGG.train_dir,global_step=step) 之后可以使用saver.restore()方法,重载模型的参数,继续训练或者用于测试数据 saver.restore(sess

Keras模型保存的几个方法和它们的区别

github博客传送门 csdn博客传送门 Keras模型保存简介 model.save() model_save_path = "model_file_path.h5" # 保存模型 model.save(model_save_path) # 删除当前已存在的模型 del model # 加载模型 from keras.models import load_model model = load_model(model_save_path) model.save_weights() m

TensorFlow学习笔记(8)--网络模型的保存和读取【转】

转自:http://blog.csdn.net/lwplwf/article/details/62419087 之前的笔记里实现了softmax回归分类.简单的含有一个隐层的神经网络.卷积神经网络等等,但是这些代码在训练完成之后就直接退出了,并没有将训练得到的模型保存下来方便下次直接使用.为了让训练结果可以复用,需要将训练好的神经网络模型持久化,这就是这篇笔记里要写的东西. TensorFlow提供了一个非常简单的API,即tf.train.Saver类来保存和还原一个神经网络模型. 下面代码给

iOS 保存、读取与应用状态

固化 对于大多数iOS应用,可以将其功能总结为:提供一套界面,帮助用户管理特定的数据.在这一过程中,不同类型的对象要各司其职:模型对象负责保存数据,视图对象负责显示数据,控制器对象负责在模型对象与视图对象之间同步数据.因此,当某个应用要保存和读取数据时,通常要完成的任务是保存和读取相应的模型对象. 对 JXHmoepwner 应用,用户可以管理的模型对象是 JXItem 对象.目前 JXHomepwner 不嗯给你保存 JXItem 对象,所以,当用户重新运行 JXHomepwner 时,之前创

C# winform用sharpGL(OpenGl)解析读取3D模型obj

原文作者:aircraft 原文链接:https://www.cnblogs.com/DOMLX/p/11783026.html 自己写了个简单的类读取解析obj模型,使用导入类,然后new个对象,在读取obj模型,然后调用显示列表显示就可以了.至于其他什么旋转移动的你们自己加起来应该很容易的,因为我有看过c#下别人写的obj模型解析的代码项目,加了很多东西,我都找不到自己要用的代码在哪里,而我只需要读取解析obj模型这块代码而已,气的我自己写了个类自己解析,所以我怕我代码写多了, 你们反而看起

Keras学习手册(五),Keras 模型-Sequential API

感谢作者分享-http://bjbsair.com/2020-04-07/tech-info/30662.html 在 Keras 中有两类主要的模型:Sequential 顺序模型 和 使用函数式 API 的 Model 类模型. 这些模型有许多共同的方法和属性: model.layers 是包含模型网络层的展平列表. model.inputs 是模型输入张量的列表. model.outputs 是模型输出张量的列表. model.summary() 打印出模型概述信息. 它是 utils.p

基本数据持久性(二) 使用sqlite保存和读取数据

关于基本数据的持久性,写过一篇文章来简述过(基本数据持久性(一) 使用plist保存和读取数据).这篇文章将简述采用数据库sqlite的方式来保存数据,并根据查询结果读取数据. 一.工作原理 sqlite采用表存储的方式,表的第一行(也就是我们常说的表头)在sqilte中被称为“字段”.对于标的每一行(除了字段)的信息,都有一个独一无二的列内容可以将表的每一行内容独立区分开(例如本文所示的案例,存储一个学生的信息——学号.姓名.年龄.班级.那么,学号这一列就可以将表的每一行内容独立区分开,因为每