浅谈keras保存模型中的save()和save_weights()区别


Posted in Python onMay 21, 2020

今天做了一个关于keras保存模型的实验,希望有助于大家了解keras保存模型的区别。

我们知道keras的模型一般保存为后缀名为h5的文件,比如final_model.h5。同样是h5文件用save()和save_weight()保存效果是不一样的。

我们用宇宙最通用的数据集MNIST来做这个实验,首先设计一个两层全连接网络:

inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)

然后,导入MNIST数据训练,分别用两种方式保存模型,在这里我还把未训练的模型也保存下来,如下:

from keras.models import Model
from keras.layers import Input, Dense
from keras.datasets import mnist
from keras.utils import np_utils
 
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train=x_train.reshape(x_train.shape[0],-1)/255.0
x_test=x_test.reshape(x_test.shape[0],-1)/255.0
y_train=np_utils.to_categorical(y_train,num_classes=10)
y_test=np_utils.to_categorical(y_test,num_classes=10)
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
 
model.save('m1.h5')
model.summary()
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=10)
#loss,accuracy=model.evaluate(x_test,y_test)
 
model.save('m2.h5')
model.save_weights('m3.h5')

如上可见,我一共保存了m1.h5, m2.h5, m3.h5 这三个h5文件。那么,我们来看看这三个玩意儿有什么区别。首先,看看大小:

浅谈keras保存模型中的save()和save_weights()区别

m2表示save()保存的模型结果,它既保持了模型的图结构,又保存了模型的参数。所以它的size最大的。

m1表示save()保存的训练前的模型结果,它保存了模型的图结构,但应该没有保存模型的初始化参数,所以它的size要比m2小很多。

m3表示save_weights()保存的模型结果,它只保存了模型的参数,但并没有保存模型的图结构。所以它的size也要比m2小很多。

通过可视化工具,我们发现:(打开m1和m2均可以显示出以下结构)

浅谈keras保存模型中的save()和save_weights()区别

而打开m3的时候,可视化工具报错了。由此可以论证, save_weights()是不含有模型结构信息的。

加载模型

两种不同方法保存的模型文件也需要用不同的加载方法。

from keras.models import load_model
 
model = load_model('m1.h5')
#model = load_model('m2.h5')
#model = load_model('m3.h5')
model.summary()

只有加载m3.h5的时候,这段代码才会报错。其他输出如下:

浅谈keras保存模型中的save()和save_weights()区别

可见,由save()保存下来的h5文件才可以直接通过load_model()打开!

那么,我们保存下来的参数(m3.h5)该怎么打开呢?

这就稍微复杂一点了,因为m3不含有模型结构信息,所以我们需要把模型结构再描述一遍才可以加载m3,如下:

from keras.models import Model
from keras.layers import Input, Dense
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
model.load_weights('m3.h5')

以上把m3换成m1和m2也是没有问题的!可见,save()保存的模型除了占用内存大一点以外,其他的优点太明显了。所以,在不怎么缺硬盘空间的情况下,还是建议大家多用save()来存。

注意!如果要load_weights(),必须保证你描述的有参数计算结构与h5文件中完全一致!什么叫有参数计算结构呢?就是有参数坑,直接填进去就行了。我们把上面的非参数结构换了一下,发现h5文件依然可以加载成功,比如将softmax换成relu,依然不影响加载。

对于keras的save()和save_weights(),完全没问题了吧

以上这篇浅谈keras保存模型中的save()和save_weights()区别就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中is与==判断的区别
Mar 28 Python
Python利用pandas处理Excel数据的应用详解
Jun 18 Python
django 中的聚合函数,分组函数,F 查询,Q查询
Jul 25 Python
Python 使用 prettytable 库打印表格美化输出功能
Dec 26 Python
15行Python代码实现免费发送手机短信推送消息功能
Feb 27 Python
Python使用sys.exc_info()方法获取异常信息
Jul 23 Python
Python容器类型公共方法总结
Aug 19 Python
python 贪心算法的实现
Sep 18 Python
python statsmodel的使用
Dec 21 Python
python基于opencv 实现图像时钟
Jan 04 Python
详解python的变量缓存机制
Jan 24 Python
详解python网络进程
Jun 15 Python
Python通过文本和图片生成词云图
May 21 #Python
解决在keras中使用model.save()函数保存模型失败的问题
May 21 #Python
Python 实现敏感目录扫描的示例代码
May 21 #Python
基于python检查矩阵计算结果
May 21 #Python
Django 解决由save方法引发的错误
May 21 #Python
Python OrderedDict字典排序方法详解
May 21 #Python
django中嵌套的try-except实例
May 21 #Python
You might like
php预定义变量使用帮助(带实例)
2013/10/30 PHP
php商品对比功能代码分享
2015/09/24 PHP
thinkPHP框架对接支付宝即时到账接口回调操作示例
2016/11/14 PHP
利用PHP扩展Xhprof分析项目性能实践教程
2018/09/05 PHP
微信公众平台开发教程④ ThinkPHP框架下微信支付功能图文详解
2019/04/10 PHP
WordPress伪静态规则设置代码实例
2020/12/10 PHP
jQuery ajax serialize()方法的使用以及常见问题解决
2013/01/27 Javascript
jquery 判断滚动条到达了底部和顶端的方法
2014/04/02 Javascript
详解JavaScript中的every()方法
2015/06/08 Javascript
基于jQuery和hwSlider实现内容左右滑动切换效果附源码下载(一)
2016/06/22 Javascript
JS+CSS3实现超炫的散列画廊特效
2016/07/16 Javascript
详解如何制作并发布一个vue的组件的npm包
2018/11/10 Javascript
详解vuex之store拆分即多模块状态管理(modules)篇
2018/11/13 Javascript
vue双向绑定及观察者模式详解
2019/03/19 Javascript
在Vue mounted方法中使用data变量详解
2019/11/05 Javascript
原生JS实现留言板
2020/03/26 Javascript
vue:el-input输入时限制输入的类型操作
2020/08/05 Javascript
Python中基本的日期时间处理的学习教程
2015/10/16 Python
教你用Python创建微信聊天机器人
2020/03/31 Python
Python实现DDos攻击实例详解
2019/02/02 Python
Python人脸识别第三方库face_recognition接口说明文档
2019/05/03 Python
python飞机大战pygame游戏背景设计详解
2019/12/17 Python
最新PyCharm 2020.2.3永久激活码(亲测有效)
2020/11/26 Python
美国男女折扣服饰百货连锁店:Stein Mart
2017/05/02 全球购物
js实现弹框效果
2021/03/24 Javascript
大学生大二自我鉴定
2013/10/28 职场文书
交通安全演讲稿
2014/01/07 职场文书
大学生军训自我鉴定
2014/02/12 职场文书
竞聘报告优秀范文
2014/11/06 职场文书
2014年接待工作总结
2014/11/26 职场文书
2015年父亲节寄语
2015/03/23 职场文书
原告离婚代理词
2015/05/23 职场文书
法律意见书范文
2015/06/04 职场文书
小学大队长竞选稿
2015/11/20 职场文书
利用ajax+php实现商品价格计算
2021/03/31 PHP
python使用glob检索文件的操作
2021/05/20 Python