时间:2021-05-22
ImageDataGenerator位于keras.preprocessing.image模块当中,可用于做数据增强,或者仅仅用于一个批次一个批次的读进图片数据.一开始以为ImageDataGenerator是用来做数据增强的,但我的目的只是想一个batch一个batch的读进图片而已,所以一开始没用它,后来发现它是有这个功能的,而且使用起来很方便.
ImageDataGenerator类包含了如下参数:(keras中文教程)
ImageDataGenerator(featurewise_center=False, #布尔值。将输入数据的均值设置为 0,逐特征进行 samplewise_center=False, #布尔值。将每个样本的均值设置为 0 featurewise_std_normalization=False, #布尔值。将输入除以数据标准差,逐特征进行 samplewise_std_normalization=False, #布尔值。将每个输入除以其标准差 zca_whitening=False, #是否进行ZAC白化 zca_epsilon=1e-06, #ZCA 白化的 epsilon 值 rotation_range=0, #整数。随机旋转的度数范围 width_shift_range=0.0, height_shift_range=0.0, brightness_range=None, shear_range=0.0, #浮点数。剪切强度(以弧度逆时针方向剪切角度) zoom_range=0.0, #浮点数 或 [lower, upper]。随机缩放范围。如果是浮点数,[lower, upper] = [1-zoom_range, 1+zoom_range]。 channel_shift_range=0.0, #浮点数。随机通道转换的范围 fill_mode='nearest', #输入边界以外的点的模式填充 cval=0.0, #当 fill_mode = "constant",边界点的填充值 horizontal_flip=False, #随机水平翻转 vertical_flip=False, #随机垂直翻 rescale=None, #默认为 None。如果是 None 或 0,不进行缩放,否则将数据乘以所提供的值(在应用任何其他转换之前) preprocessing_function=None, #应用于每个输入的函数。这个函数会在任何其他改变之前运行。这个函数需要一个参数:一张图像(秩为 3 的 Numpy 张量),并且应该输出一个同尺寸的 Numpy 张量。 data_format=None, #图像数据格式,{"channels_first", "channels_last"} 之一 validation_split=0.0, dtype=None) #生成数组使用的数据类型虽然包含了很多参数,但实际应用时用到的并不会很多,假设我的目的只是一个batch一个batch的读进图片,那么,我在实例化对象的时候什么参数都不需要设置,然后再调用ImageDataGenerator类的成员函数flow_from_directory()就可以从目录中读图.
我放图片的目录如下图,在train文件夹中包含了两个子文件夹,然后在两个子文件夹里面分别包含了猫和狗的图片.
先看看flow_from_directory()的参数.需要注意的是,第一个参数directory不是图片的路径,而是子文件夹的路径,还有就是第四个参数classes,它填写是子文件夹的名称,比如此处的为['cat', 'dog'],然后该函数就会自动把两个子文件夹看成是2个类别,cat文件夹里面所有图片的标签都为0,dog文件夹里面所有图片的标签都为1.而且可以通过设置第5个参数class_mode把标签设置为ont-hot形式(默认的categorical就是one-hot 形式).可以看出,这个函数有多方便,直接把标签和原图对应起来了.
def flow_from_directory(self, directory, #子文件夹所在的目录 target_size=(256, 256), #输出的图片的尺寸 color_mode='rgb', #单通道还是三通道 classes=None, #类别,有多少个子文件夹就有多少个类别,填写的是子文件夹的名称 class_mode='categorical', #通常默认,表示标签采用one-hot形式, batch_size=32, shuffle=True, #是否随机打乱顺序 seed=None, save_to_dir=None, #把图片保存,输入的是路径 save_prefix='', #图像前缀名, save_format='png', #图像后缀名 follow_links=False, subset=None, interpolation='nearest')接下来看一个例子,部分代码.
from tensorflow.keras.preprocessing.image import ImageDataGenerator #我是直接装tensorflow,然后使用里面的keras的, #实例化对象datagendatagen=ImageDataGenerator() #读训练集图片train_generator = datagen.flow_from_directory( '/home/hky/folder/kaggle/DataGenerator/train', classes=['cat','dog'], target_size=(227, 227), class_mode='categorical', batch_size=batch_size) #读验证集图片validation_generator = datagen.flow_from_directory( '/home/hky/folder/kaggle/DataGenerator/validation', classes=['cat','dog'], target_size=(227, 227), class_mode='categorical', batch_size=batch_size) '''开始训练'''#steps_per_epoch是为了判断是否完成了一个epoch,这里我训练集有20000张图片,然后batch_size=16,所以是10000/16#同样,validation_steps=2496/16是因为我的验证集有2496张图片model.fit_generator(generator=train_generator,steps_per_epoch=20000/16,epochs=10,validation_data=validation_generator,validation_steps=2496/16)下面是完整代码,实现了一个AlexNet模型.
import tensorflow as tffrom tensorflow.keras.models import Sequentialfrom tensorflow.keras.preprocessing import imagefrom tensorflow.keras.callbacks import EarlyStoppingfrom tensorflow.keras import optimizersfrom tensorflow.keras.preprocessing.image import ImageDataGeneratorimport numpy as npimport cv2import yamlfrom tensorflow.keras.models import model_from_yaml batch_size = 16 ''' 搭建模型'''l=tf.keras.layersmodel=Sequential() #第一层卷积和池化model.add(l.Conv2D(filters=96,kernel_size=(11,11),strides=(4,4),padding='valid',input_shape=(227,227,3),activation='relu'))model.add(l.BatchNormalization())model.add(l.MaxPooling2D(pool_size=(3,3),strides=(2,2),padding='valid')) #第二层卷积和池化model.add(l.Conv2D(256,(5,5),(1,1),padding='same',activation='relu'))model.add(l.BatchNormalization())model.add(l.MaxPooling2D((3,3),(2,2),padding='valid')) #第三层卷积model.add(l.Conv2D(384,(3,3),(1,1),'same',activation='relu')) #第四层卷积model.add(l.Conv2D(384,(3,3),(1,1),'same',activation='relu')) #第五层卷积和池化model.add(l.Conv2D(256,(3,3),(1,1),'same',activation='relu'))model.add(l.MaxPooling2D((3,3),(2,2),'valid')) #全连接层model.add(l.Flatten())model.add(l.Dense(4096,activation='relu'))model.add(l.Dropout(0.5)) model.add(l.Dense(4096,activation='relu'))model.add(l.Dropout(0.5)) model.add(l.Dense(1000,activation='relu'))model.add(l.Dropout(0.5)) #输出层model.add(l.Dense(2,activation='softmax'))model.compile(optimizer='sgd',loss='categorical_crossentropy',metrics=['accuracy']) '''导入图片数据'''#利用ImageDataGenerator生成一个batch一个batch的数据 datagen=ImageDataGenerator(samplewise_center=True,rescale=1.0/255) #samplewise_center:使输入数据的每个样本均值为0,rescale:归一化train_generator = datagen.flow_from_directory( '/home/hky/folder/kaggle/DataGenerator/train', classes=['cat','dog'], target_size=(227, 227), class_mode='categorical', batch_size=batch_size) validation_generator = datagen.flow_from_directory( '/home/hky/folder/kaggle/DataGenerator/validation', classes=['cat','dog'], target_size=(227, 227), class_mode='categorical', batch_size=batch_size) '''开始训练'''model.fit_generator(generator=train_generator,steps_per_epoch=20000/16,epochs=10,validation_data=validation_generator,validation_steps=2496/16) yaml_string = model.to_yaml() # 保存模型结构到yaml文件open('./model_architecture.yaml', 'w').write(yaml_string)model.save_weights('./AlexNet_model.h5') #保存模型参数 '''导入模型'''#model = model_from_yaml(open('./model_architecture.yaml').read())#model.load_weights('./AlexNet_model.h5') '''随便输入一张图片测试一下'''imgs=[]img=cv2.imread('/home/hky/folder/kaggle/test/120.jpg')img=cv2.resize(img,(227,227))imgs.append(img)a=np.array(imgs) result=model.predict(a)idx=np.argmax(result) if idx==0: print('the image is cat\n')else: print('the image is dog\n') cv2.imshow("image",img)cv2.waitKey(0)以上这篇使用Keras中的ImageDataGenerator进行批次读图方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
声明:本页内容来源网络,仅供用户参考;我单位不保证亦不表示资料全面及准确无误,也不保证亦不表示这些资料为最新信息,如因任何原因,本网内容或者用户因倚赖本网内容造成任何损失或损害,我单位将不会负任何法律责任。如涉及版权问题,请提交至online#300.cn邮箱联系删除。
fit_generator是keras提供的用来进行批次训练的函数,使用方法如下:model.fit_generator(generator,steps_per
最近在工作中进行了NLP的内容,使用的还是Keras中embedding的词嵌入来做的。Keras中embedding层做一下介绍。中文文档地址:https:/
问题我们使用anoconda创建envs环境下的Tensorflow-gpu版的,但是当我们在Pycharm设置里的工程中安装Keras后,发现调用keras无
前言在使用keras时候报错Keyerror‘acc',这是一个keras版本问题,acc和accuracy本意是一样的,但是不同keras版本使用不同命名,因
使用keras实现性别识别,模型数据使用的是oarriaga/face_classification的模型实现效果准备工作在开始之前先要安装keras和tens