【tf.keras】tf.keras模型复现

keras 构建模型很简单,上手很方便,同时又是 tensorflow 的高级 API,所以学学也挺好。

模型复现在我们的实验中也挺重要的,跑出了一个模型,虽然我们可以将模型的 checkpoint 保存,但再跑一遍,怎么都得不到相同的结果,对我而言这是不能接受的。

用 keras 实现模型,想要能够复现,需要将设置各个可能的随机过程的 seed;而且,代码不要在 GPU 上跑,而是在 CPU 上跑。(也就是说,GPU 上得到的 keras 模型没办法再复现。)

我的 tensorflow+keras 版本:

print(tf.VERSION)    # '1.10.0'
print(tf.keras.__version__)    # '2.1.6-tf'

keras 模型可复现的配置:

import numpy as np
import tensorflow as tf
import random as rn

import os
# run on CPU only
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ["PYTHONHASHSEED"] = '0'

# The below is necessary for starting Numpy generated random numbers
# in a well-defined initial state.

np.random.seed(42)

# The below is necessary for starting core Python generated random numbers
# in a well-defined state.

rn.seed(12345)

# Force TensorFlow to use single thread.
# Multiple threads are a potential source of non-reproducible results.
# For further details, see: https://stackoverflow.com/questions/42022950/

session_conf = tf.ConfigProto(intra_op_parallelism_threads=1,
                              inter_op_parallelism_threads=1)

from keras import backend as K

# The below tf.set_random_seed() will make random number generation
# in the TensorFlow backend have a well-defined initial state.
# For further details, see:
# https://www.tensorflow.org/api_docs/python/tf/set_random_seed

tf.set_random_seed(1234)

sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
K.set_session(sess)

# Rest of code follows ...

References

How can I obtain reproducible results using Keras during development? -- Keras Documentation
具有Tensorflow后端的Keras可以随意使用CPU或GPU吗?

原文地址:https://www.cnblogs.com/wuliytTaotao/p/10883749.html

时间: 2024-10-07 14:25:15

【tf.keras】tf.keras模型复现的相关文章

Keras Sequential顺序模型

keras是基于tensorflow封装的的高级API,Keras的优点是可以快速的开发实验,它能够以TensorFlow, CNTK, 或者 Theano 作为后端运行. 模型构建 最简单的模型是 Sequential 顺序模型,它由多个网络层线性堆叠.对于更复杂的结构,你应该使用 Keras 函数式 API,它允许构建任意的神经网络图. 用Keras定义网络模型有两种方式, Sequential 顺序模型 Keras 函数式 API模型 1.Sequential 顺序模型 from kera

Python Keras module 'keras.backend' has no attribute 'image_data_format'

问题: 当使用Keras运行示例程序mnist_cnn时,出现如下错误: 'keras.backend' has no attribute 'image_data_format' 程序路径https://github.com/fchollet/keras/blob/master/examples/mnist_cnn.py 使用的python conda环境是udacity自动驾驶课程的carnd-term1 故障程序段: if K.image_data_format() == 'channels

TF:TF定义两个变量相乘之placeholder先hold类似变量+feed_dict最后外界传入值—Jason niu

#TF:TF定义两个变量相乘之placeholder先hold类似变量+feed_dict最后外界传入值 import tensorflow as tf input1 = tf.placeholder(tf.float32) #TF一般只能处理float32的数据类型 input2 = tf.placeholder(tf.float32) #ouput = tf.mul(input1, input2) ouput = tf.multiply(input1, input2) #定义两个变量相乘 w

tensorflow-底层梯度tf.AggregationMethod,tf.gradients

(1)tf.AggregationMethod是一个类 Class?AggregationMethod类拥有的方法主要用于聚集梯度 ?计算偏导数需要聚集梯度贡献,这个类拥有在计算图中聚集梯度的很多方法.比如: ADD_N: 所有的梯度被求和汇总,使用 "AddN"操作.有一个特点:所有的梯度在聚集之前必须要准备好,DEFAULT: 默认聚集方法类方法 ADD_N DEFAULT EXPERIMENTAL_ACCUMULATE_N EXPERIMENTAL_TREE TensorFlow

TensorFlow 学习(二)—— tf Graph tf Session 与 tf Session run

session: with tf.Session() as sess:/ tf.InteractiveSession() 初始化: tf.global_variables_initializer() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) 1 2 0. tf.Graph 命名空间与 operation name(oper.name 获取操作名): c_0 = tf.constant(0, nam

tf.unstack\tf.unstack

tf.unstack 原型: unstack( value, num=None, axis=0, name='unstack' ) 官方解释:https://tensorflow.google.cn/api_docs/python/tf/unstack 解释:这是一个对矩阵进行分解的函数,以下为关键参数解释: value:代表需要分解的矩阵变量(其实就是一个多维数组,一般为二维): axis:指明对矩阵的哪个维度进行分解. 要理解tf.unstack函数,我们不妨先来看看tf.stack函数.T

论文阅读与模型复现——HAN

论文阅读论文链接:https://arxiv.org/pdf/1903.07293.pdf tensorflow版代码Github链接:https://github.com/Jhy1993/HAN 介绍视频:https://www.bilibili.com/video/av53418944/ 参考博客:https://blog.csdn.net/yyl424525/article/details/103804574 文中提出了一种新的基于注意力机制的异质图神经网络 Heterogeneous G

keras训练cnn模型时loss为nan

1.首先记下来如何解决这个问题的:由于我代码中 model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) 即损失函数用的是categorical_crossentropy所以,在pycharm中双击shift键,寻找该函数,会出现keras.loss模块中有该函数,进入该函数后, 原函数为: def categorical_crossentropy(y_true, y_pred):

ROS TF——learning tf

在机器人的控制中,坐标系统是非常重要的,在ROS使用tf软件库进行坐标转换. 相关链接:http://www.ros.org/wiki/tf/Tutorials#Learning_tf 一.tf简介 我们通过一个小小的实例来介绍tf的作用. 1.安装turtle包 $ rosdep install turtle_tf rviz $ rosmake turtle_tf rviz 2.运行demo 运行简单的demo: $ roslaunch turtle_tf turtle_tf_demo.lau