AI - TensorFlow - 示例05:保存和恢复模型

保存和恢复模型(Save and restore models)

官网示例:https://www.tensorflow.org/tutorials/keras/save_and_restore_models

在训练期间保存检查点

在训练期间或训练结束时自动保存检查点。
权重存储在检查点格式的文件集合中,这些文件仅包含经过训练的权重(采用二进制格式)。
可以使用经过训练的模型,而无需重新训练该模型,或从上次暂停的地方继续训练,以防训练过程中断

  • 检查点回调用法:创建检查点回调,训练模型并将ModelCheckpoint回调传递给该模型,得到检查点文件集合,用于分享权重
  • 检查点回调选项:该回调提供了多个选项,用于为生成的检查点提供独一无二的名称,以及调整检查点创建频率。

手动保存权重

使用 Model.save_weights 方法即可手动保存权重

保存整个模型

整个模型可以保存到一个文件中,其中包含权重值、模型配置(架构)、优化器配置。
可以为模型设置检查点,并稍后从完全相同的状态继续训练,而无需访问原始代码。
Keras通过检查架构来保存模型,使用HDF5标准提供基本的保存格式。
特别注意:

  • 目前无法保存TensorFlow优化器(来自tf.train)。
  • 使用此类优化器时,需要在加载模型后对其进行重新编译,使优化器的状态变松散。

MNIST数据集

MNIST(Mixed National Institute of Standards and Technology database)是一个计算机视觉数据集

示例

脚本内容

GitHub:https://github.com/anliven/Hello-AI/blob/master/Google-Learn-and-use-ML/5_save_and_restore_models.py

  1 # coding=utf-8
  2 import tensorflow as tf
  3 from tensorflow import keras
  4 import numpy as np
  5 import pathlib
  6 import os
  7
  8 os.environ[‘TF_CPP_MIN_LOG_LEVEL‘] = ‘2‘
  9 print("# TensorFlow version: {}  - tf.keras version: {}".format(tf.VERSION, tf.keras.__version__))  # 查看版本
 10
 11 # ### 获取示例数据集
 12
 13 ds_path = str(pathlib.Path.cwd()) + "\\datasets\\mnist\\"  # 数据集路径
 14 np_data = np.load(ds_path + "mnist.npz")  # 加载numpy格式数据
 15 print("# np_data keys: ", list(np_data.keys()))  # 查看所有的键
 16
 17 # 加载mnist数据集
 18 (train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data(path=ds_path + "mnist.npz")
 19 train_labels = train_labels[:1000]
 20 test_labels = test_labels[:1000]
 21 train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
 22 test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
 23
 24
 25 # ### 定义模型
 26 def create_model():
 27     model = tf.keras.models.Sequential([
 28         keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(784,)),
 29         keras.layers.Dropout(0.2),
 30         keras.layers.Dense(10, activation=tf.nn.softmax)
 31     ])  # 构建一个简单的模型
 32     model.compile(optimizer=tf.keras.optimizers.Adam(),
 33                   loss=tf.keras.losses.sparse_categorical_crossentropy,
 34                   metrics=[‘accuracy‘])
 35     return model
 36
 37
 38 mod = create_model()
 39 mod.summary()
 40
 41 # ### 在训练期间保存检查点
 42
 43 # 检查点回调用法
 44 checkpoint_path = "training_1/cp.ckpt"
 45 checkpoint_dir = os.path.dirname(checkpoint_path)  # 检查点存放目录
 46 cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
 47                                                  save_weights_only=True,
 48                                                  verbose=2)  # 创建检查点回调
 49 model1 = create_model()
 50 model1.fit(train_images, train_labels,
 51            epochs=10,
 52            validation_data=(test_images, test_labels),
 53            verbose=0,
 54            callbacks=[cp_callback]  # 将ModelCheckpoint回调传递给该模型
 55            )  # 训练模型,将创建一个TensorFlow检查点文件集合,这些文件在每个周期结束时更新
 56
 57 model2 = create_model()  # 创建一个未经训练的全新模型(与原始模型架构相同,才能分享权重)
 58 loss, acc = model2.evaluate(test_images, test_labels)  # 使用测试集进行评估
 59 print("# Untrained model2, accuracy: {:5.2f}%".format(100 * acc))  # 未训练模型的表现(准确率约为10%)
 60
 61 model2.load_weights(checkpoint_path)  # 从检查点加载权重
 62 loss, acc = model2.evaluate(test_images, test_labels)  # 使用测试集,重新进行评估
 63 print("# Restored model2, accuracy: {:5.2f}%".format(100 * acc))  # 模型表现得到大幅提升
 64
 65 # 检查点回调选项
 66 checkpoint_path2 = "training_2/cp-{epoch:04d}.ckpt"  # 使用“str.format”方式为每个检查点设置唯一名称
 67 checkpoint_dir2 = os.path.dirname(checkpoint_path)
 68 cp_callback2 = tf.keras.callbacks.ModelCheckpoint(checkpoint_path2,
 69                                                   verbose=1,
 70                                                   save_weights_only=True,
 71                                                   period=5  # 每隔5个周期保存一次检查点
 72                                                   )  # 创建检查点回调
 73 model3 = create_model()
 74 model3.fit(train_images, train_labels,
 75            epochs=50,
 76            callbacks=[cp_callback2],  # 将ModelCheckpoint回调传递给该模型
 77            validation_data=(test_images, test_labels),
 78            verbose=0)  # 训练一个新模型,每隔5个周期保存一次检查点并设置唯一名称
 79 latest = tf.train.latest_checkpoint(checkpoint_dir2)
 80 print("# latest checkpoint: {}".format(latest))  # 查看最新的检查点
 81
 82 model4 = create_model()  # 重新创建一个全新的模型
 83 loss, acc = model2.evaluate(test_images, test_labels)  # 使用测试集进行评估
 84 print("# Untrained model4, accuracy: {:5.2f}%".format(100 * acc))  # 未训练模型的表现(准确率约为10%)
 85
 86 model4.load_weights(latest)  # 加载最新的检查点
 87 loss, acc = model4.evaluate(test_images, test_labels)  #
 88 print("# Restored model4, accuracy: {:5.2f}%".format(100 * acc))  # 模型表现得到大幅提升
 89
 90 # ### 手动保存权重
 91 model5 = create_model()
 92 model5.fit(train_images, train_labels,
 93            epochs=10,
 94            validation_data=(test_images, test_labels),
 95            verbose=0)  # 训练模型
 96 model5.save_weights(‘./training_3/my_checkpoint‘)  # 手动保存权重
 97
 98 model6 = create_model()
 99 loss, acc = model6.evaluate(test_images, test_labels)
100 print("# Restored model6, accuracy: {:5.2f}%".format(100 * acc))
101 model6.load_weights(‘./training_3/my_checkpoint‘)
102 loss, acc = model6.evaluate(test_images, test_labels)
103 print("# Restored model6, accuracy: {:5.2f}%".format(100 * acc))
104
105 # ### 保存整个模型
106 model7 = create_model()
107 model7.fit(train_images, train_labels, epochs=5)
108 model7.save(‘my_model.h5‘)  # 保存整个模型到HDF5文件
109
110 model8 = keras.models.load_model(‘my_model.h5‘)  # 重建完全一样的模型,包括权重和优化器
111 model8.summary()
112 loss, acc = model8.evaluate(test_images, test_labels)
113 print("Restored model8, accuracy: {:5.2f}%".format(100 * acc))

运行结果

C:\Users\anliven\AppData\Local\conda\conda\envs\mlcc\python.exe D:/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML/5_save_and_restore_models.py
# TensorFlow version: 1.12.0  - tf.keras version: 2.1.6-tf
# np_data keys:  [‘x_test‘, ‘x_train‘, ‘y_train‘, ‘y_test‘]
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
dense (Dense)                (None, 512)               401920
_________________________________________________________________
dropout (Dropout)            (None, 512)               0
_________________________________________________________________
dense_1 (Dense)              (None, 10)                5130
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

Epoch 00001: saving model to training_1/cp.ckpt
Epoch 00002: saving model to training_1/cp.ckpt
Epoch 00003: saving model to training_1/cp.ckpt
Epoch 00004: saving model to training_1/cp.ckpt
Epoch 00005: saving model to training_1/cp.ckpt
Epoch 00006: saving model to training_1/cp.ckpt
Epoch 00007: saving model to training_1/cp.ckpt
Epoch 00008: saving model to training_1/cp.ckpt
Epoch 00009: saving model to training_1/cp.ckpt
Epoch 00010: saving model to training_1/cp.ckpt

  32/1000 [..............................] - ETA: 3s
1000/1000 [==============================] - 0s 140us/step
# Untrained model2, accuracy:  8.20%

  32/1000 [..............................] - ETA: 0s
1000/1000 [==============================] - 0s 40us/step
# Restored model2, accuracy: 86.40%

Epoch 00005: saving model to training_2/cp-0005.ckpt
Epoch 00010: saving model to training_2/cp-0010.ckpt
Epoch 00015: saving model to training_2/cp-0015.ckpt
Epoch 00020: saving model to training_2/cp-0020.ckpt
Epoch 00025: saving model to training_2/cp-0025.ckpt
Epoch 00030: saving model to training_2/cp-0030.ckpt
Epoch 00035: saving model to training_2/cp-0035.ckpt
Epoch 00040: saving model to training_2/cp-0040.ckpt
Epoch 00045: saving model to training_2/cp-0045.ckpt
Epoch 00050: saving model to training_2/cp-0050.ckpt

# latest checkpoint: training_1\cp.ckpt

  32/1000 [..............................] - ETA: 3s
1000/1000 [==============================] - 0s 140us/step
# Untrained model4, accuracy: 86.40%

  32/1000 [..............................] - ETA: 2s
1000/1000 [==============================] - 0s 110us/step
# Restored model4, accuracy: 86.40%

  32/1000 [..............................] - ETA: 5s
1000/1000 [==============================] - 0s 220us/step
# Restored model6, accuracy: 18.20%

  32/1000 [..............................] - ETA: 0s
1000/1000 [==============================] - 0s 40us/step
# Restored model6, accuracy: 87.40%
Epoch 1/5

  32/1000 [..............................] - ETA: 9s - loss: 2.4141 - acc: 0.0625
 320/1000 [========>.....................] - ETA: 0s - loss: 1.8229 - acc: 0.4469
 576/1000 [================>.............] - ETA: 0s - loss: 1.4932 - acc: 0.5694
 864/1000 [========================>.....] - ETA: 0s - loss: 1.2624 - acc: 0.6481
1000/1000 [==============================] - 1s 530us/step - loss: 1.1978 - acc: 0.6620
Epoch 2/5

  32/1000 [..............................] - ETA: 0s - loss: 0.5490 - acc: 0.8750
 320/1000 [========>.....................] - ETA: 0s - loss: 0.4832 - acc: 0.8594
 576/1000 [================>.............] - ETA: 0s - loss: 0.4630 - acc: 0.8715
 864/1000 [========================>.....] - ETA: 0s - loss: 0.4356 - acc: 0.8808
1000/1000 [==============================] - 0s 200us/step - loss: 0.4298 - acc: 0.8790
Epoch 3/5

  32/1000 [..............................] - ETA: 0s - loss: 0.1681 - acc: 0.9688
 320/1000 [========>.....................] - ETA: 0s - loss: 0.2826 - acc: 0.9437
 576/1000 [================>.............] - ETA: 0s - loss: 0.2774 - acc: 0.9340
 832/1000 [=======================>......] - ETA: 0s - loss: 0.2740 - acc: 0.9327
1000/1000 [==============================] - 0s 200us/step - loss: 0.2781 - acc: 0.9280
Epoch 4/5

  32/1000 [..............................] - ETA: 0s - loss: 0.1589 - acc: 0.9688
 288/1000 [=======>......................] - ETA: 0s - loss: 0.2169 - acc: 0.9410
 608/1000 [=================>............] - ETA: 0s - loss: 0.2186 - acc: 0.9457
 864/1000 [========================>.....] - ETA: 0s - loss: 0.2231 - acc: 0.9479
1000/1000 [==============================] - 0s 200us/step - loss: 0.2164 - acc: 0.9480
Epoch 5/5

  32/1000 [..............................] - ETA: 0s - loss: 0.1095 - acc: 1.0000
 352/1000 [=========>....................] - ETA: 0s - loss: 0.1631 - acc: 0.9744
 608/1000 [=================>............] - ETA: 0s - loss: 0.1671 - acc: 0.9638
 864/1000 [========================>.....] - ETA: 0s - loss: 0.1545 - acc: 0.9688
1000/1000 [==============================] - 0s 210us/step - loss: 0.1538 - acc: 0.9670
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
dense_14 (Dense)             (None, 512)               401920
_________________________________________________________________
dropout_7 (Dropout)          (None, 512)               0
_________________________________________________________________
dense_15 (Dense)             (None, 10)                5130
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

  32/1000 [..............................] - ETA: 3s
1000/1000 [==============================] - 0s 150us/step
Restored model8, accuracy: 86.10%

Process finished with exit code 0

生成的文件

[email protected] MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
$ ll training_1
total 1601
-rw-r--r-- 1 anliven 197121      71 5月   5 23:36 checkpoint
-rw-r--r-- 1 anliven 197121 1631508 5月   5 23:36 cp.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121     647 5月   5 23:36 cp.ckpt.index

[email protected] MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
$

[email protected] MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
$

[email protected] MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
$

[email protected] MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
$ ls -l training_1
total 1601
-rw-r--r-- 1 anliven 197121      71 5月   5 23:36 checkpoint
-rw-r--r-- 1 anliven 197121 1631508 5月   5 23:36 cp.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121     647 5月   5 23:36 cp.ckpt.index

[email protected] MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
$ ls -l training_2
total 16001
-rw-r--r-- 1 anliven 197121      81 5月   5 23:37 checkpoint
-rw-r--r-- 1 anliven 197121 1631508 5月   5 23:36 cp-0005.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121     647 5月   5 23:36 cp-0005.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月   5 23:36 cp-0010.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121     647 5月   5 23:36 cp-0010.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月   5 23:36 cp-0015.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121     647 5月   5 23:36 cp-0015.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月   5 23:36 cp-0020.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121     647 5月   5 23:36 cp-0020.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月   5 23:36 cp-0025.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121     647 5月   5 23:36 cp-0025.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月   5 23:37 cp-0030.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121     647 5月   5 23:37 cp-0030.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月   5 23:37 cp-0035.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121     647 5月   5 23:37 cp-0035.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月   5 23:37 cp-0040.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121     647 5月   5 23:37 cp-0040.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月   5 23:37 cp-0045.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121     647 5月   5 23:37 cp-0045.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月   5 23:37 cp-0050.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121     647 5月   5 23:37 cp-0050.ckpt.index

[email protected] MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
$ ls -l training_3
total 1601
-rw-r--r-- 1 anliven 197121      83 5月   5 23:37 checkpoint
-rw-r--r-- 1 anliven 197121 1631517 5月   5 23:37 my_checkpoint.data-00000-of-00001
-rw-r--r-- 1 anliven 197121     647 5月   5 23:37 my_checkpoint.index

[email protected] MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
$ ls -l my_model.h5
-rw-r--r-- 1 anliven 197121 4909112 5月   5 23:37 my_model.h5

问题处理

问题描述:出现如下告警信息。

WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x00000280FD318780>) but is being saved in TensorFlow format with `save_weights`. The model‘s weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer‘s state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.

问题处理:

正常告警,对脚本运行和结果无影响,暂不关注。

原文地址:https://www.cnblogs.com/anliven/p/10817233.html

时间: 2024-07-30 14:16:00

AI - TensorFlow - 示例05:保存和恢复模型的相关文章

AI - TensorFlow - 示例:影评文本分类

影评文本分类 文本分类(Text classification):https://www.tensorflow.org/tutorials/keras/basic_text_classification主要步骤: 1.加载IMDB数据集 2.探索数据:了解数据格式.将整数转换为字词 3.准备数据 4.构建模型:隐藏单元.损失函数和优化器 5.创建验证集 6.训练模型 7.评估模型 8.可视化:创建准确率和损失随时间变化的图 IMDB数据集 https://www.tensorflow.org/a

AI - TensorFlow - 示例03:基本回归

基本回归 回归(Regression):https://www.tensorflow.org/tutorials/keras/basic_regression 主要步骤:数据部分 获取数据(Get the data) 清洗数据(Clean the data) 划分训练集和测试集(Split the data into train and test) 检查数据(Inspect the data) 分离标签(Split features from labels) 规范化数据(Normalize th

AI - TensorFlow - 示例01:基本分类

基本分类 基本分类(Basic classification):https://www.tensorflow.org/tutorials/keras/basic_classification Fashion MNIST数据集 经典 MNIST 数据集(常用作计算机视觉机器学习程序的“Hello, World”入门数据集)的简易替换 包含训练数据60000个,测试数据10000个,每个图片是28x28像素的灰度图像,涵盖10个类别 TensorFlow:https://www.tensorflow

tensorflow 1.0 学习:模型的保存与恢复(Saver)

将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf.train.Saver() 在创建这个Saver对象的时候,有一个参数我们经常会用到,就是 max_to_keep 参数,这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型.如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置

【Android】11.2 通过重写对应的方法保存和恢复实例的状态

分类:C#.Android.VS2015: 创建日期:2016-02-21 一.简介 通过重写(也叫回调)对应的方法来管理Activity的生命周期,比如用户旋转屏幕时应用程序要能自动保存和恢复实例的状态,这对于开发一个健壮而又灵活的应用程序而言至关重要. 1.本节要点 一旦真正理解了Activity的生命周期,就可以轻松自如地通过C#代码去控制它了.这一节我们主要学习如何用Boundle存储简单类型的数据(比如int.double.string.bool.--等). 当一个Activity停止

运维管理中IT故障定位、预警与智能恢复模型建立和应用实践解密

运维管理中IT故障定位.预警与智能恢复模型建立和应用实践解密各位大伽先不要说我理的对不对,我们来说用网管软件与IT运维管理系统来做IT的监测管理,先来看下面以SITEVIEW  ITOSS为例的一张图,图最能说明模型的意图:  我们可以从左向右,从上向下来看一下,一开始是需要采集,也即监测源端,监测源包括比如关键的服务器.网络设备.网络.日志和核心的业务应用系统,IT的环境.数据中心.机房环境等等.监测的参数就看如图SITEVIEW ITOSS一体化平台包括的五大模块功能中需要的参数状态数据,这

Android——保存和恢复用户状态

onSaveInstanceState 保存 在暂停之后和保存之前调用 onRestoreInstanceState 恢复 再启动之后和显示之前调用 package com.example.chenshuai.excise; import android.app.Activity; import android.content.Intent; import android.os.Bundle; import android.util.Log; import android.view.View;

ios开发——实用技术篇&amp;数据保存于恢复

数据保存于恢复 用户操作(输入数据)之后,应用程序退出并且终止之后,当用户再次打开应用的时候还是保持原来的状态 一:在storyBoard中设置恢复标志符 二:在AppDalegate中代理方法 1 -(BOOL) application:(UIApplication *)application shouldSaveApplicationState:(NSCoder *)coder 2 { 3 return YES; 4 } 5 6 -(BOOL) application:(UIApplicat

保存和恢复activity的状态数据[转]

转自:here 一般来说,调用onPause()和onStop()方法后的activity实例仍然存在于内存中,activity中的所有信息和状态数据都不会消失,当activity重新回到前台后,所有的改变都会保留. 但是当内存系统内存不足时,调用onPause()和onStop()方法的activity可能被摧毁.此时内存中就不会存在有该activity实例对象了. 为了避免这种情况,我们可以覆盖onSaveInstanceState()方法来接受一个Bundle类型的参数,我们可以将该act