时间:2021-05-22
tensorflow官方提供了3种方法来读取数据:
预加载数据(preloaded data):在TensorFlow图中定义常量或变量来保存所有的数据,适用于数据量不太大的情况。填充数据(feeding):通过Python产生数据,然后再把数据填充到后端。
从文件读取数据(reading from file):从文件中直接读取,然后通过队列管理器从文件中读取数据。
本文主要介绍第三种方法,通过tfrecord文件来保存和读取数据,对于前两种读取数据的方式也会进行一个简单的介绍。
项目下载github地址:https://github.com/steelOneself/tensorflow_learn/tree/master/tf_records_writer_read
一、预加载数据
a = tf.constant([1,2,3]) b = tf.constant([4,5,6]) c = tf.add(a,b) with tf.Session() as sess: print(sess.run(c))#[5 7 9]这种方式加载数据比较简单,它是直接将数据嵌入在数据流图中,当训练数据较大时,比较消耗内存。
二、填充数据
通过先定义placeholder然后再通过feed_dict来喂养数据,这种方式在TensorFlow中使用的也是比较多的,但是也存在数据量大时比较消耗内存的缺点,下面介绍一种更高效的数据读取方式,通过tfrecord文件来读取数据。
x = tf.placeholder(tf.int16) y = tf.placeholder(tf.int16) z = tf.add(x,y) with tf.Session() as sess: print(sess.run(z,feed_dict={x:[1,2,3],y:[4,5,6]})) #[5 7 9]三、从文件读取数据
通过slim来实现将图片保存为tfrecord文件和tfrecord文件的读取,slim是基于TensorFlow的一个更高级别的封装模型,通过slim来编程可以实现更高效率和更简洁的代码。
在本次实验中使用的数据集是kaggle的dog vs cat,数据集下载地址:https://mon_queue_min = 24 ) raw_image,img_label = data_provider.get(["image","label"]) #Perform the correct preprocessing for this image depending if it is training or evaluating image = preprocess_image(raw_image, height, width,True) #As for the raw images, we just do a simple reshape to batch it up raw_image = tf.expand_dims(raw_image, 0) raw_image = tf.image.resize_nearest_neighbor(raw_image, [height, width]) raw_image = tf.squeeze(raw_image) #获取一个batch数据 images,raw_image,labels = tf.train.batch( [image,raw_image,img_label], batch_size=batch_size, num_threads=4, capacity=4*batch_size, allow_smaller_final_batch=True ) return images,raw_image,labels
c、读取tfrecord文件
#读取tfrecord文件def read_tfrecord(): #从tfreocrd文件中读取数据 train_dataset = get_dataset_by_tfrecords("train",dataset_dir_path,"catVSdog",2,label_num_to_name) images,raw_images,labels = load_batch("train",train_dataset,batch_size,227,227) with tf.Session() as sess: threads = tf.train.start_queue_runners(sess) for i in range(6): train_img,train_label = sess.run([raw_images,labels]) plt.subplot(2,3,i+1) plt.imshow(np.array(train_img[0])) plt.title("image label:%s"%str(label_num_to_name[train_label[0]])) plt.show()读取训练集的tfrecord文件,只从tfrecord文件中获取了图片数据和图片的标签,images表示的是预处理后的图片,raw_images表示的是没有经过预处理的图片。
以上这篇tensorflow将图片保存为tfrecord和tfrecord的读取方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
声明:本页内容来源网络,仅供用户参考;我单位不保证亦不表示资料全面及准确无误,也不保证亦不表示这些资料为最新信息,如因任何原因,本网内容或者用户因倚赖本网内容造成任何损失或损害,我单位将不会负任何法律责任。如涉及版权问题,请提交至online#300.cn邮箱联系删除。
tensorflow中如果要对神经网络模型进行训练,需要把训练数据转换为tfrecord格式才能被读取,tensorflow的model文件里直接提供了相应的脚
TensorFlow提供了一种统一的格式来存储数据,就是TFRecord,它可以统一不同的原始数据格式,并且更加有效地管理不同的属性。TFRecord格式TFR
训练模型时,我们并不是直接将图像送入模型,而是先将图像转换为tfrecord文件,再将tfrecord文件送入模型。为进一步理解tfrecord文件,本例先将6
tensorflow模型保存为saver=tf.train.Saver()函数,saver.save()保存模型,代码如下:importtensorflowas
在上一篇文章tensorflow入门:tfrecord和tf.data.TFRecordDataset的使用里,讲到了使用如何使用tf.data.TFRecor