吴裕雄--天生自然python Google深度学习框架:Tensorflow实现迁移学习

import glob
import os.path
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile
import tensorflow.contrib.slim as slim

# 加载通过TensorFlow-Slim定义好的inception_v3模型。
import tensorflow.contrib.slim.python.slim.nets.inception_v3 as inception_v3

# 处理好之后的数据文件。
INPUT_DATA = ‘../../datasets/flower_processed_data.npy‘
# 保存训练好的模型的路径。
TRAIN_FILE = ‘train_dir/model‘
# 谷歌提供的训练好的模型文件地址。因为GitHub无法保存大于100M的文件,所以
# 在运行时需要先自行从Google下载inception_v3.ckpt文件。
CKPT_FILE = ‘../../datasets/inception_v3.ckpt‘

# 定义训练中使用的参数。
LEARNING_RATE = 0.0001
STEPS = 300
BATCH = 32
N_CLASSES = 5

# 不需要从谷歌训练好的模型中加载的参数。
CHECKPOINT_EXCLUDE_SCOPES = ‘InceptionV3/Logits,InceptionV3/AuxLogits‘
# 需要训练的网络层参数明层,在fine-tuning的过程中就是最后的全联接层。
TRAINABLE_SCOPES=‘InceptionV3/Logits,InceptionV3/AuxLogit‘
def get_tuned_variables():
    exclusions = [scope.strip() for scope in CHECKPOINT_EXCLUDE_SCOPES.split(‘,‘)]

    variables_to_restore = []
    # 枚举inception-v3模型中所有的参数,然后判断是否需要从加载列表中移除。
    for var in slim.get_model_variables():
        excluded = False
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                excluded = True
                break
        if not excluded:
            variables_to_restore.append(var)
    return variables_to_restore
def get_trainable_variables():
    scopes = [scope.strip() for scope in TRAINABLE_SCOPES.split(‘,‘)]
    variables_to_train = []

    # 枚举所有需要训练的参数前缀,并通过这些前缀找到所有需要训练的参数。
    for scope in scopes:
        variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
        variables_to_train.extend(variables)
    return variables_to_train
def main():
    # 加载预处理好的数据。
    processed_data = np.load(INPUT_DATA)
    training_images = processed_data[0]
    n_training_example = len(training_images)
    training_labels = processed_data[1]

    validation_images = processed_data[2]
    validation_labels = processed_data[3]

    testing_images = processed_data[4]
    testing_labels = processed_data[5]
    print("%d training examples, %d validation examples and %d testing examples." % (
        n_training_example, len(validation_labels), len(testing_labels)))

    # 定义inception-v3的输入,images为输入图片,labels为每一张图片对应的标签。
    images = tf.placeholder(tf.float32, [None, 299, 299, 3], name=‘input_images‘)
    labels = tf.placeholder(tf.int64, [None], name=‘labels‘)

    # 定义inception-v3模型。因为谷歌给出的只有模型参数取值,所以这里
    # 需要在这个代码中定义inception-v3的模型结构。虽然理论上需要区分训练和
    # 测试中使用到的模型,也就是说在测试时应该使用is_training=False,但是
    # 因为预先训练好的inception-v3模型中使用的batch normalization参数与
    # 新的数据会有出入,所以这里直接使用同一个模型来做测试。
    with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
        logits, _ = inception_v3.inception_v3(
            images, num_classes=N_CLASSES, is_training=True)

    trainable_variables = get_trainable_variables()
    # 定义损失函数和训练过程。
    tf.losses.softmax_cross_entropy(
        tf.one_hot(labels, N_CLASSES), logits, weights=1.0)
    total_loss = tf.losses.get_total_loss()
    train_step = tf.train.RMSPropOptimizer(LEARNING_RATE).minimize(total_loss)

    # 计算正确率。
    with tf.name_scope(‘evaluation‘):
        correct_prediction = tf.equal(tf.argmax(logits, 1), labels)
        evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    # 定义加载Google训练好的Inception-v3模型的Saver。
    load_fn = slim.assign_from_checkpoint_fn(
      CKPT_FILE,
      get_tuned_variables(),
      ignore_missing_vars=True)

    # 定义保存新模型的Saver。
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # 初始化没有加载进来的变量。
        init = tf.global_variables_initializer()
        sess.run(init)

        # 加载谷歌已经训练好的模型。
        print(‘Loading tuned variables from %s‘ % CKPT_FILE)
        load_fn(sess)

        start = 0
        end = BATCH
        for i in range(STEPS):
            _, loss = sess.run([train_step, total_loss], feed_dict={
                images: training_images[start:end],
                labels: training_labels[start:end]})

            if i % 30 == 0 or i + 1 == STEPS:
                saver.save(sess, TRAIN_FILE, global_step=i)

                validation_accuracy = sess.run(evaluation_step, feed_dict={
                    images: validation_images, labels: validation_labels})
                print(‘Step %d: Training loss is %.1f Validation accuracy = %.1f%%‘ % (
                    i, loss, validation_accuracy * 100.0))

            start = end
            if start == n_training_example:
                start = 0

            end = start + BATCH
            if end > n_training_example:
                end = n_training_example

        # 在最后的测试数据上测试正确率。
        test_accuracy = sess.run(evaluation_step, feed_dict={
            images: testing_images, labels: testing_labels})
        print(‘Final test accuracy = %.1f%%‘ % (test_accuracy * 100))

原文地址:https://www.cnblogs.com/tszr/p/12051611.html

时间: 2024-10-10 00:30:49

吴裕雄--天生自然python Google深度学习框架:Tensorflow实现迁移学习的相关文章

吴裕雄--天生自然python Google深度学习框架:经典卷积神经网络模型

import tensorflow as tf INPUT_NODE = 784 OUTPUT_NODE = 10 IMAGE_SIZE = 28 NUM_CHANNELS = 1 NUM_LABELS = 10 CONV1_DEEP = 32 CONV1_SIZE = 5 CONV2_DEEP = 64 CONV2_SIZE = 5 FC_SIZE = 512 def inference(input_tensor, train, regularizer): with tf.variable_s

吴裕雄--天生自然python Google深度学习框架:图像识别与卷积神经网络

原文地址:https://www.cnblogs.com/tszr/p/12051477.html

吴裕雄--天生自然python编程:正则表达式

re.match函数 re.match 尝试从字符串的起始位置匹配一个模式,如果不是起始位置匹配成功的话,match()就返回none. 函数语法: re.match(pattern, string, flags=0) 函数参数说明: 参数 描述 pattern 匹配的正则表达式 string 要匹配的字符串. flags 标志位,用于控制正则表达式的匹配方式,如:是否区分大小写,多行匹配等等. 匹配成功re.match方法返回一个匹配的对象,否则返回None. 我们可以使用group(num)

吴裕雄--天生自然python机器学习:决策树算法

我们经常使用决策树处理分类问题’近来的调查表明决策树也是最经常使用的数据挖掘算法. 它之所以如此流行,一个很重要的原因就是使用者基本上不用了解机器学习算法,也不用深究它 是如何工作的. K-近邻算法可以完成很多分类任务,但是它最大的缺点就是无法给出数据的内 在含义,决策树的主要优势就在于数据形式非常容易理解. 决策树很多任务都 是为了数据中所蕴含的知识信息,因此决策树可以使用不熟悉的数据集合,并从中提取出一系列 规则,机器学习算法最终将使用这些机器从数据集中创造的规则.专家系统中经常使用决策树,

吴裕雄--天生自然python机器学习:朴素贝叶斯算法

分类器有时会产生错误结果,这时可以要求分类器给出一个最优的类别猜测结果,同 时给出这个猜测的概率估计值. 概率论是许多机器学习算法的基础 在计算 特征值取某个值的概率时涉及了一些概率知识,在那里我们先统计特征在数据集中取某个特定值 的次数,然后除以数据集的实例总数,就得到了特征取该值的概率. 首先从一个最简单的概率分类器开始,然后给 出一些假设来学习朴素贝叶斯分类器.我们称之为“朴素”,是因为整个形式化过程只做最原始.最简单的假设. 基于贝叶斯决策理论的分类方法 朴素贝叶斯是贝叶斯决策理论的一部

吴裕雄--天生自然 PYTHON数据分析:糖尿病视网膜病变数据分析(完整版)

# This Python 3 environment comes with many helpful analytics libraries installed # It is defined by the kaggle/python docker image: https://github.com/kaggle/docker-python # For example, here's several helpful packages to load in import numpy as np

吴裕雄--天生自然python编程:turtle模块绘图(3)

turtle(海龟)是Python重要的标准库之一,它能够进行基本的图形绘制.turtle图形绘制的概念诞生于1969年,成功应用于LOGO编程语言. turtle库绘制图形有一个基本框架:一个小海龟在坐标系中爬行,其爬行轨迹形成了绘制图形.刚开始绘制时,小海龟位于画布正中央,此处坐标为(0,0),前进方向为水平右方. Python——turtle库 turtle库包含100多个功能函数,主要包括窗体函数.画笔状态函数和画笔运动函数3类. 画笔运动函数 turtle通过一组函数控制画笔的行进动作

吴裕雄--天生自然 PYTHON数据分析:人类发展报告——HDI, GDI,健康,全球人口数据数据分析

import pandas as pd # Data analysis import numpy as np #Data analysis import seaborn as sns # Data visualization import matplotlib.pyplot as plt # Data Visualization import matplotlib.gridspec as gridspec # subplots and grid from wordcloud import Wor

吴裕雄--天生自然 PYTHON语言数据分析:ESA的火星快车操作数据集分析

import os import numpy as np import pandas as pd from datetime import datetime import matplotlib import matplotlib.pyplot as plt import seaborn as sns sns.set_style('white') %matplotlib inline %load_ext autoreload %autoreload 2 def to_utms(ut): retur