keras topN显示,自编写代码案例


Posted in Python onJuly 03, 2020

对于使用已经训练好的模型,比如VGG,RESNET等,keras都自带了一个keras.applications.imagenet_utils.decode_predictions的方法,有很多限制:

def decode_predictions(preds, top=5):
 """Decodes the prediction of an ImageNet model.

 # Arguments
 preds: Numpy tensor encoding a batch of predictions.
 top: Integer, how many top-guesses to return.

 # Returns
 A list of lists of top class prediction tuples
 `(class_name, class_description, score)`.
 One list of tuples per sample in batch input.

 # Raises
 ValueError: In case of invalid shape of the `pred` array
  (must be 2D).
 """
 global CLASS_INDEX
 if len(preds.shape) != 2 or preds.shape[1] != 1000:
 raise ValueError('`decode_predictions` expects '
    'a batch of predictions '
    '(i.e. a 2D array of shape (samples, 1000)). '
    'Found array with shape: ' + str(preds.shape))
 if CLASS_INDEX is None:
 fpath = get_file('imagenet_class_index.json',
    CLASS_INDEX_PATH,
    cache_subdir='models',
    file_hash='c2c37ea517e94d9795004a39431a14cb')
 with open(fpath) as f:
  CLASS_INDEX = json.load(f)
 results = []
 for pred in preds:
 top_indices = pred.argsort()[-top:][::-1]
 result = [tuple(CLASS_INDEX[str(i)]) + (pred[i],) for i in top_indices]
 result.sort(key=lambda x: x[2], reverse=True)
 results.append(result)
 return results

把重要的东西挖出来,然后自己敲,这样就OK了,下例以MNIST数据集为例:

import keras
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
import tflearn
import tflearn.datasets.mnist as mnist

def decode_predictions_custom(preds, top=5):
 CLASS_CUSTOM = ["0","1","2","3","4","5","6","7","8","9"]
 results = []
 for pred in preds:
 top_indices = pred.argsort()[-top:][::-1]
 result = [tuple(CLASS_CUSTOM[i]) + (pred[i]*100,) for i in top_indices]
 results.append(result)
 return results

x_train, y_train, x_test, y_test = mnist.load_data(one_hot=True)

model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=784))
model.add(Dense(units=10, activation='softmax'))
model.compile(loss='categorical_crossentropy',
  optimizer='sgd',
  metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, batch_size=128)
# score = model.evaluate(x_test, y_test, batch_size=128)
# print(score)
preds = model.predict(x_test[0:1,:])
p = decode_predictions_custom(preds)
for (i,(label,prob)) in enumerate(p[0]):
 print("{}. {}: {:.2f}%".format(i+1, label,prob)) 
# 1. 7: 99.43%
# 2. 9: 0.24%
# 3. 3: 0.23%
# 4. 0: 0.05%
# 5. 2: 0.03%

补充知识:keras简单的去噪自编码器代码和各种类型自编码器代码

我就废话不多说了,大家还是直接看代码吧~

start = time()
 
from keras.models import Sequential
from keras.layers import Dense, Dropout,Input
from keras.layers import Embedding
from keras.layers import Conv1D, GlobalAveragePooling1D, MaxPooling1D
from keras import layers
from keras.models import Model
 
# Parameters for denoising autoencoder
nb_visible = 120
nb_hidden = 64
batch_size = 16
# Build autoencoder model
input_img = Input(shape=(nb_visible,))
 
encoded = Dense(nb_hidden, activation='relu')(input_img)
decoded = Dense(nb_visible, activation='sigmoid')(encoded)
 
autoencoder = Model(input=input_img, output=decoded)
autoencoder.compile(loss='mean_squared_error',optimizer='adam',metrics=['mae'])
autoencoder.summary()
 
# Train
### 加一个early_stooping
import keras 
 
early_stopping = keras.callbacks.EarlyStopping(
  monitor='val_loss',
  min_delta=0.0001,
  patience=5, 
  verbose=0, 
  mode='auto'
)
autoencoder.fit(X_train_np, y_train_np, nb_epoch=50, batch_size=batch_size , shuffle=True,
        callbacks = [early_stopping],verbose = 1,validation_data=(X_test_np, y_test_np))
# Evaluate
evaluation = autoencoder.evaluate(X_test_np, y_test_np, batch_size=batch_size , verbose=1)
print('val_loss: %.6f, val_mean_absolute_error: %.6f' % (evaluation[0], evaluation[1]))
 
end = time()
print('耗时:'+str((end-start)/60))

keras各种自编码代码

以上这篇keras topN显示,自编写代码案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python操作摄像头截图实现远程监控的例子
Mar 25 Python
浅析Python中的赋值和深浅拷贝
Aug 15 Python
python 查找文件名包含指定字符串的方法
Jun 05 Python
Python学习小技巧总结
Jun 10 Python
浅谈tensorflow中几个随机函数的用法
Jul 27 Python
python 中文件输入输出及os模块对文件系统的操作方法
Aug 27 Python
python 使用plt画图,去除图片四周的白边方法
Jul 09 Python
python使用turtle库绘制奥运五环
Feb 24 Python
浅谈pandas.cut与pandas.qcut的使用方法及区别
Mar 03 Python
pycharm2020.2 配置使用的方法详解
Sep 16 Python
python爬虫中PhantomJS加载页面的实例方法
Nov 12 Python
Python 使用SFTP和FTP实现对服务器的文件下载功能
Dec 17 Python
python如何使用代码运行助手
Jul 03 #Python
Python 3.10 的首个 PEP 诞生,内置类型 zip() 迎来新特性(推荐)
Jul 03 #Python
python3 简单实现组合设计模式
Jul 02 #Python
Django Session和Cookie分别实现记住用户登录状态操作
Jul 02 #Python
django 装饰器 检测登录状态操作
Jul 02 #Python
详解用Python爬虫获取百度企业信用中企业基本信息
Jul 02 #Python
django 实现后台从富文本提取纯文本
Jul 02 #Python
You might like
Discuz!5的PHP代码高亮显示插件(黑暗中的舞者更新)
2007/01/29 PHP
PHP中$_SERVER使用说明
2015/07/05 PHP
Yii2 assets清除缓存的方法
2016/05/16 PHP
php+ajax实现带进度条的上传图片功能【附demo源码下载】
2016/09/14 PHP
PHP处理Ajax请求与Ajax跨域问题
2017/02/13 PHP
php常用字符串查找函数strstr()与strpos()实例分析
2019/06/21 PHP
JQuery 表格操作(交替显示、拖动表格行、选择行等)
2009/07/29 Javascript
javascript测试题练习代码
2012/10/10 Javascript
常见表单重复提交问题整理及解决方法
2013/11/13 Javascript
JS OffsetParent属性深入解析
2014/01/13 Javascript
扒一扒JavaScript 预解释
2015/01/28 Javascript
JS实现随页面滚动显示/隐藏窗口固定位置元素
2016/02/26 Javascript
javascript中闭包概念与用法深入理解
2016/12/15 Javascript
jQuery焦点图轮播效果实现方法
2016/12/19 Javascript
AngularJS路由实现页面跳转实例
2017/03/03 Javascript
Javascript ES6中数据类型Symbol的使用详解
2017/05/02 Javascript
解决Extjs下拉框不显示的问题
2017/06/21 Javascript
微信小程序之蓝牙的链接
2017/09/26 Javascript
Koa2微信公众号开发之本地开发调试环境搭建
2018/05/16 Javascript
小程序组件之仿微信通讯录的实现代码
2018/09/12 Javascript
Nodejs文件上传、监听上传进度的代码
2020/03/27 NodeJs
JS内置对象和Math对象知识点详解
2020/04/03 Javascript
[19:14]DOTA2 HEROS教学视频教你分分钟做大人-维萨吉
2014/06/24 DOTA
Django内容增加富文本功能的实例
2017/10/17 Python
Python开发网站目录扫描器的实现
2019/02/21 Python
python实现在cmd窗口显示彩色文字
2019/06/24 Python
python函数不定长参数使用方法解析
2019/12/14 Python
详解Python 最短匹配模式
2020/07/29 Python
python复合条件下的字典排序
2020/12/18 Python
婚鞋、新娘鞋、礼服鞋、童鞋:Nina Shoes
2019/09/04 全球购物
工作自荐信
2013/12/11 职场文书
办公室秘书自我鉴定
2014/01/18 职场文书
早读迟到检讨书
2014/01/24 职场文书
留学顾问岗位职责
2014/04/14 职场文书
2015年污水处理厂工作总结
2015/05/26 职场文书
用几道面试题来看JavaScript执行机制
2021/04/30 Javascript