浅谈keras保存模型中的save()和save_weights()区别


Posted in Python onMay 21, 2020

今天做了一个关于keras保存模型的实验,希望有助于大家了解keras保存模型的区别。

我们知道keras的模型一般保存为后缀名为h5的文件,比如final_model.h5。同样是h5文件用save()和save_weight()保存效果是不一样的。

我们用宇宙最通用的数据集MNIST来做这个实验,首先设计一个两层全连接网络:

inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)

然后,导入MNIST数据训练,分别用两种方式保存模型,在这里我还把未训练的模型也保存下来,如下:

from keras.models import Model
from keras.layers import Input, Dense
from keras.datasets import mnist
from keras.utils import np_utils
 
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train=x_train.reshape(x_train.shape[0],-1)/255.0
x_test=x_test.reshape(x_test.shape[0],-1)/255.0
y_train=np_utils.to_categorical(y_train,num_classes=10)
y_test=np_utils.to_categorical(y_test,num_classes=10)
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
 
model.save('m1.h5')
model.summary()
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=10)
#loss,accuracy=model.evaluate(x_test,y_test)
 
model.save('m2.h5')
model.save_weights('m3.h5')

如上可见,我一共保存了m1.h5, m2.h5, m3.h5 这三个h5文件。那么,我们来看看这三个玩意儿有什么区别。首先,看看大小:

浅谈keras保存模型中的save()和save_weights()区别

m2表示save()保存的模型结果,它既保持了模型的图结构,又保存了模型的参数。所以它的size最大的。

m1表示save()保存的训练前的模型结果,它保存了模型的图结构,但应该没有保存模型的初始化参数,所以它的size要比m2小很多。

m3表示save_weights()保存的模型结果,它只保存了模型的参数,但并没有保存模型的图结构。所以它的size也要比m2小很多。

通过可视化工具,我们发现:(打开m1和m2均可以显示出以下结构)

浅谈keras保存模型中的save()和save_weights()区别

而打开m3的时候,可视化工具报错了。由此可以论证, save_weights()是不含有模型结构信息的。

加载模型

两种不同方法保存的模型文件也需要用不同的加载方法。

from keras.models import load_model
 
model = load_model('m1.h5')
#model = load_model('m2.h5')
#model = load_model('m3.h5')
model.summary()

只有加载m3.h5的时候,这段代码才会报错。其他输出如下:

浅谈keras保存模型中的save()和save_weights()区别

可见,由save()保存下来的h5文件才可以直接通过load_model()打开!

那么,我们保存下来的参数(m3.h5)该怎么打开呢?

这就稍微复杂一点了,因为m3不含有模型结构信息,所以我们需要把模型结构再描述一遍才可以加载m3,如下:

from keras.models import Model
from keras.layers import Input, Dense
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
model.load_weights('m3.h5')

以上把m3换成m1和m2也是没有问题的!可见,save()保存的模型除了占用内存大一点以外,其他的优点太明显了。所以,在不怎么缺硬盘空间的情况下,还是建议大家多用save()来存。

注意!如果要load_weights(),必须保证你描述的有参数计算结构与h5文件中完全一致!什么叫有参数计算结构呢?就是有参数坑,直接填进去就行了。我们把上面的非参数结构换了一下,发现h5文件依然可以加载成功,比如将softmax换成relu,依然不影响加载。

对于keras的save()和save_weights(),完全没问题了吧

以上这篇浅谈keras保存模型中的save()和save_weights()区别就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python发布模块的步骤分享
Feb 21 Python
Python 抓取动态网页内容方案详解
Dec 25 Python
Python中的多重装饰器
Apr 11 Python
Python中的条件判断语句与循环语句用法小结
Mar 21 Python
python 画三维图像 曲面图和散点图的示例
Dec 29 Python
Kears+Opencv实现简单人脸识别
Aug 28 Python
Python网络编程之使用TCP方式传输文件操作示例
Nov 01 Python
python3正则模块re的使用方法详解
Feb 11 Python
Python使用urllib模块对URL网址中的中文编码与解码实例详解
Feb 18 Python
浅谈ROC曲线的最佳阈值如何选取
Feb 28 Python
Python turtle库的画笔控制说明
Jun 28 Python
手残删除python之后的补救方法
Jun 26 Python
Python通过文本和图片生成词云图
May 21 #Python
解决在keras中使用model.save()函数保存模型失败的问题
May 21 #Python
Python 实现敏感目录扫描的示例代码
May 21 #Python
基于python检查矩阵计算结果
May 21 #Python
Django 解决由save方法引发的错误
May 21 #Python
Python OrderedDict字典排序方法详解
May 21 #Python
django中嵌套的try-except实例
May 21 #Python
You might like
php实现首页链接查询 友情链接检查的代码
2010/01/05 PHP
解析web文件操作常见安全漏洞(目录、文件名检测漏洞)
2013/06/29 PHP
提高PHP性能的编码技巧以及性能优化详细解析
2013/08/24 PHP
php5.4以上版本GBK编码下htmlspecialchars输出为空问题解决方法汇总
2015/04/03 PHP
PHP后台微信支付和支付宝支付开发
2017/04/28 PHP
javascript入门·动态的时钟,显示完整的一些方法,新年倒计时
2007/10/01 Javascript
jQuery提交表单ajax查询实例代码
2012/10/07 Javascript
Js 冒泡事件阻止实现代码
2013/01/27 Javascript
动态的创建一个元素createElement及删除一个元素
2014/01/24 Javascript
跟我学Nodejs(三)--- Node.js模块
2014/05/25 NodeJs
jQuery中;function($,undefined) 前面的分号的用处
2014/12/17 Javascript
javascript操作字符串的原生方法
2014/12/22 Javascript
JS实现仿微博可关闭弹出层效果
2015/09/21 Javascript
基于jQuery实现拖拽图标到回收站并删除功能
2015/11/25 Javascript
原生js实现验证码功能
2017/03/16 Javascript
Angular2使用Augury来调试Angular2程序
2017/05/21 Javascript
Vuejs 页面的区域化与组件封装的实现
2017/09/11 Javascript
微信小程序之GET请求的实例详解
2017/09/29 Javascript
关于Angularjs中跨域设置白名单问题
2018/04/17 Javascript
vue中解决chrome浏览器自动播放音频和MP3语音打包到线上的实现方法
2020/10/09 Javascript
vue使用swiper实现左右滑动切换图片
2020/10/16 Javascript
Python中不同进制的语法及转换方法分析
2016/07/27 Python
Python Pandas找到缺失值的位置方法
2018/04/12 Python
Python模块的加载讲解
2019/01/15 Python
Python3中列表list合并的四种方法
2019/04/19 Python
解决python3中的requests解析中文页面出现乱码问题
2019/04/19 Python
python顺序执行多个py文件的方法
2019/06/29 Python
Django执行源生mysql语句实现过程解析
2020/11/12 Python
韩国休闲女装品牌网站:ANAIS
2016/08/24 全球购物
国际鲜花速递专家:Floraqueen
2016/11/24 全球购物
销售总监工作职责
2013/11/21 职场文书
酒店开业庆典策划方案
2014/05/28 职场文书
格列佛游记读书笔记
2015/06/30 职场文书
2017寒假社会实践心得体会范文
2016/01/14 职场文书
Python 多线程之threading 模块的使用
2021/04/14 Python
python基础之匿名函数详解
2021/04/21 Python