tf.train.string_input_producer()

处理从文件中读数据

官方说明

简单使用

示例中读取的是csv文件,如果要读tfrecord的文件,需要换成 tf.TFRecordReader

import tensorflow as tf
filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"])

reader = tf.TextLineReader()
key, value = reader.read(filename_queue)

# Default values, in case of empty columns. Also specifies the type of the decoded result.
record_defaults = [[1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5 = tf.decode_csv(value, record_defaults=record_defaults)
features = tf.stack([col1, col2, col3, col4])

with tf.Session() as sess:
    # Start populating the filename queue.
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(12):
        # Retrieve a single instance:
        example, label = sess.run([features, col5])
        print(example, label)

    coord.request_stop()
    coord.join(threads)

运行结果:

结合批处理

import tensorflow as tf
def read_my_file_format(filename_queue):
#     reader = tf.SomeReader()
    reader = tf.TextLineReader()
    key, record_string = reader.read(filename_queue)
#     example, label = tf.some_decoder(record_string)
    record_defaults = [[1], [1], [1], [1], [1]]
    col1, col2, col3, col4, col5 = tf.decode_csv(record_string, record_defaults=record_defaults)
#     processed_example = some_processing(example)
    features = tf.stack([col1, col2, col3, col4])
    return features, col5

def input_pipeline(filenames, batch_size, num_epochs=None):
    filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs, shuffle=True)
    example, label = read_my_file_format(filename_queue)
    #   min_after_dequeue + (num_threads + a small safety margin) * batch_size
    min_after_dequeue = 100
    capacity = min_after_dequeue + 3 * batch_size
    example_batch, label_batch = tf.train.shuffle_batch([example, label], batch_size=batch_size, capacity=capacity,
                              min_after_dequeue=min_after_dequeue)
    return example_batch, label_batch

x,y = input_pipeline(["file0.csv", "file1.csv"],5,4)

sess = tf.Session()
sess.run([tf.global_variables_initializer(),tf.initialize_local_variables()])

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

try:
    print("in try")
    while not coord.should_stop():
        # Run training steps or whatever
        example, label = sess.run([x,y])
        print(example, label)
        print("ssss")

except tf.errors.OutOfRangeError:
    print (‘Done training -- epoch limit reached‘)
finally:
    # When done, ask the threads to stop.
    coord.request_stop()

# Wait for threads to finish.
coord.join(threads)
sess.close()

运行结果:

原文地址:https://www.cnblogs.com/helloworld0604/p/10044748.html

时间: 2024-10-08 17:37:46

tf.train.string_input_producer()的相关文章

tensorflow API _ 3 (tf.train.polynomial_decay)

学习率的三种调整方式:固定的,指数的,多项式的 def _configure_learning_rate(num_samples_per_epoch, global_step): """Configures the learning rate. Args: num_samples_per_epoch: The number of samples in each epoch of training. global_step: The global_step tensor. Re

跟我学算法- tensorflow模型的保存与读取 tf.train.Saver()

save =  tf.train.Saver() 通过save. save() 实现数据的加载 通过save.restore() 实现数据的导出 第一步: 数据的载入 import tensorflow as tf #创建变量 v1 = tf.Variable(tf.random_normal([1, 2], name='v1')) v2 = tf.Variable(tf.random_normal([2, 3], name='v2')) #初始化变量 init_op = tf.global_v

TensorFlow 中的 tf.train.exponential_decay() 指数衰减法

exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase=False, name=None) 使用方式为 tf.train.exponential_decay( ) 在 Tensorflow 中,exponential_decay()是应用于学习率的指数衰减函数. 在训练模型时,通常建议随着训练的进行逐步降低学习率.该函数需要`global_step`值来计算衰减的学习速率. 该函数返回衰减后

图融合之加载子图:Tensorflow.contrib.slim与tf.train.Saver之坑

import tensorflow as tf import tensorflow.contrib.slim as slim import rawpy import numpy as np import tensorflow as tf import struct import glob import os from PIL import Image import time __sony__ = 0 __huawei__ = 1 __blackberry__ = 2 __stage_raw2ra

tensorflow-训练检查点tf.train.Saver

#!/usr/bin/env python2 # -*- coding: utf-8 -*- """ Created on Thu Sep 6 10:16:37 2018 @author: myhaspl @email:[email protected] """ import tensorflow as tf g1=tf.Graph() with g1.as_default(): with tf.name_scope("input_Va

机器学习与Tensorflow(7)——tf.train.Saver()、inception-v3的应用

1. tf.train.Saver() tf.train.Saver()是一个类,提供了变量.模型(也称图Graph)的保存和恢复模型方法. TensorFlow是通过构造Graph的方式进行深度学习,任何操作(如卷积.池化等)都需要operator,保存和恢复操作也不例外. 在tf.train.Saver()类初始化时,用于保存和恢复的save和restore operator会被加入Graph.所以,下列类初始化操作应在搭建Graph时完成. saver = tf.train.Saver()

TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model

TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model Checkmate is designed to be a simple drop-in solution for a very common Tensorflow use-case: keeping track of the best model checkpoints during training. The BestCheckpointSaver is a wrapper arou

TensorFlow 实战(二)—— tf train(优化算法)

Training | TensorFlow tf 下以大写字母开头的含义为名词的一般表示一个类(class) 1. 优化器(optimizer) 优化器的基类(Optimizer base class)主要实现了两个接口,一是计算损失函数的梯度,二是将梯度作用于变量.tf.train 主要提供了如下的优化函数: tf.train.Optimizer tf.train.GradientDescentOptimizer tf.train.AdadeltaOpzimizer Ada delta tf.

tf.train.Saver()-tensorflow中模型的保存及读取

作用:训练网络之后保存训练好的模型,以及在程序中读取已保存好的模型 使用步骤: 实例化一个Saver对象 saver = tf.train.Saver() 在训练过程中,定期调用saver.save方法,像文件夹中写入包含当前模型中所有可训练变量的checkpoint文件 saver.save(sess,FLAGG.train_dir,global_step=step) 之后可以使用saver.restore()方法,重载模型的参数,继续训练或者用于测试数据 saver.restore(sess