浅谈keras保存模型中的save()和save_weights()区别


Posted in Python onMay 21, 2020

今天做了一个关于keras保存模型的实验,希望有助于大家了解keras保存模型的区别。

我们知道keras的模型一般保存为后缀名为h5的文件,比如final_model.h5。同样是h5文件用save()和save_weight()保存效果是不一样的。

我们用宇宙最通用的数据集MNIST来做这个实验,首先设计一个两层全连接网络:

inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)

然后,导入MNIST数据训练,分别用两种方式保存模型,在这里我还把未训练的模型也保存下来,如下:

from keras.models import Model
from keras.layers import Input, Dense
from keras.datasets import mnist
from keras.utils import np_utils
 
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train=x_train.reshape(x_train.shape[0],-1)/255.0
x_test=x_test.reshape(x_test.shape[0],-1)/255.0
y_train=np_utils.to_categorical(y_train,num_classes=10)
y_test=np_utils.to_categorical(y_test,num_classes=10)
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
 
model.save('m1.h5')
model.summary()
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=10)
#loss,accuracy=model.evaluate(x_test,y_test)
 
model.save('m2.h5')
model.save_weights('m3.h5')

如上可见,我一共保存了m1.h5, m2.h5, m3.h5 这三个h5文件。那么,我们来看看这三个玩意儿有什么区别。首先,看看大小:

浅谈keras保存模型中的save()和save_weights()区别

m2表示save()保存的模型结果,它既保持了模型的图结构,又保存了模型的参数。所以它的size最大的。

m1表示save()保存的训练前的模型结果,它保存了模型的图结构,但应该没有保存模型的初始化参数,所以它的size要比m2小很多。

m3表示save_weights()保存的模型结果,它只保存了模型的参数,但并没有保存模型的图结构。所以它的size也要比m2小很多。

通过可视化工具,我们发现:(打开m1和m2均可以显示出以下结构)

浅谈keras保存模型中的save()和save_weights()区别

而打开m3的时候,可视化工具报错了。由此可以论证, save_weights()是不含有模型结构信息的。

加载模型

两种不同方法保存的模型文件也需要用不同的加载方法。

from keras.models import load_model
 
model = load_model('m1.h5')
#model = load_model('m2.h5')
#model = load_model('m3.h5')
model.summary()

只有加载m3.h5的时候,这段代码才会报错。其他输出如下:

浅谈keras保存模型中的save()和save_weights()区别

可见,由save()保存下来的h5文件才可以直接通过load_model()打开!

那么,我们保存下来的参数(m3.h5)该怎么打开呢?

这就稍微复杂一点了,因为m3不含有模型结构信息,所以我们需要把模型结构再描述一遍才可以加载m3,如下:

from keras.models import Model
from keras.layers import Input, Dense
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
model.load_weights('m3.h5')

以上把m3换成m1和m2也是没有问题的!可见,save()保存的模型除了占用内存大一点以外,其他的优点太明显了。所以,在不怎么缺硬盘空间的情况下,还是建议大家多用save()来存。

注意!如果要load_weights(),必须保证你描述的有参数计算结构与h5文件中完全一致!什么叫有参数计算结构呢?就是有参数坑,直接填进去就行了。我们把上面的非参数结构换了一下,发现h5文件依然可以加载成功,比如将softmax换成relu,依然不影响加载。

对于keras的save()和save_weights(),完全没问题了吧

以上这篇浅谈keras保存模型中的save()和save_weights()区别就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现猜数字游戏(无重复数字)示例分享
Mar 29 Python
Python中使用scapy模拟数据包实现arp攻击、dns放大攻击例子
Oct 23 Python
使用Python中的线程进行网络编程的入门教程
Apr 15 Python
Python中的rfind()方法使用详解
May 19 Python
R vs. Python 数据分析中谁与争锋?
Oct 18 Python
Python验证文件是否可读写代码分享
Dec 11 Python
python版本的仿windows计划任务工具
Apr 30 Python
python文件转为exe文件的方法及用法详解
Jul 08 Python
Pyspark读取parquet数据过程解析
Mar 27 Python
Python多线程thread及模块使用实例
Apr 28 Python
详解pycharm连接远程linux服务器的虚拟环境的方法
Nov 13 Python
OpenCV项目实践之停车场车位实时检测
Apr 11 Python
Python通过文本和图片生成词云图
May 21 #Python
解决在keras中使用model.save()函数保存模型失败的问题
May 21 #Python
Python 实现敏感目录扫描的示例代码
May 21 #Python
基于python检查矩阵计算结果
May 21 #Python
Django 解决由save方法引发的错误
May 21 #Python
Python OrderedDict字典排序方法详解
May 21 #Python
django中嵌套的try-except实例
May 21 #Python
You might like
php生成SessionID和图片校验码的思路和实现代码
2009/03/10 PHP
php安全配置 如何配置使其更安全
2011/12/16 PHP
实现WordPress主题侧边栏切换功能的PHP脚本详解
2015/12/14 PHP
PHP Curl模拟登录微信公众平台、新浪微博实例代码
2016/01/28 PHP
PHP信号量基本用法实例详解
2016/02/12 PHP
PHP生成图像验证码的方法小结(2种方法)
2016/07/18 PHP
php pdo连接数据库操作示例
2019/11/18 PHP
在次封装easyui-Dialog插件实现代码
2010/11/14 Javascript
JavaScript原生对象之Date对象的属性和方法详解
2015/03/13 Javascript
jQuery+html5实现div弹出层并遮罩背景
2015/04/15 Javascript
Ext JS动态加载JavaScript创建窗体的方法
2016/06/23 Javascript
AngularJS ng-change 指令的详解及简单实例
2016/07/30 Javascript
JQuery学习总结【一】
2016/12/01 Javascript
jQuery EasyUI ProgressBar进度条组件
2017/02/28 Javascript
jQuery判断网页是否已经滚动到浏览器底部的实现方法
2017/10/27 jQuery
微信小程序实现弹出菜单功能
2018/06/12 Javascript
微信小程序首页的分类功能和搜索功能的实现思路及代码详解
2018/09/11 Javascript
小程序实现左右来回滚动字幕效果
2018/12/28 Javascript
在 Vue.js中优雅地使用全局事件的方法
2019/02/01 Javascript
使用vue2.6实现抖音【时间轮盘】屏保效果附源码
2019/04/24 Javascript
详解搭建一个vue-cli的移动端H5开发模板
2020/01/17 Javascript
微信小程序实现树莓派(raspberry pi)小车控制
2020/02/12 Javascript
Python使用xlrd模块操作Excel数据导入的方法
2015/05/26 Python
在Django的模型中执行原始SQL查询的方法
2015/07/21 Python
Django实现学生管理系统
2019/02/26 Python
Win10系统下安装labelme及json文件批量转化方法
2019/07/30 Python
Python将主机名转换为IP地址的方法
2019/08/14 Python
python conda操作方法
2019/09/11 Python
CSS3 实现footer 固定在底部(无论页面多高始终在底部)
2019/10/15 HTML / CSS
澳大利亚窗帘商店:Curtain Wonderland
2019/12/01 全球购物
医学专业毕业生个人的求职信
2013/12/04 职场文书
中西医专业毕业生职业规划书
2014/02/24 职场文书
会议邀请函
2015/01/30 职场文书
总经理助理岗位职责
2015/01/31 职场文书
病危通知单
2015/04/17 职场文书
无保留意见审计报告
2015/06/05 职场文书