浅谈keras保存模型中的save()和save_weights()区别


Posted in Python onMay 21, 2020

今天做了一个关于keras保存模型的实验,希望有助于大家了解keras保存模型的区别。

我们知道keras的模型一般保存为后缀名为h5的文件,比如final_model.h5。同样是h5文件用save()和save_weight()保存效果是不一样的。

我们用宇宙最通用的数据集MNIST来做这个实验,首先设计一个两层全连接网络:

inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)

然后,导入MNIST数据训练,分别用两种方式保存模型,在这里我还把未训练的模型也保存下来,如下:

from keras.models import Model
from keras.layers import Input, Dense
from keras.datasets import mnist
from keras.utils import np_utils
 
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train=x_train.reshape(x_train.shape[0],-1)/255.0
x_test=x_test.reshape(x_test.shape[0],-1)/255.0
y_train=np_utils.to_categorical(y_train,num_classes=10)
y_test=np_utils.to_categorical(y_test,num_classes=10)
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
 
model.save('m1.h5')
model.summary()
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=10)
#loss,accuracy=model.evaluate(x_test,y_test)
 
model.save('m2.h5')
model.save_weights('m3.h5')

如上可见,我一共保存了m1.h5, m2.h5, m3.h5 这三个h5文件。那么,我们来看看这三个玩意儿有什么区别。首先,看看大小:

浅谈keras保存模型中的save()和save_weights()区别

m2表示save()保存的模型结果,它既保持了模型的图结构,又保存了模型的参数。所以它的size最大的。

m1表示save()保存的训练前的模型结果,它保存了模型的图结构,但应该没有保存模型的初始化参数,所以它的size要比m2小很多。

m3表示save_weights()保存的模型结果,它只保存了模型的参数,但并没有保存模型的图结构。所以它的size也要比m2小很多。

通过可视化工具,我们发现:(打开m1和m2均可以显示出以下结构)

浅谈keras保存模型中的save()和save_weights()区别

而打开m3的时候,可视化工具报错了。由此可以论证, save_weights()是不含有模型结构信息的。

加载模型

两种不同方法保存的模型文件也需要用不同的加载方法。

from keras.models import load_model
 
model = load_model('m1.h5')
#model = load_model('m2.h5')
#model = load_model('m3.h5')
model.summary()

只有加载m3.h5的时候,这段代码才会报错。其他输出如下:

浅谈keras保存模型中的save()和save_weights()区别

可见,由save()保存下来的h5文件才可以直接通过load_model()打开!

那么,我们保存下来的参数(m3.h5)该怎么打开呢?

这就稍微复杂一点了,因为m3不含有模型结构信息,所以我们需要把模型结构再描述一遍才可以加载m3,如下:

from keras.models import Model
from keras.layers import Input, Dense
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
model.load_weights('m3.h5')

以上把m3换成m1和m2也是没有问题的!可见,save()保存的模型除了占用内存大一点以外,其他的优点太明显了。所以,在不怎么缺硬盘空间的情况下,还是建议大家多用save()来存。

注意!如果要load_weights(),必须保证你描述的有参数计算结构与h5文件中完全一致!什么叫有参数计算结构呢?就是有参数坑,直接填进去就行了。我们把上面的非参数结构换了一下,发现h5文件依然可以加载成功,比如将softmax换成relu,依然不影响加载。

对于keras的save()和save_weights(),完全没问题了吧

以上这篇浅谈keras保存模型中的save()和save_weights()区别就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python通过scapy获取局域网所有主机mac地址示例
May 04 Python
简单的抓取淘宝图片的Python爬虫
Dec 25 Python
用Python进行一些简单的自然语言处理的教程
Mar 31 Python
Python中设置变量作为默认值时容易遇到的错误
Apr 03 Python
python简单线程和协程学习心得(分享)
Jun 14 Python
Python面向对象程序设计OOP深入分析【构造函数,组合类,工具类等】
Jan 05 Python
python 将字符串完成特定的向右移动方法
Jun 11 Python
python Dijkstra算法实现最短路径问题的方法
Sep 19 Python
django框架两个使用模板实例
Dec 11 Python
python实现简单的购物程序代码实例
Mar 03 Python
Python一行代码实现自动发邮件功能
May 30 Python
利用python进行数据加载
Jun 20 Python
Python通过文本和图片生成词云图
May 21 #Python
解决在keras中使用model.save()函数保存模型失败的问题
May 21 #Python
Python 实现敏感目录扫描的示例代码
May 21 #Python
基于python检查矩阵计算结果
May 21 #Python
Django 解决由save方法引发的错误
May 21 #Python
Python OrderedDict字典排序方法详解
May 21 #Python
django中嵌套的try-except实例
May 21 #Python
You might like
PHP 的 __FILE__ 常量
2007/01/15 PHP
实用函数8
2007/11/08 PHP
PHP生成sitemap.xml地图函数
2013/11/13 PHP
thinkPHP下ueditor的使用方法详解
2015/12/26 PHP
PHP实现双链表删除与插入节点的方法示例
2017/11/11 PHP
alixixi runcode.asp的代码不错的应用
2007/08/08 Javascript
javascript 拖动表格行实现代码
2011/05/05 Javascript
JS去除右边逗号的简单方法
2013/07/03 Javascript
js和php如何获取当前url的内容
2013/09/22 Javascript
jQuery选择器源码解读(六):Sizzle选择器匹配逻辑分析
2015/03/31 Javascript
JS中获取函数调用链所有参数的方法
2015/05/07 Javascript
使用javascript提交form表单方法汇总
2015/06/25 Javascript
javascript先序遍历DOM树的方法
2016/02/27 Javascript
使用 stylelint检查CSS_StyleLint
2016/04/28 Javascript
js完整倒计时代码分享
2016/09/18 Javascript
JS去掉字符串前后空格或去掉所有空格的用法
2017/03/25 Javascript
Textarea输入字数限制实例(兼容iOS&安卓)
2017/07/06 Javascript
在一般处理程序(ashx)中弹出js提示语
2017/08/16 Javascript
echarts实现词云自定义形状的示例代码
2019/02/20 Javascript
详解element-ui中el-select的默认选择项问题
2019/08/02 Javascript
vue悬浮可拖拽悬浮按钮的实例代码
2019/08/20 Javascript
[03:05]《我与DAC》之xiao8:DAC与BG
2018/03/27 DOTA
进一步理解Python中的函数编程
2015/04/13 Python
python的pdb调试命令的命令整理及实例
2017/07/12 Python
python遍历文件夹找出文件夹后缀为py的文件方法
2018/10/21 Python
Python docx库用法示例分析
2019/02/16 Python
flask框架路由常用定义方式总结
2019/07/23 Python
python实现替换word中的关键文字(使用通配符)
2020/02/13 Python
Python进程间通信multiprocess代码实例
2020/03/18 Python
美国Randolph太阳镜官网:美国制造的飞行员太阳镜和射击眼镜
2018/06/15 全球购物
介绍一下Ruby中的对象,属性和方法
2012/07/11 面试题
低碳生活的宣传标语
2014/06/23 职场文书
离婚协议书范本2014
2014/10/27 职场文书
2014年数学教师工作总结
2014/12/03 职场文书
债务纠纷起诉书
2015/05/20 职场文书
postgresql 删除重复数据案例详解
2021/08/02 PostgreSQL