浅谈keras保存模型中的save()和save_weights()区别


Posted in Python onMay 21, 2020

今天做了一个关于keras保存模型的实验,希望有助于大家了解keras保存模型的区别。

我们知道keras的模型一般保存为后缀名为h5的文件,比如final_model.h5。同样是h5文件用save()和save_weight()保存效果是不一样的。

我们用宇宙最通用的数据集MNIST来做这个实验,首先设计一个两层全连接网络:

inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)

然后,导入MNIST数据训练,分别用两种方式保存模型,在这里我还把未训练的模型也保存下来,如下:

from keras.models import Model
from keras.layers import Input, Dense
from keras.datasets import mnist
from keras.utils import np_utils
 
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train=x_train.reshape(x_train.shape[0],-1)/255.0
x_test=x_test.reshape(x_test.shape[0],-1)/255.0
y_train=np_utils.to_categorical(y_train,num_classes=10)
y_test=np_utils.to_categorical(y_test,num_classes=10)
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
 
model.save('m1.h5')
model.summary()
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=10)
#loss,accuracy=model.evaluate(x_test,y_test)
 
model.save('m2.h5')
model.save_weights('m3.h5')

如上可见,我一共保存了m1.h5, m2.h5, m3.h5 这三个h5文件。那么,我们来看看这三个玩意儿有什么区别。首先,看看大小:

浅谈keras保存模型中的save()和save_weights()区别

m2表示save()保存的模型结果,它既保持了模型的图结构,又保存了模型的参数。所以它的size最大的。

m1表示save()保存的训练前的模型结果,它保存了模型的图结构,但应该没有保存模型的初始化参数,所以它的size要比m2小很多。

m3表示save_weights()保存的模型结果,它只保存了模型的参数,但并没有保存模型的图结构。所以它的size也要比m2小很多。

通过可视化工具,我们发现:(打开m1和m2均可以显示出以下结构)

浅谈keras保存模型中的save()和save_weights()区别

而打开m3的时候,可视化工具报错了。由此可以论证, save_weights()是不含有模型结构信息的。

加载模型

两种不同方法保存的模型文件也需要用不同的加载方法。

from keras.models import load_model
 
model = load_model('m1.h5')
#model = load_model('m2.h5')
#model = load_model('m3.h5')
model.summary()

只有加载m3.h5的时候,这段代码才会报错。其他输出如下:

浅谈keras保存模型中的save()和save_weights()区别

可见,由save()保存下来的h5文件才可以直接通过load_model()打开!

那么,我们保存下来的参数(m3.h5)该怎么打开呢?

这就稍微复杂一点了,因为m3不含有模型结构信息,所以我们需要把模型结构再描述一遍才可以加载m3,如下:

from keras.models import Model
from keras.layers import Input, Dense
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
model.load_weights('m3.h5')

以上把m3换成m1和m2也是没有问题的!可见,save()保存的模型除了占用内存大一点以外,其他的优点太明显了。所以,在不怎么缺硬盘空间的情况下,还是建议大家多用save()来存。

注意!如果要load_weights(),必须保证你描述的有参数计算结构与h5文件中完全一致!什么叫有参数计算结构呢?就是有参数坑,直接填进去就行了。我们把上面的非参数结构换了一下,发现h5文件依然可以加载成功,比如将softmax换成relu,依然不影响加载。

对于keras的save()和save_weights(),完全没问题了吧

以上这篇浅谈keras保存模型中的save()和save_weights()区别就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python的SimPy库简化复杂的编程模型的介绍
Apr 13 Python
使用Python实现BT种子和磁力链接的相互转换
Nov 09 Python
python Django模板的使用方法
Jan 14 Python
Python处理命令行参数模块optpars用法实例分析
May 31 Python
python画一个玫瑰和一个爱心
Aug 18 Python
解决python3捕获cx_oracle抛出的异常错误问题
Oct 18 Python
Python爬虫文件下载图文教程
Dec 23 Python
神经网络相关之基础概念的讲解
Dec 29 Python
Pandas中DataFrame的分组/分割/合并的实现
Jul 16 Python
Python *args和**kwargs用法实例解析
Mar 02 Python
python3 循环读取excel文件并写入json操作
Jul 14 Python
Pandas||过滤缺失数据||pd.dropna()函数的用法说明
May 14 Python
Python通过文本和图片生成词云图
May 21 #Python
解决在keras中使用model.save()函数保存模型失败的问题
May 21 #Python
Python 实现敏感目录扫描的示例代码
May 21 #Python
基于python检查矩阵计算结果
May 21 #Python
Django 解决由save方法引发的错误
May 21 #Python
Python OrderedDict字典排序方法详解
May 21 #Python
django中嵌套的try-except实例
May 21 #Python
You might like
php下几个常用的去空、分组、调试数组函数
2009/02/22 PHP
PHP数组的交集array_intersect(),array_intersect_assoc(),array_inter_key()函数的小问题
2011/05/29 PHP
PHPThumb PHP 图片缩略图库
2012/03/11 PHP
php自定义分页类完整实例
2015/12/25 PHP
Centos PHP 扩展Xchche的安装教程
2016/07/09 PHP
thinkPHP显示不出验证码的原因与解决方法分析
2017/05/20 PHP
在Yii2特定页面如何禁用调试工具栏Debug Toolbar详解
2017/08/07 PHP
PHP实现的mysql读写分离操作示例
2018/05/22 PHP
laravel excel 上传文件保存到本地服务器功能
2019/11/14 PHP
JS和Jquery获取和修改label的值的示例代码
2014/01/15 Javascript
浅谈JS中String()与 .toString()的区别
2016/10/20 Javascript
原生JavaScript实现Tooltip浮动提示框特效
2017/03/07 Javascript
js中的触发事件对象event.srcElement与event.target详解
2017/03/15 Javascript
VS Code转换大小写、修改选中文字或代码颜色的方法
2017/12/15 Javascript
jQuery实现的自定义轮播图功能详解
2018/12/28 jQuery
ES6知识点整理之函数对象参数默认值及其解构应用示例
2019/04/17 Javascript
JavaScript函数式编程(Functional Programming)箭头函数(Arrow functions)用法分析
2019/05/22 Javascript
Vue使用axios出现options请求方法
2019/05/30 Javascript
实现vuex与组件data之间的数据同步更新方式
2019/11/12 Javascript
浅谈vue 多个变量同时赋相同值互相影响
2020/08/05 Javascript
Python多进程同步Lock、Semaphore、Event实例
2014/11/21 Python
Python实现线程池代码分享
2015/06/21 Python
Python编程pygame模块实现移动的小车示例代码
2018/01/03 Python
快速了解Python中的装饰器
2018/01/11 Python
对python中矩阵相加函数sum()的使用详解
2019/01/28 Python
python 使用装饰器并记录log的示例代码
2019/07/12 Python
ansible-playbook实现自动部署KVM及安装python3的详细教程
2020/05/11 Python
CSS3.0实现霓虹灯按钮动画特效的示例代码
2021/01/12 HTML / CSS
英国设计的甲板鞋和船鞋:Chatham
2018/12/06 全球购物
应付会计岗位职责
2013/12/12 职场文书
简单租房协议书范本
2014/08/20 职场文书
政府个人对照检查材料
2014/08/28 职场文书
2014年幼儿园工作总结
2014/11/10 职场文书
办公室年度工作总结2015
2015/05/21 职场文书
节约用水广告语60条
2019/11/14 职场文书
Mybatis-Plus 使用 @TableField 自动填充日期
2022/04/26 Java/Android