在tensorflow中实现去除不足一个batch的数据

时间:2021-05-22

我就废话不多说了,直接上代码吧!

#-*- coding:utf-8 -*-import tensorflow as tfimport numpy as np value1 = tf.placeholder(dtype=tf.float32)value2 = tf.placeholder(dtype=tf.float32)value3 = value1 + value2 #定义的dataset有参数,只能使用参数化迭代器dataset = tf.data.Dataset.range(10)# 定义参数化迭代器dataset = dataset.shuffle(100)dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(3)) #每个batch3个数据,不足3个舍弃iterator = dataset.make_initializable_iterator()next_element = iterator.get_next() with tf.Session() as sess: # 需要用参数初始化迭代器 for i in range(2): sess.run(iterator.initializer) while True: try: value = sess.run(next_element) result = sess.run(value3,feed_dict={value1:value,value2:value}) print(result) except tf.errors.OutOfRangeError: print("End of epoch %d" % i) break

以上这篇在tensorflow中实现去除不足一个batch的数据就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。

声明:本页内容来源网络,仅供用户参考;我单位不保证亦不表示资料全面及准确无误,也不保证亦不表示这些资料为最新信息,如因任何原因,本网内容或者用户因倚赖本网内容造成任何损失或损害,我单位将不会负任何法律责任。如涉及版权问题,请提交至online#300.cn邮箱联系删除。

相关文章