浅谈keras保存模型中的save()和save_weights()区别


Posted in Python onMay 21, 2020

今天做了一个关于keras保存模型的实验,希望有助于大家了解keras保存模型的区别。

我们知道keras的模型一般保存为后缀名为h5的文件,比如final_model.h5。同样是h5文件用save()和save_weight()保存效果是不一样的。

我们用宇宙最通用的数据集MNIST来做这个实验,首先设计一个两层全连接网络:

inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)

然后,导入MNIST数据训练,分别用两种方式保存模型,在这里我还把未训练的模型也保存下来,如下:

from keras.models import Model
from keras.layers import Input, Dense
from keras.datasets import mnist
from keras.utils import np_utils
 
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train=x_train.reshape(x_train.shape[0],-1)/255.0
x_test=x_test.reshape(x_test.shape[0],-1)/255.0
y_train=np_utils.to_categorical(y_train,num_classes=10)
y_test=np_utils.to_categorical(y_test,num_classes=10)
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
 
model.save('m1.h5')
model.summary()
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=10)
#loss,accuracy=model.evaluate(x_test,y_test)
 
model.save('m2.h5')
model.save_weights('m3.h5')

如上可见,我一共保存了m1.h5, m2.h5, m3.h5 这三个h5文件。那么,我们来看看这三个玩意儿有什么区别。首先,看看大小:

浅谈keras保存模型中的save()和save_weights()区别

m2表示save()保存的模型结果,它既保持了模型的图结构,又保存了模型的参数。所以它的size最大的。

m1表示save()保存的训练前的模型结果,它保存了模型的图结构,但应该没有保存模型的初始化参数,所以它的size要比m2小很多。

m3表示save_weights()保存的模型结果,它只保存了模型的参数,但并没有保存模型的图结构。所以它的size也要比m2小很多。

通过可视化工具,我们发现:(打开m1和m2均可以显示出以下结构)

浅谈keras保存模型中的save()和save_weights()区别

而打开m3的时候,可视化工具报错了。由此可以论证, save_weights()是不含有模型结构信息的。

加载模型

两种不同方法保存的模型文件也需要用不同的加载方法。

from keras.models import load_model
 
model = load_model('m1.h5')
#model = load_model('m2.h5')
#model = load_model('m3.h5')
model.summary()

只有加载m3.h5的时候,这段代码才会报错。其他输出如下:

浅谈keras保存模型中的save()和save_weights()区别

可见,由save()保存下来的h5文件才可以直接通过load_model()打开!

那么,我们保存下来的参数(m3.h5)该怎么打开呢?

这就稍微复杂一点了,因为m3不含有模型结构信息,所以我们需要把模型结构再描述一遍才可以加载m3,如下:

from keras.models import Model
from keras.layers import Input, Dense
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
model.load_weights('m3.h5')

以上把m3换成m1和m2也是没有问题的!可见,save()保存的模型除了占用内存大一点以外,其他的优点太明显了。所以,在不怎么缺硬盘空间的情况下,还是建议大家多用save()来存。

注意!如果要load_weights(),必须保证你描述的有参数计算结构与h5文件中完全一致!什么叫有参数计算结构呢?就是有参数坑,直接填进去就行了。我们把上面的非参数结构换了一下,发现h5文件依然可以加载成功,比如将softmax换成relu,依然不影响加载。

对于keras的save()和save_weights(),完全没问题了吧

以上这篇浅谈keras保存模型中的save()和save_weights()区别就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python生成随机验证码(中文验证码)示例
Apr 03 Python
python写的ARP攻击代码实例
Jun 04 Python
Python基于pygame实现的弹力球效果(附源码)
Nov 11 Python
Python实现PS滤镜的万花筒效果示例
Jan 23 Python
Python使用Dijkstra算法实现求解图中最短路径距离问题详解
May 16 Python
Python中将两个或多个list合成一个list的方法小结
May 12 Python
python 整数越界问题详解
Jun 27 Python
查看端口并杀进程python脚本代码
Dec 17 Python
python获取引用对象的个数方式
Dec 20 Python
python Matplotlib数据可视化(2):详解三大容器对象与常用设置
Sep 30 Python
利用Python第三方库实现预测NBA比赛结果
Jun 21 Python
python基础之类方法和静态方法
Oct 24 Python
Python通过文本和图片生成词云图
May 21 #Python
解决在keras中使用model.save()函数保存模型失败的问题
May 21 #Python
Python 实现敏感目录扫描的示例代码
May 21 #Python
基于python检查矩阵计算结果
May 21 #Python
Django 解决由save方法引发的错误
May 21 #Python
Python OrderedDict字典排序方法详解
May 21 #Python
django中嵌套的try-except实例
May 21 #Python
You might like
用PHP实现文件上传二法
2006/10/09 PHP
如何给phpadmin一个保护
2006/10/09 PHP
SESSION信息保存在哪个文件目录下以及能够用来保存什么类型的数据
2012/06/17 PHP
解析如何在PHP下载文件名中解决乱码的问题
2013/06/20 PHP
php判断一个数组是否为有序的方法
2015/03/27 PHP
PHP实现对文件锁进行加锁、解锁操作的方法
2017/07/04 PHP
YII2框架中使用RBAC对模块,控制器,方法的权限控制及规则的使用示例
2020/03/18 PHP
Extjs学习笔记之二 初识Extjs之Form
2010/01/07 Javascript
Javascript 面向对象 重载
2010/05/13 Javascript
ASP 过滤数组重复数据函数(加强版)
2010/05/31 Javascript
javascript题目,重写函数让其无限相加
2012/02/15 Javascript
JavaScript中的style.cssText使用教程
2014/11/06 Javascript
javascript事件冒泡和事件捕获详解
2015/05/26 Javascript
javascript 判断一个对象为数组的方法
2017/05/03 Javascript
jQuery实现切换隐藏与显示同时切换图标功能
2017/10/29 jQuery
vue中的计算属性的使用和vue实例的方法示例
2017/12/04 Javascript
小程序中this.setData的使用和注意事项
2019/08/28 Javascript
jQuery实现手风琴效果(蒙版)
2020/01/11 jQuery
[09:37]2018DOTA2国际邀请赛寻真——不懈追梦的Team Serenity
2018/08/13 DOTA
[11:12]2018DOTA2国际邀请赛寻真——绿色长城OpTic
2018/08/10 DOTA
[59:30]VG vs LGD 2019国际邀请赛淘汰赛 胜者组 BO3 第二场 8.22
2019/09/05 DOTA
[42:06]2019国际邀请赛全明星赛 8.23
2019/09/05 DOTA
Python常用算法学习基础教程
2017/04/13 Python
在Python 不同级目录之间模块的调用方法
2019/01/19 Python
Django集成搜索引擎Elasticserach的方法示例
2019/06/04 Python
使用python制作游戏下载进度条的代码(程序说明见注释)
2019/10/24 Python
Python中如何将一个类方法变为多个方法
2019/12/30 Python
pycharm通过ssh连接远程服务器教程
2020/02/12 Python
绝对令人的惊叹的CSS3折叠效果(3D效果)整理
2012/12/30 HTML / CSS
公司营业员的工作总结自我评价
2013/10/05 职场文书
个人社会实践自我鉴定
2014/03/24 职场文书
食品销售计划书
2014/04/26 职场文书
政风行风整改报告
2014/11/06 职场文书
网络妈妈观后感
2015/06/08 职场文书
2016国培学习心得体会
2016/01/08 职场文书
导游词之舟山普陀山
2019/11/06 职场文书