时间:2021-05-22
在进行大量数据训练神经网络的时候,可能需要批量读取数据。于是参考了这篇文章的代码,结果发现数据一直批量循环输出,不会在数据的末尾自动停止。
然后发现这篇博文说slice_input_producer()这个函数有一个形参num_epochs,通过设置它的值就可以控制全部数据循环输出几次。
于是我设置之后出现以下的报错:
tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value input_producer/input_producer/limit_epochs/epochs [[Node: input_producer/input_producer/limit_epochs/CountUpTo = CountUpTo[T=DT_INT64, _class=["loc:@input_producer/input_producer/limit_epochs/epochs"], limit=2, _device="/job:localhost/replica:0/task:0/cpu:0"](input_producer/input_producer/limit_epochs/epochs)]]找了好久,都不知道为什么会错,于是只好去看看slice_input_producer()函数的源码,结果在源码中发现作者说这个num_epochs如果不是空的话,就是一个局部变量,需要先调用global_variables_initializer()函数初始化。
于是我调用了之后,一切就正常了,特此记录下来,希望其他人遇到的时候能够及时找到原因。
哈哈,这是笔者第一次通过阅读源码解决了问题,心情还是有点小激动。啊啊,扯远了,上最终成功的代码:
import pandas as pdimport numpy as npimport tensorflow as tfdef generate_data(): num = 25 label = np.asarray(range(0, num)) images = np.random.random([num, 5]) print('label size :{}, image size {}'.format(label.shape, images.shape)) return images,labeldef get_batch_data(): label, images = generate_data() input_queue = tf.train.slice_input_producer([images, label], shuffle=False,num_epochs=2) image_batch, label_batch = tf.train.batch(input_queue, batch_size=5, num_threads=1, capacity=64,allow_smaller_final_batch=False) return image_batch,label_batchimages,label = get_batch_data()sess = tf.Session()sess.run(tf.global_variables_initializer())sess.run(tf.local_variables_initializer())#就是这一行coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess,coord)try: while not coord.should_stop(): i,l = sess.run([images,label]) print(i) print(l)except tf.errors.OutOfRangeError: print('Done training')finally: coord.request_stop()coord.join(threads)sess.close()以上这篇tensorflow tf.train.batch之数据批量读取方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
声明:本页内容来源网络,仅供用户参考;我单位不保证亦不表示资料全面及准确无误,也不保证亦不表示这些资料为最新信息,如因任何原因,本网内容或者用户因倚赖本网内容造成任何损失或损害,我单位将不会负任何法律责任。如涉及版权问题,请提交至online#300.cn邮箱联系删除。
一、TensorFlow模型保存和提取方法1.TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取。tf.train.Saver对
读取tfrecord数据从TFRecords文件中读取数据,首先需要用tf.train.string_input_producer生成一个解析队列。之后调用tf
使用tensorflow训练模型时,我们可以使用tensorflow自带的Save模块tf.train.Saver()来保存模型,使用方式很简单就是在训练完模型
从tensorflow训练后保存的模型中打印训变量:使用tf.train.NewCheckpointReader()importtensorflowastfre
tensorflow模型保存为saver=tf.train.Saver()函数,saver.save()保存模型,代码如下:importtensorflowas