时间:2021-05-23
ImageDataGenerator的参数自己看文档
from keras.preprocessing import imageimport numpy as npX_train=np.ones((3,123,123,1))Y_train=np.array([[1],[2],[2]])generator=image.ImageDataGenerator(featurewise_center=False, samplewise_center=False, featurewise_std_normalization=False, samplewise_std_normalization=False, zca_whitening=False, zca_epsilon=1e-6, rotation_range=180, width_shift_range=0.2, height_shift_range=0.2, shear_range=0, zoom_range=0.001, channel_shift_range=0, fill_mode='nearest', cval=0., horizontal_flip=True, vertical_flip=True, rescale=None, preprocessing_function=None, data_format='channels_last')a=generator.flow(X_train,Y_train,batch_size=20)#生成的是一个迭代器,可直接用于for循环'''batch_size如果小于X的第一维m,next生成的多维矩阵的第一维是为batch_size,输出是从输入中随机选取batch_size个数据batch_size如果大于X的第一维m,next生成的多维矩阵的第一维是m,输出是m个数据,不过顺序随机,输出的X,Y是一一对对应的如果要直接用于tf.placeholder(),要求生成的矩阵和要与tf.placeholder相匹配'''X,Y=next(a)print(Y)X,Y=next(a)print(Y)X,Y=next(a)print(Y)X,Y=next(a)输出
[[2] [1] [2]][[2] [2] [1]][[2] [2] [1]][[2] [2] [1]]补充知识:tensorflow 与keras 混用之坑
在使用tensorflow与keras混用是model.save 是正常的但是在load_model的时候报错了在这里mark 一下
其中错误为:TypeError: tuple indices must be integers, not list
再一一番百度后无结果,上谷歌后找到了类似的问题。但是是一对鸟文不知道什么东西(翻译后发现是俄文)。后来谷歌翻译了一下找到了解决方法。故将原始问题文章贴上来警示一下
原训练代码
from tensorflow.python.keras.preprocessing.image import ImageDataGeneratorfrom tensorflow.python.keras.models import Sequentialfrom tensorflow.python.keras.layers import Conv2D, MaxPooling2D, BatchNormalizationfrom tensorflow.python.keras.layers import Activation, Dropout, Flatten, Dense #Каталог с данными для обученияtrain_dir = 'train'# Каталог с данными для проверкиval_dir = 'val'# Каталог с данными для тестированияtest_dir = 'val' # Размеры изображенияimg_width, img_height = 800, 800# Размерность тензора на основе изображения для входных данных в нейронную сеть# backend Tensorflow, channels_lastinput_shape = (img_width, img_height, 3)# Количество эпохepochs = 1# Размер мини-выборкиbatch_size = 4# Количество изображений для обученияnb_train_samples = 300# Количество изображений для проверкиnb_validation_samples = 25# Количество изображений для тестированияnb_test_samples = 25 model = Sequential() model.add(Conv2D(32, (7, 7), padding="same", input_shape=input_shape))model.add(BatchNormalization())model.add(Activation('tanh'))model.add(MaxPooling2D(pool_size=(10, 10))) model.add(Conv2D(64, (5, 5), padding="same"))model.add(BatchNormalization())model.add(Activation('tanh'))model.add(MaxPooling2D(pool_size=(10, 10))) model.add(Flatten())model.add(Dense(512))model.add(Activation('relu'))model.add(Dropout(0.5))model.add(Dense(10, activation='softmax')) model.compile(loss='categorical_crossentropy', optimizer="Nadam", metrics=['accuracy'])print(model.summary())datagen = ImageDataGenerator(rescale=1. / 255) train_generator = datagen.flow_from_directory( train_dir, target_size=(img_width, img_height), batch_size=batch_size, class_mode='categorical') val_generator = datagen.flow_from_directory( val_dir, target_size=(img_width, img_height), batch_size=batch_size, class_mode='categorical') test_generator = datagen.flow_from_directory( test_dir, target_size=(img_width, img_height), batch_size=batch_size, class_mode='categorical') model.fit_generator( train_generator, steps_per_epoch=nb_train_samples // batch_size, epochs=epochs, validation_data=val_generator, validation_steps=nb_validation_samples // batch_size) print('Сохраняем сеть')model.save("grib.h5")print("Сохранение завершено!")模型载入
from tensorflow.python.keras.preprocessing.image import ImageDataGeneratorfrom tensorflow.python.keras.models import Sequentialfrom tensorflow.python.keras.layers import Conv2D, MaxPooling2D, BatchNormalizationfrom tensorflow.python.keras.layers import Activation, Dropout, Flatten, Densefrom keras.models import load_model print("Загрузка сети")model = load_model("grib.h5")print("Загрузка завершена!")报错
/usr/bin/python3.5 /home/disk2/py/neroset/do.py/home/mama/.local/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`. from ._conv import register_converters as _register_convertersUsing TensorFlow backend.Загрузка сетиTraceback (most recent call last): File "/home/disk2/py/neroset/do.py", line 13, in <module> model = load_model("grib.h5") File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 243, in load_model model = model_from_config(model_config, custom_objects=custom_objects) File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 317, in model_from_config return layer_module.deserialize(config, custom_objects=custom_objects) File "/usr/local/lib/python3.5/dist-packages/keras/layers/__init__.py", line 55, in deserialize printable_module_name='layer') File "/usr/local/lib/python3.5/dist-packages/keras/utils/generic_utils.py", line 144, in deserialize_keras_object list(custom_objects.items()))) File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 1350, in from_config model.add(layer) File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 492, in add output_tensor = layer(self.outputs[0]) File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 590, in __call__ self.build(input_shapes[0]) File "/usr/local/lib/python3.5/dist-packages/keras/layers/normalization.py", line 92, in build dim = input_shape[self.axis]TypeError: tuple indices must be integers or slices, not list Process finished with exit code 1战斗种族解释
убераю BatchNormalization всё работает хорошо. Не подскажите в чём ошибка?Выяснил что сохранение keras и нормализация tensorflow не работают вместе нужно просто изменить строку импорта.(译文:整理BatchNormalization一切正常。 不要告诉我错误是什么?我发现保存keras和规范化tensorflow不能一起工作;只需更改导入字符串即可。)
强调文本 强调文本
keras.preprocessing.image import ImageDataGeneratorkeras.models import Sequentialkeras.layers import Conv2D, MaxPooling2D, BatchNormalizationkeras.layers import Activation, Dropout, Flatten, Dense##完美解决
##附上原文链接
https://qa-help.ru/questions/keras-batchnormalization
以上这篇keras的ImageDataGenerator和flow()的用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
声明:本页内容来源网络,仅供用户参考;我单位不保证亦不表示资料全面及准确无误,也不保证亦不表示这些资料为最新信息,如因任何原因,本网内容或者用户因倚赖本网内容造成任何损失或损害,我单位将不会负任何法律责任。如涉及版权问题,请提交至online#300.cn邮箱联系删除。
mysqldump常用于MySQL数据库逻辑备份。1、各种用法说明A.最简单的用法:mysqldump-uroot-pPassword[databasename
下面介绍jquery字符串切割函数substring的用法 代码如下:jquery字符串切割函数substring的用法说明
1、各种用法说明A.最简单的用法:复制代码代码如下:mysqldump-uroot-pPassword[databasename]>[dumpfile]上述命令
mysqldump常用于MySQL数据库逻辑备份。1、各种用法说明A.最简单的用法:mysqldump-uroot-pPassword[databasename
本文总结了PHP数组相关的函数。分享给大家供大家参考。具体如下:这里包括函数名和用法说明,没有详细的代码范例。感兴趣的朋友可以查阅本站相关的函数用法。数组的相关