tensorflow将图片保存为tfrecord和tfrecord的读取方式

时间:2021-05-22

tensorflow官方提供了3种方法来读取数据:

预加载数据(preloaded data):在TensorFlow图中定义常量或变量来保存所有的数据,适用于数据量不太大的情况。填充数据(feeding):通过Python产生数据,然后再把数据填充到后端。

从文件读取数据(reading from file):从文件中直接读取,然后通过队列管理器从文件中读取数据。

本文主要介绍第三种方法,通过tfrecord文件来保存和读取数据,对于前两种读取数据的方式也会进行一个简单的介绍。

项目下载github地址:https://github.com/steelOneself/tensorflow_learn/tree/master/tf_records_writer_read

一、预加载数据

a = tf.constant([1,2,3]) b = tf.constant([4,5,6]) c = tf.add(a,b) with tf.Session() as sess: print(sess.run(c))#[5 7 9]

这种方式加载数据比较简单,它是直接将数据嵌入在数据流图中,当训练数据较大时,比较消耗内存。

二、填充数据

通过先定义placeholder然后再通过feed_dict来喂养数据,这种方式在TensorFlow中使用的也是比较多的,但是也存在数据量大时比较消耗内存的缺点,下面介绍一种更高效的数据读取方式,通过tfrecord文件来读取数据。

x = tf.placeholder(tf.int16) y = tf.placeholder(tf.int16) z = tf.add(x,y) with tf.Session() as sess: print(sess.run(z,feed_dict={x:[1,2,3],y:[4,5,6]})) #[5 7 9]

三、从文件读取数据

通过slim来实现将图片保存为tfrecord文件和tfrecord文件的读取,slim是基于TensorFlow的一个更高级别的封装模型,通过slim来编程可以实现更高效率和更简洁的代码。

在本次实验中使用的数据集是kaggle的dog vs cat,数据集下载地址:https://mon_queue_min = 24 ) raw_image,img_label = data_provider.get(["image","label"]) #Perform the correct preprocessing for this image depending if it is training or evaluating image = preprocess_image(raw_image, height, width,True) #As for the raw images, we just do a simple reshape to batch it up raw_image = tf.expand_dims(raw_image, 0) raw_image = tf.image.resize_nearest_neighbor(raw_image, [height, width]) raw_image = tf.squeeze(raw_image) #获取一个batch数据 images,raw_image,labels = tf.train.batch( [image,raw_image,img_label], batch_size=batch_size, num_threads=4, capacity=4*batch_size, allow_smaller_final_batch=True ) return images,raw_image,labels

c、读取tfrecord文件

#读取tfrecord文件def read_tfrecord(): #从tfreocrd文件中读取数据 train_dataset = get_dataset_by_tfrecords("train",dataset_dir_path,"catVSdog",2,label_num_to_name) images,raw_images,labels = load_batch("train",train_dataset,batch_size,227,227) with tf.Session() as sess: threads = tf.train.start_queue_runners(sess) for i in range(6): train_img,train_label = sess.run([raw_images,labels]) plt.subplot(2,3,i+1) plt.imshow(np.array(train_img[0])) plt.title("image label:%s"%str(label_num_to_name[train_label[0]])) plt.show()

读取训练集的tfrecord文件,只从tfrecord文件中获取了图片数据和图片的标签,images表示的是预处理后的图片,raw_images表示的是没有经过预处理的图片。

以上这篇tensorflow将图片保存为tfrecord和tfrecord的读取方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。

声明:本页内容来源网络,仅供用户参考;我单位不保证亦不表示资料全面及准确无误,也不保证亦不表示这些资料为最新信息,如因任何原因,本网内容或者用户因倚赖本网内容造成任何损失或损害,我单位将不会负任何法律责任。如涉及版权问题,请提交至online#300.cn邮箱联系删除。

相关文章