keras 模型参数,模型保存,中间结果输出操作


Posted in Python onJuly 06, 2020

我就废话不多说了,大家还是直接看代码吧~

'''
Created on 2018-4-16
'''
import keras
from keras.models import Sequential
from keras.layers import Dense
from keras.models import Model
from keras.callbacks import ModelCheckpoint,Callback
import numpy as np
import tflearn
import tflearn.datasets.mnist as mnist

x_train, y_train, x_test, y_test = mnist.load_data(one_hot=True)
x_valid = x_test[:5000]
y_valid = y_test[:5000]
x_test = x_test[5000:]
y_test = y_test[5000:]
print(x_valid.shape)
print(x_test.shape)

model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=784))
model.add(Dense(units=10, activation='softmax'))
model.compile(loss='categorical_crossentropy',
       optimizer='sgd',
       metrics=['accuracy'])
filepath = 'D:\\machineTest\\model-ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5'
# filepath = 'D:\\machineTest\\model-ep{epoch:03d}-loss{loss:.3f}.h5'
checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min')
print(model.get_config())
# [{'class_name': 'Dense', 'config': {'bias_regularizer': None, 'use_bias': True, 'kernel_regularizer': None, 'batch_input_shape': (None, 784), 'trainable': True, 'kernel_constraint': None, 'bias_constraint': None, 'kernel_initializer': {'class_name': 'VarianceScaling', 'config': {'scale': 1.0, 'distribution': 'uniform', 'mode': 'fan_avg', 'seed': None}}, 'activity_regularizer': None, 'units': 64, 'dtype': 'float32', 'bias_initializer': {'class_name': 'Zeros', 'config': {}}, 'activation': 'relu', 'name': 'dense_1'}}, {'class_name': 'Dense', 'config': {'bias_regularizer': None, 'use_bias': True, 'kernel_regularizer': None, 'bias_initializer': {'class_name': 'Zeros', 'config': {}}, 'kernel_constraint': None, 'bias_constraint': None, 'kernel_initializer': {'class_name': 'VarianceScaling', 'config': {'scale': 1.0, 'distribution': 'uniform', 'mode': 'fan_avg', 'seed': None}}, 'activity_regularizer': None, 'trainable': True, 'units': 10, 'activation': 'softmax', 'name': 'dense_2'}}]
# model.fit(x_train, y_train, epochs=1, batch_size=128, callbacks=[checkpoint],validation_data=(x_valid, y_valid))
model.fit(x_train, y_train, epochs=1,validation_data=(x_valid, y_valid),steps_per_epoch=10,validation_steps=1)
# score = model.evaluate(x_test, y_test, batch_size=128)
# print(score)
# #获取模型结构状况
# model.summary()
# _________________________________________________________________
# Layer (type)         Output Shape       Param #  
# =================================================================
# dense_1 (Dense)       (None, 64)        50240(784*64+64(b))   
# _________________________________________________________________
# dense_2 (Dense)       (None, 10)        650(64*10 + 10 )    
# =================================================================
# #根据下标和名称返回层对象
# layer = model.get_layer(index = 0)
# 获取模型权重,设置权重model.set_weights()
weights = np.array(model.get_weights())
print(weights.shape)
# (4,)权重由4部分组成
print(weights[0].shape)
# (784, 64)dense_1 w1
print(weights[1].shape)
# (64,)dense_1 b1
print(weights[2].shape)
# (64, 10)dense_2 w2
print(weights[3].shape)
# (10,)dense_2 b2

# # 保存权重和加载权重
# model.save_weights("D:\\xxx\\weights.h5")
# model.load_weights("D:\\xxx\\weights.h5", by_name=False)#by_name=True,可以根据名字匹配和层载入权重

# 查看中间结果,必须要先声明个函数式模型
dense1_layer_model = Model(inputs=model.input,outputs=model.get_layer('dense_1').output)
out = dense1_layer_model.predict(x_test)
print(out.shape)
# (5000, 64)

# 如果是函数式模型,则可以直接输出
# import keras
# from keras.models import Model
# from keras.callbacks import ModelCheckpoint,Callback
# import numpy as np
# from keras.layers import Input,Conv2D,MaxPooling2D
# import cv2
# 
# image = cv2.imread("D:\\machineTest\\falali.jpg")
# print(image.shape)
# cv2.imshow("1",image)
# 
# # 第一层conv
# image = image.reshape([-1, 386, 580, 3])
# img_input = Input(shape=(386, 580, 3))
# x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1')(img_input)
# x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x)
# x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)
# model = Model(inputs=img_input, outputs=x)
# out = model.predict(image)
# print(out.shape)
# out = out.reshape(193, 290,64)
# image_conv1 = out[:,:,1].reshape(193, 290)
# image_conv2 = out[:,:,20].reshape(193, 290)
# image_conv3 = out[:,:,40].reshape(193, 290)
# image_conv4 = out[:,:,60].reshape(193, 290)
# cv2.imshow("conv1",image_conv1)
# cv2.imshow("conv2",image_conv2)
# cv2.imshow("conv3",image_conv3)
# cv2.imshow("conv4",image_conv4)
# cv2.waitKey(0)

中间结果输出可以查看conv过之后的图像:

原始图像:

keras 模型参数,模型保存,中间结果输出操作

经过一层conv以后,输出其中4张图片:

keras 模型参数,模型保存,中间结果输出操作

keras 模型参数,模型保存,中间结果输出操作

keras 模型参数,模型保存,中间结果输出操作

keras 模型参数,模型保存,中间结果输出操作

以上这篇keras 模型参数,模型保存,中间结果输出操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python time模块用法实例详解
Sep 11 Python
Ubuntu 14.04+Django 1.7.1+Nginx+uwsgi部署教程
Nov 18 Python
win8下python3.4安装和环境配置图文教程
Jul 31 Python
在numpy矩阵中令小于0的元素改为0的实例
Jan 26 Python
Python 给定的经纬度标注在地图上的实现方法
Jul 05 Python
python 实现生成均匀分布的点
Dec 05 Python
django连接mysql数据库及建表操作实例详解
Dec 10 Python
python numpy 矩阵堆叠实例
Jan 17 Python
新手入门学习python Numpy基础操作
Mar 02 Python
IntelliJ 中配置 Anaconda的过程图解
Jun 01 Python
django序列化时使用外键的真实值操作
Jul 15 Python
python 实现端口扫描工具
Dec 18 Python
Python自省及反射原理实例详解
Jul 06 #Python
如何通过命令行进入python
Jul 06 #Python
解决TensorFlow调用Keras库函数存在的问题
Jul 06 #Python
python else语句在循环中的运用详解
Jul 06 #Python
Keras模型转成tensorflow的.pb操作
Jul 06 #Python
python如何进入交互模式
Jul 06 #Python
python3.4中清屏的处理方法
Jul 06 #Python
You might like
《一拳超人》埼玉一拳下去,他们存在了800年毫无意义!
2020/03/02 日漫
PHP配置心得包含MYSQL5乱码解决
2006/11/20 PHP
php使用smtp发送支持附件的邮件示例
2014/04/13 PHP
ThinkPHP的SAE开发相关注意事项详解
2016/10/09 PHP
浅析php-fpm静态和动态执行方式的比较
2016/11/09 PHP
php使用curl实现ftp文件下载功能
2017/05/16 PHP
基于swoole实现多人聊天室
2018/06/14 PHP
javascript YUI 读码日记之 YAHOO.util.Dom - Part.4
2008/03/22 Javascript
js控制div及网页相关属性的代码
2009/12/19 Javascript
javascript算法题 求任意一个1-9位不重复的N位数在该组合中的大小排列序号
2012/07/21 Javascript
JavaScript建立一个语法高亮输入框实现思路
2013/02/26 Javascript
GridView中获取被点击行中的DropDownList和TextBox中的值
2013/07/18 Javascript
JavaScript实现快速排序的方法
2015/07/31 Javascript
jquery实现Li滚动时滚动条自动添加样式的方法
2015/08/10 Javascript
基于jquery实现的银行卡号每隔4位自动插入空格的实现代码
2016/11/22 Javascript
vue.js 使用v-if v-else发现没有执行解决办法
2017/05/15 Javascript
实现微信小程序的wxml文件和wxss文件在webstrom的支持
2017/06/12 Javascript
bootstrap paginator分页前后台用法示例
2017/06/17 Javascript
微信小程序实现的一键连接wifi功能示例
2019/04/24 Javascript
[00:52]DOTA2齐天大圣预告片
2016/08/13 DOTA
[25:45]2018DOTA2亚洲邀请赛4.5SOLO赛 Sylar vs Paparazi
2018/04/06 DOTA
[10:21]2018DOTA2国际邀请赛寻真——Winstrike
2018/08/11 DOTA
解决python3中自定义wsgi函数,make_server函数报错的问题
2017/11/21 Python
Python实现京东秒杀功能代码
2019/05/16 Python
django框架使用orm实现批量更新数据的方法
2019/06/21 Python
Pycharm简单使用教程(入门小结)
2019/07/04 Python
对Pytorch神经网络初始化kaiming分布详解
2019/08/18 Python
使用matlab或python将txt文件转为excel表格
2019/11/01 Python
Jupyter notebook设置背景主题,字体大小及自动补全代码的操作
2020/04/13 Python
CSS3中颜色线性渐变实战
2015/07/18 HTML / CSS
世界上最大的街头服饰网站:Karmaloop
2017/02/04 全球购物
Wojas罗马尼亚网站:波兰皮鞋品牌
2018/11/01 全球购物
性能测试工程师的面试题
2015/02/20 面试题
预备党员入党自我评价范文
2014/03/10 职场文书
法学自荐信
2014/06/20 职场文书
Pandas实现DataFrame的简单运算、统计与排序
2022/03/31 Python