保存和恢复模型(Save and restore models)
官网示例:https://www.tensorflow.org/tutorials/keras/save_and_restore_models
在训练期间保存检查点
在训练期间或训练结束时自动保存检查点。
权重存储在检查点格式的文件集合中,这些文件仅包含经过训练的权重(采用二进制格式)。
可以使用经过训练的模型,而无需重新训练该模型,或从上次暂停的地方继续训练,以防训练过程中断
- 检查点回调用法:创建检查点回调,训练模型并将ModelCheckpoint回调传递给该模型,得到检查点文件集合,用于分享权重
- 检查点回调选项:该回调提供了多个选项,用于为生成的检查点提供独一无二的名称,以及调整检查点创建频率。
手动保存权重
使用 Model.save_weights 方法即可手动保存权重
保存整个模型
整个模型可以保存到一个文件中,其中包含权重值、模型配置(架构)、优化器配置。
可以为模型设置检查点,并稍后从完全相同的状态继续训练,而无需访问原始代码。
Keras通过检查架构来保存模型,使用HDF5标准提供基本的保存格式。
特别注意:
- 目前无法保存TensorFlow优化器(来自tf.train)。
- 使用此类优化器时,需要在加载模型后对其进行重新编译,使优化器的状态变松散。
MNIST数据集
MNIST(Mixed National Institute of Standards and Technology database)是一个计算机视觉数据集
- 官方下载地址:http://yann.lecun.com/exdb/mnist/
- 包含70000张手写数字的灰度图片,其中60000张为训练图像和10000张为测试图像
- 每一张图片都是28*28个像素点大小的灰度图像
- https://keras.io/datasets/#mnist-database-of-handwritten-digits
- TensorFlow:https://www.tensorflow.org/api_docs/python/tf/keras/datasets/mnist
示例
脚本内容
GitHub:https://github.com/anliven/Hello-AI/blob/master/Google-Learn-and-use-ML/5_save_and_restore_models.py
1 # coding=utf-8 2 import tensorflow as tf 3 from tensorflow import keras 4 import numpy as np 5 import pathlib 6 import os 7 8 os.environ[‘TF_CPP_MIN_LOG_LEVEL‘] = ‘2‘ 9 print("# TensorFlow version: {} - tf.keras version: {}".format(tf.VERSION, tf.keras.__version__)) # 查看版本 10 11 # ### 获取示例数据集 12 13 ds_path = str(pathlib.Path.cwd()) + "\\datasets\\mnist\\" # 数据集路径 14 np_data = np.load(ds_path + "mnist.npz") # 加载numpy格式数据 15 print("# np_data keys: ", list(np_data.keys())) # 查看所有的键 16 17 # 加载mnist数据集 18 (train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data(path=ds_path + "mnist.npz") 19 train_labels = train_labels[:1000] 20 test_labels = test_labels[:1000] 21 train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0 22 test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0 23 24 25 # ### 定义模型 26 def create_model(): 27 model = tf.keras.models.Sequential([ 28 keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(784,)), 29 keras.layers.Dropout(0.2), 30 keras.layers.Dense(10, activation=tf.nn.softmax) 31 ]) # 构建一个简单的模型 32 model.compile(optimizer=tf.keras.optimizers.Adam(), 33 loss=tf.keras.losses.sparse_categorical_crossentropy, 34 metrics=[‘accuracy‘]) 35 return model 36 37 38 mod = create_model() 39 mod.summary() 40 41 # ### 在训练期间保存检查点 42 43 # 检查点回调用法 44 checkpoint_path = "training_1/cp.ckpt" 45 checkpoint_dir = os.path.dirname(checkpoint_path) # 检查点存放目录 46 cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, 47 save_weights_only=True, 48 verbose=2) # 创建检查点回调 49 model1 = create_model() 50 model1.fit(train_images, train_labels, 51 epochs=10, 52 validation_data=(test_images, test_labels), 53 verbose=0, 54 callbacks=[cp_callback] # 将ModelCheckpoint回调传递给该模型 55 ) # 训练模型,将创建一个TensorFlow检查点文件集合,这些文件在每个周期结束时更新 56 57 model2 = create_model() # 创建一个未经训练的全新模型(与原始模型架构相同,才能分享权重) 58 loss, acc = model2.evaluate(test_images, test_labels) # 使用测试集进行评估 59 print("# Untrained model2, accuracy: {:5.2f}%".format(100 * acc)) # 未训练模型的表现(准确率约为10%) 60 61 model2.load_weights(checkpoint_path) # 从检查点加载权重 62 loss, acc = model2.evaluate(test_images, test_labels) # 使用测试集,重新进行评估 63 print("# Restored model2, accuracy: {:5.2f}%".format(100 * acc)) # 模型表现得到大幅提升 64 65 # 检查点回调选项 66 checkpoint_path2 = "training_2/cp-{epoch:04d}.ckpt" # 使用“str.format”方式为每个检查点设置唯一名称 67 checkpoint_dir2 = os.path.dirname(checkpoint_path) 68 cp_callback2 = tf.keras.callbacks.ModelCheckpoint(checkpoint_path2, 69 verbose=1, 70 save_weights_only=True, 71 period=5 # 每隔5个周期保存一次检查点 72 ) # 创建检查点回调 73 model3 = create_model() 74 model3.fit(train_images, train_labels, 75 epochs=50, 76 callbacks=[cp_callback2], # 将ModelCheckpoint回调传递给该模型 77 validation_data=(test_images, test_labels), 78 verbose=0) # 训练一个新模型,每隔5个周期保存一次检查点并设置唯一名称 79 latest = tf.train.latest_checkpoint(checkpoint_dir2) 80 print("# latest checkpoint: {}".format(latest)) # 查看最新的检查点 81 82 model4 = create_model() # 重新创建一个全新的模型 83 loss, acc = model2.evaluate(test_images, test_labels) # 使用测试集进行评估 84 print("# Untrained model4, accuracy: {:5.2f}%".format(100 * acc)) # 未训练模型的表现(准确率约为10%) 85 86 model4.load_weights(latest) # 加载最新的检查点 87 loss, acc = model4.evaluate(test_images, test_labels) # 88 print("# Restored model4, accuracy: {:5.2f}%".format(100 * acc)) # 模型表现得到大幅提升 89 90 # ### 手动保存权重 91 model5 = create_model() 92 model5.fit(train_images, train_labels, 93 epochs=10, 94 validation_data=(test_images, test_labels), 95 verbose=0) # 训练模型 96 model5.save_weights(‘./training_3/my_checkpoint‘) # 手动保存权重 97 98 model6 = create_model() 99 loss, acc = model6.evaluate(test_images, test_labels) 100 print("# Restored model6, accuracy: {:5.2f}%".format(100 * acc)) 101 model6.load_weights(‘./training_3/my_checkpoint‘) 102 loss, acc = model6.evaluate(test_images, test_labels) 103 print("# Restored model6, accuracy: {:5.2f}%".format(100 * acc)) 104 105 # ### 保存整个模型 106 model7 = create_model() 107 model7.fit(train_images, train_labels, epochs=5) 108 model7.save(‘my_model.h5‘) # 保存整个模型到HDF5文件 109 110 model8 = keras.models.load_model(‘my_model.h5‘) # 重建完全一样的模型,包括权重和优化器 111 model8.summary() 112 loss, acc = model8.evaluate(test_images, test_labels) 113 print("Restored model8, accuracy: {:5.2f}%".format(100 * acc))
运行结果
C:\Users\anliven\AppData\Local\conda\conda\envs\mlcc\python.exe D:/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML/5_save_and_restore_models.py # TensorFlow version: 1.12.0 - tf.keras version: 2.1.6-tf # np_data keys: [‘x_test‘, ‘x_train‘, ‘y_train‘, ‘y_test‘] _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 512) 401920 _________________________________________________________________ dropout (Dropout) (None, 512) 0 _________________________________________________________________ dense_1 (Dense) (None, 10) 5130 ================================================================= Total params: 407,050 Trainable params: 407,050 Non-trainable params: 0 _________________________________________________________________ Epoch 00001: saving model to training_1/cp.ckpt Epoch 00002: saving model to training_1/cp.ckpt Epoch 00003: saving model to training_1/cp.ckpt Epoch 00004: saving model to training_1/cp.ckpt Epoch 00005: saving model to training_1/cp.ckpt Epoch 00006: saving model to training_1/cp.ckpt Epoch 00007: saving model to training_1/cp.ckpt Epoch 00008: saving model to training_1/cp.ckpt Epoch 00009: saving model to training_1/cp.ckpt Epoch 00010: saving model to training_1/cp.ckpt 32/1000 [..............................] - ETA: 3s 1000/1000 [==============================] - 0s 140us/step # Untrained model2, accuracy: 8.20% 32/1000 [..............................] - ETA: 0s 1000/1000 [==============================] - 0s 40us/step # Restored model2, accuracy: 86.40% Epoch 00005: saving model to training_2/cp-0005.ckpt Epoch 00010: saving model to training_2/cp-0010.ckpt Epoch 00015: saving model to training_2/cp-0015.ckpt Epoch 00020: saving model to training_2/cp-0020.ckpt Epoch 00025: saving model to training_2/cp-0025.ckpt Epoch 00030: saving model to training_2/cp-0030.ckpt Epoch 00035: saving model to training_2/cp-0035.ckpt Epoch 00040: saving model to training_2/cp-0040.ckpt Epoch 00045: saving model to training_2/cp-0045.ckpt Epoch 00050: saving model to training_2/cp-0050.ckpt # latest checkpoint: training_1\cp.ckpt 32/1000 [..............................] - ETA: 3s 1000/1000 [==============================] - 0s 140us/step # Untrained model4, accuracy: 86.40% 32/1000 [..............................] - ETA: 2s 1000/1000 [==============================] - 0s 110us/step # Restored model4, accuracy: 86.40% 32/1000 [..............................] - ETA: 5s 1000/1000 [==============================] - 0s 220us/step # Restored model6, accuracy: 18.20% 32/1000 [..............................] - ETA: 0s 1000/1000 [==============================] - 0s 40us/step # Restored model6, accuracy: 87.40% Epoch 1/5 32/1000 [..............................] - ETA: 9s - loss: 2.4141 - acc: 0.0625 320/1000 [========>.....................] - ETA: 0s - loss: 1.8229 - acc: 0.4469 576/1000 [================>.............] - ETA: 0s - loss: 1.4932 - acc: 0.5694 864/1000 [========================>.....] - ETA: 0s - loss: 1.2624 - acc: 0.6481 1000/1000 [==============================] - 1s 530us/step - loss: 1.1978 - acc: 0.6620 Epoch 2/5 32/1000 [..............................] - ETA: 0s - loss: 0.5490 - acc: 0.8750 320/1000 [========>.....................] - ETA: 0s - loss: 0.4832 - acc: 0.8594 576/1000 [================>.............] - ETA: 0s - loss: 0.4630 - acc: 0.8715 864/1000 [========================>.....] - ETA: 0s - loss: 0.4356 - acc: 0.8808 1000/1000 [==============================] - 0s 200us/step - loss: 0.4298 - acc: 0.8790 Epoch 3/5 32/1000 [..............................] - ETA: 0s - loss: 0.1681 - acc: 0.9688 320/1000 [========>.....................] - ETA: 0s - loss: 0.2826 - acc: 0.9437 576/1000 [================>.............] - ETA: 0s - loss: 0.2774 - acc: 0.9340 832/1000 [=======================>......] - ETA: 0s - loss: 0.2740 - acc: 0.9327 1000/1000 [==============================] - 0s 200us/step - loss: 0.2781 - acc: 0.9280 Epoch 4/5 32/1000 [..............................] - ETA: 0s - loss: 0.1589 - acc: 0.9688 288/1000 [=======>......................] - ETA: 0s - loss: 0.2169 - acc: 0.9410 608/1000 [=================>............] - ETA: 0s - loss: 0.2186 - acc: 0.9457 864/1000 [========================>.....] - ETA: 0s - loss: 0.2231 - acc: 0.9479 1000/1000 [==============================] - 0s 200us/step - loss: 0.2164 - acc: 0.9480 Epoch 5/5 32/1000 [..............................] - ETA: 0s - loss: 0.1095 - acc: 1.0000 352/1000 [=========>....................] - ETA: 0s - loss: 0.1631 - acc: 0.9744 608/1000 [=================>............] - ETA: 0s - loss: 0.1671 - acc: 0.9638 864/1000 [========================>.....] - ETA: 0s - loss: 0.1545 - acc: 0.9688 1000/1000 [==============================] - 0s 210us/step - loss: 0.1538 - acc: 0.9670 _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_14 (Dense) (None, 512) 401920 _________________________________________________________________ dropout_7 (Dropout) (None, 512) 0 _________________________________________________________________ dense_15 (Dense) (None, 10) 5130 ================================================================= Total params: 407,050 Trainable params: 407,050 Non-trainable params: 0 _________________________________________________________________ 32/1000 [..............................] - ETA: 3s 1000/1000 [==============================] - 0s 150us/step Restored model8, accuracy: 86.10% Process finished with exit code 0
生成的文件
[email protected] MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML $ ll training_1 total 1601 -rw-r--r-- 1 anliven 197121 71 5月 5 23:36 checkpoint -rw-r--r-- 1 anliven 197121 1631508 5月 5 23:36 cp.ckpt.data-00000-of-00001 -rw-r--r-- 1 anliven 197121 647 5月 5 23:36 cp.ckpt.index [email protected] MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML $ [email protected] MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML $ [email protected] MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML $ [email protected] MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML $ ls -l training_1 total 1601 -rw-r--r-- 1 anliven 197121 71 5月 5 23:36 checkpoint -rw-r--r-- 1 anliven 197121 1631508 5月 5 23:36 cp.ckpt.data-00000-of-00001 -rw-r--r-- 1 anliven 197121 647 5月 5 23:36 cp.ckpt.index [email protected] MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML $ ls -l training_2 total 16001 -rw-r--r-- 1 anliven 197121 81 5月 5 23:37 checkpoint -rw-r--r-- 1 anliven 197121 1631508 5月 5 23:36 cp-0005.ckpt.data-00000-of-00001 -rw-r--r-- 1 anliven 197121 647 5月 5 23:36 cp-0005.ckpt.index -rw-r--r-- 1 anliven 197121 1631508 5月 5 23:36 cp-0010.ckpt.data-00000-of-00001 -rw-r--r-- 1 anliven 197121 647 5月 5 23:36 cp-0010.ckpt.index -rw-r--r-- 1 anliven 197121 1631508 5月 5 23:36 cp-0015.ckpt.data-00000-of-00001 -rw-r--r-- 1 anliven 197121 647 5月 5 23:36 cp-0015.ckpt.index -rw-r--r-- 1 anliven 197121 1631508 5月 5 23:36 cp-0020.ckpt.data-00000-of-00001 -rw-r--r-- 1 anliven 197121 647 5月 5 23:36 cp-0020.ckpt.index -rw-r--r-- 1 anliven 197121 1631508 5月 5 23:36 cp-0025.ckpt.data-00000-of-00001 -rw-r--r-- 1 anliven 197121 647 5月 5 23:36 cp-0025.ckpt.index -rw-r--r-- 1 anliven 197121 1631508 5月 5 23:37 cp-0030.ckpt.data-00000-of-00001 -rw-r--r-- 1 anliven 197121 647 5月 5 23:37 cp-0030.ckpt.index -rw-r--r-- 1 anliven 197121 1631508 5月 5 23:37 cp-0035.ckpt.data-00000-of-00001 -rw-r--r-- 1 anliven 197121 647 5月 5 23:37 cp-0035.ckpt.index -rw-r--r-- 1 anliven 197121 1631508 5月 5 23:37 cp-0040.ckpt.data-00000-of-00001 -rw-r--r-- 1 anliven 197121 647 5月 5 23:37 cp-0040.ckpt.index -rw-r--r-- 1 anliven 197121 1631508 5月 5 23:37 cp-0045.ckpt.data-00000-of-00001 -rw-r--r-- 1 anliven 197121 647 5月 5 23:37 cp-0045.ckpt.index -rw-r--r-- 1 anliven 197121 1631508 5月 5 23:37 cp-0050.ckpt.data-00000-of-00001 -rw-r--r-- 1 anliven 197121 647 5月 5 23:37 cp-0050.ckpt.index [email protected] MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML $ ls -l training_3 total 1601 -rw-r--r-- 1 anliven 197121 83 5月 5 23:37 checkpoint -rw-r--r-- 1 anliven 197121 1631517 5月 5 23:37 my_checkpoint.data-00000-of-00001 -rw-r--r-- 1 anliven 197121 647 5月 5 23:37 my_checkpoint.index [email protected] MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML $ ls -l my_model.h5 -rw-r--r-- 1 anliven 197121 4909112 5月 5 23:37 my_model.h5
问题处理
问题描述:出现如下告警信息。
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x00000280FD318780>) but is being saved in TensorFlow format with `save_weights`. The model‘s weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer‘s state will not be saved. Consider using a TensorFlow optimizer from `tf.train`.
问题处理:
正常告警,对脚本运行和结果无影响,暂不关注。
原文地址:https://www.cnblogs.com/anliven/p/10817233.html
时间: 2024-10-09 04:42:47