时间:2021-05-22
直接看代码例子,有详细注释!!
import tensorflow as tfimport numpy as npd = np.arange(0,60).reshape([6, 10])# 将array转化为tensordata = tf.data.Dataset.from_tensor_slices(d)# 从data数据集中按顺序抽取buffer_size个样本放在buffer中,然后打乱buffer中的样本# buffer中样本个数不足buffer_size,继续从data数据集中安顺序填充至buffer_size,# 此时会再次打乱data = data.shuffle(buffer_size=3)# 每次从buffer中抽取4个样本data = data.batch(4)# 将data数据集重复,其实就是2个epoch数据集data = data.repeat(2)# 构造获取数据的迭代器iters = data.make_one_shot_iterator()# 每次从迭代器中获取一批数据batch = iters.get_next()sess = tf.Session()sess.run(batch)# 数据集完成遍历完之后,继续抽取的话会报错:OutOfRangeErrorIn [21]: dOut[21]: array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], [30, 31, 32, 33, 34, 35, 36, 37, 38, 39], [40, 41, 42, 43, 44, 45, 46, 47, 48, 49], [50, 51, 52, 53, 54, 55, 56, 57, 58, 59]])In [22]: sess.run(batch)Out[22]: array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [30, 31, 32, 33, 34, 35, 36, 37, 38, 39], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]])In [23]: sess.run(batch)Out[23]: array([[40, 41, 42, 43, 44, 45, 46, 47, 48, 49], [50, 51, 52, 53, 54, 55, 56, 57, 58, 59]])从输出结果可以看出:
shuffle是按顺序将数据放入buffer里面的;
当repeat函数在shuffle之后的话,是将一个epoch的数据集抽取完毕,再进行下一个epoch的。
那么,当repeat函数在shuffle之前会怎么样呢?如下:
data = data.repeat(2)data = data.shuffle(buffer_size=3)data = data.batch(4)In [25]: sess.run(batch)Out[25]: array([[10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [40, 41, 42, 43, 44, 45, 46, 47, 48, 49]])In [26]: sess.run(batch)Out[26]: array([[50, 51, 52, 53, 54, 55, 56, 57, 58, 59], [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [30, 31, 32, 33, 34, 35, 36, 37, 38, 39], [30, 31, 32, 33, 34, 35, 36, 37, 38, 39]])In [27]: sess.run(batch)Out[27]: array([[10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [50, 51, 52, 53, 54, 55, 56, 57, 58, 59], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], [40, 41, 42, 43, 44, 45, 46, 47, 48, 49]])可以看出,其实它就是先将数据集复制一遍,然后把两个epoch当成同一个新的数据集,一直shuffle和batch下去。
以上这篇TensorFlow dataset.shuffle、batch、repeat的使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
声明:本页内容来源网络,仅供用户参考;我单位不保证亦不表示资料全面及准确无误,也不保证亦不表示这些资料为最新信息,如因任何原因,本网内容或者用户因倚赖本网内容造成任何损失或损害,我单位将不会负任何法律责任。如涉及版权问题,请提交至online#300.cn邮箱联系删除。
DataLoader完整的参数表如下:classtorch.utils.data.DataLoader(dataset,batch_size=1,shuffle
1.作用dataset.shuffle作用是将数据进行打乱操作,传入参数为buffer_size,改参数为设置“打乱缓存区大小”,也就是说程序会维持一个buff
batch很好理解,就是batchsize。注意在一个epoch中最后一个batch大小可能小于等于batchsizedataset.repeat就是俗称epo
今天踩过的两个小坑:一.用random的shuffle打乱数据集中的数据-标签对index=[iforiinrange(len(X_batch))]#print
numpy.random.shuffle在做将caffe模型和预训练的参数转化为tensorflow的模型和预训练的参数,以便微调,遇到如下函数:defgen_