时间:2021-05-22
对于使用已经训练好的模型,比如VGG,RESNET等,keras都自带了一个keras.applications.imagenet_utils.decode_predictions的方法,有很多限制:
def decode_predictions(preds, top=5): """Decodes the prediction of an ImageNet model. # Arguments preds: Numpy tensor encoding a batch of predictions. top: Integer, how many top-guesses to return. # Returns A list of lists of top class prediction tuples `(class_name, class_description, score)`. One list of tuples per sample in batch input. # Raises ValueError: In case of invalid shape of the `pred` array (must be 2D). """ global CLASS_INDEX if len(preds.shape) != 2 or preds.shape[1] != 1000: raise ValueError('`decode_predictions` expects ' 'a batch of predictions ' '(i.e. a 2D array of shape (samples, 1000)). ' 'Found array with shape: ' + str(preds.shape)) if CLASS_INDEX is None: fpath = get_file('imagenet_class_index.json', CLASS_INDEX_PATH, cache_subdir='models', file_hash='c2c37ea517e94d9795004a39431a14cb') with open(fpath) as f: CLASS_INDEX = json.load(f) results = [] for pred in preds: top_indices = pred.argsort()[-top:][::-1] result = [tuple(CLASS_INDEX[str(i)]) + (pred[i],) for i in top_indices] result.sort(key=lambda x: x[2], reverse=True) results.append(result) return results把重要的东西挖出来,然后自己敲,这样就OK了,下例以MNIST数据集为例:
import kerasfrom keras.models import Sequentialfrom keras.layers import Denseimport numpy as npimport tflearnimport tflearn.datasets.mnist as mnistdef decode_predictions_custom(preds, top=5): CLASS_CUSTOM = ["0","1","2","3","4","5","6","7","8","9"] results = [] for pred in preds: top_indices = pred.argsort()[-top:][::-1] result = [tuple(CLASS_CUSTOM[i]) + (pred[i]*100,) for i in top_indices] results.append(result) return resultsx_train, y_train, x_test, y_test = mnist.load_data(one_hot=True)model = Sequential()model.add(Dense(units=64, activation='relu', input_dim=784))model.add(Dense(units=10, activation='softmax'))model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])model.fit(x_train, y_train, epochs=10, batch_size=128)# score = model.evaluate(x_test, y_test, batch_size=128)# print(score)preds = model.predict(x_test[0:1,:])p = decode_predictions_custom(preds)for (i,(label,prob)) in enumerate(p[0]): print("{}. {}: {:.2f}%".format(i+1, label,prob)) # 1. 7: 99.43%# 2. 9: 0.24%# 3. 3: 0.23%# 4. 0: 0.05%# 5. 2: 0.03%补充知识:keras简单的去噪自编码器代码和各种类型自编码器代码
我就废话不多说了,大家还是直接看代码吧~
start = time() from keras.models import Sequentialfrom keras.layers import Dense, Dropout,Inputfrom keras.layers import Embeddingfrom keras.layers import Conv1D, GlobalAveragePooling1D, MaxPooling1Dfrom keras import layersfrom keras.models import Model # Parameters for denoising autoencodernb_visible = 120nb_hidden = 64batch_size = 16# Build autoencoder modelinput_img = Input(shape=(nb_visible,)) encoded = Dense(nb_hidden, activation='relu')(input_img)decoded = Dense(nb_visible, activation='sigmoid')(encoded) autoencoder = Model(input=input_img, output=decoded)autoencoder.compile(loss='mean_squared_error',optimizer='adam',metrics=['mae'])autoencoder.summary() # Train### 加一个early_stoopingimport keras early_stopping = keras.callbacks.EarlyStopping( monitor='val_loss', min_delta=0.0001, patience=5, verbose=0, mode='auto')autoencoder.fit(X_train_np, y_train_np, nb_epoch=50, batch_size=batch_size , shuffle=True, callbacks = [early_stopping],verbose = 1,validation_data=(X_test_np, y_test_np))# Evaluateevaluation = autoencoder.evaluate(X_test_np, y_test_np, batch_size=batch_size , verbose=1)print('val_loss: %.6f, val_mean_absolute_error: %.6f' % (evaluation[0], evaluation[1])) end = time()print('耗时:'+str((end-start)/60))keras各种自编码代码
以上这篇keras topN显示,自编写代码案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
声明:本页内容来源网络,仅供用户参考;我单位不保证亦不表示资料全面及准确无误,也不保证亦不表示这些资料为最新信息,如因任何原因,本网内容或者用户因倚赖本网内容造成任何损失或损害,我单位将不会负任何法律责任。如涉及版权问题,请提交至online#300.cn邮箱联系删除。
Keras是一个用Python编写的高级神经网络API,它能够以TensorFlow,CNTK,或者Theano作为后端运行。Keras的开发重点是支持快速的实
利用Splinter开发浏览器自动化操作,编写代码比较简单。案例一:fromsplinterimportBrowserwithBrowser()asbrowse
我就废话不多说了,大家还是直接看代码吧!model=keras.models.Sequential([#卷积层1keras.layers.Conv2D(32,k
最近每天都在空闲时间努力编写Apworks框架的案例代码WeText。在文本发布和处理微服务中,我打算使用微软的SQLServerforLinux来做演示,于是
MySQL分组排序求TopN表结构按照grp分组,按照num排序,每组取Top3,输出结果如下:源代码:SELECT*FROMscoreASt3WHERE(SE