浅谈keras保存模型中的save()和save_weights()区别


Posted in Python onMay 21, 2020

今天做了一个关于keras保存模型的实验,希望有助于大家了解keras保存模型的区别。

我们知道keras的模型一般保存为后缀名为h5的文件,比如final_model.h5。同样是h5文件用save()和save_weight()保存效果是不一样的。

我们用宇宙最通用的数据集MNIST来做这个实验,首先设计一个两层全连接网络:

inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)

然后,导入MNIST数据训练,分别用两种方式保存模型,在这里我还把未训练的模型也保存下来,如下:

from keras.models import Model
from keras.layers import Input, Dense
from keras.datasets import mnist
from keras.utils import np_utils
 
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train=x_train.reshape(x_train.shape[0],-1)/255.0
x_test=x_test.reshape(x_test.shape[0],-1)/255.0
y_train=np_utils.to_categorical(y_train,num_classes=10)
y_test=np_utils.to_categorical(y_test,num_classes=10)
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
 
model.save('m1.h5')
model.summary()
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=10)
#loss,accuracy=model.evaluate(x_test,y_test)
 
model.save('m2.h5')
model.save_weights('m3.h5')

如上可见,我一共保存了m1.h5, m2.h5, m3.h5 这三个h5文件。那么,我们来看看这三个玩意儿有什么区别。首先,看看大小:

浅谈keras保存模型中的save()和save_weights()区别

m2表示save()保存的模型结果,它既保持了模型的图结构,又保存了模型的参数。所以它的size最大的。

m1表示save()保存的训练前的模型结果,它保存了模型的图结构,但应该没有保存模型的初始化参数,所以它的size要比m2小很多。

m3表示save_weights()保存的模型结果,它只保存了模型的参数,但并没有保存模型的图结构。所以它的size也要比m2小很多。

通过可视化工具,我们发现:(打开m1和m2均可以显示出以下结构)

浅谈keras保存模型中的save()和save_weights()区别

而打开m3的时候,可视化工具报错了。由此可以论证, save_weights()是不含有模型结构信息的。

加载模型

两种不同方法保存的模型文件也需要用不同的加载方法。

from keras.models import load_model
 
model = load_model('m1.h5')
#model = load_model('m2.h5')
#model = load_model('m3.h5')
model.summary()

只有加载m3.h5的时候,这段代码才会报错。其他输出如下:

浅谈keras保存模型中的save()和save_weights()区别

可见,由save()保存下来的h5文件才可以直接通过load_model()打开!

那么,我们保存下来的参数(m3.h5)该怎么打开呢?

这就稍微复杂一点了,因为m3不含有模型结构信息,所以我们需要把模型结构再描述一遍才可以加载m3,如下:

from keras.models import Model
from keras.layers import Input, Dense
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
model.load_weights('m3.h5')

以上把m3换成m1和m2也是没有问题的!可见,save()保存的模型除了占用内存大一点以外,其他的优点太明显了。所以,在不怎么缺硬盘空间的情况下,还是建议大家多用save()来存。

注意!如果要load_weights(),必须保证你描述的有参数计算结构与h5文件中完全一致!什么叫有参数计算结构呢?就是有参数坑,直接填进去就行了。我们把上面的非参数结构换了一下,发现h5文件依然可以加载成功,比如将softmax换成relu,依然不影响加载。

对于keras的save()和save_weights(),完全没问题了吧

以上这篇浅谈keras保存模型中的save()和save_weights()区别就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的设计模式编程入门指南
Apr 02 Python
Python字典操作简明总结
Apr 13 Python
python中单下划线_的常见用法总结
Jul 10 Python
python看某个模块的版本方法
Oct 16 Python
python将一个英文语句以单词为单位逆序排放的方法
Dec 20 Python
Python 调用PIL库失败的解决方法
Jan 08 Python
学习python分支结构
May 17 Python
django将数组传递给前台模板的方法
Aug 06 Python
Python完全识别验证码自动登录实例详解
Nov 24 Python
如何利用Python识别图片中的文字
May 31 Python
Python实现列表中非负数保留,负数转化为指定的数值方式
Jun 04 Python
Python用户自定义异常的实现
Dec 25 Python
Python通过文本和图片生成词云图
May 21 #Python
解决在keras中使用model.save()函数保存模型失败的问题
May 21 #Python
Python 实现敏感目录扫描的示例代码
May 21 #Python
基于python检查矩阵计算结果
May 21 #Python
Django 解决由save方法引发的错误
May 21 #Python
Python OrderedDict字典排序方法详解
May 21 #Python
django中嵌套的try-except实例
May 21 #Python
You might like
PHP代码判断设备是手机还是平板电脑(两种方法)
2015/10/19 PHP
PHP+Ajax实现验证码的实时验证
2016/07/20 PHP
实例讲解YII2中多表关联的使用方法
2017/07/21 PHP
PHP7 新增常量
2021/03/09 PHP
脚本吧 - 幻宇工作室用到js,超强推荐base.js
2006/12/23 Javascript
关于js datetime的那点事
2011/11/15 Javascript
基于jquery点击自以外任意处,关闭自身的代码
2012/02/10 Javascript
Nodejs中调用系统命令、Shell脚本和Python脚本的方法和实例
2015/01/01 NodeJs
JS打字效果的动态菜单代码分享
2015/08/21 Javascript
常用的JQuery函数及功能小结
2016/03/24 Javascript
深入理解JS中的substr和substring
2016/04/26 Javascript
利用jQuery对无序列表排序的简单方法
2016/10/16 Javascript
jQuery实现返回顶部按钮和scroll滚动功能[带动画效果]
2017/07/05 jQuery
angular框架实现全选与单选chekbox的自定义
2017/07/06 Javascript
如何去除vue项目中的#及其ie9兼容性
2018/01/11 Javascript
JS函数内部属性之arguments和this实例解析
2018/10/07 Javascript
js实现tab栏切换效果
2020/08/02 Javascript
[03:09]显微镜下的DOTA2第一期——带你走进华丽的DOTA2世界
2014/06/20 DOTA
[54:29]2018DOTA2亚洲邀请赛 4.7 淘汰赛 VP vs LGD 第二场
2018/04/09 DOTA
在Python中操作文件之truncate()方法的使用教程
2015/05/25 Python
python中管道用法入门实例
2015/06/04 Python
在Python的Django框架中调用方法和处理无效变量
2015/07/15 Python
PyTorch学习笔记之回归实战
2018/05/28 Python
TensorFlow的权值更新方法
2018/06/14 Python
python 用所有标点符号分隔句子的示例
2019/07/15 Python
python GUI库图形界面开发之PyQt5访问系统剪切板QClipboard类详细使用方法与实例
2020/02/27 Python
浅谈python 调用open()打开文件时路径出错的原因
2020/06/05 Python
HTML5 Canvas中绘制椭圆的4种方法
2015/04/24 HTML / CSS
加拿大奢华时装品牌:Mackage
2018/01/10 全球购物
大学生的应聘自我评价
2013/12/13 职场文书
中层干部培训方案
2014/06/16 职场文书
员工三分钟演讲稿
2014/08/19 职场文书
教师政风行风评议心得体会
2014/10/21 职场文书
2014年新农村建设工作总结
2014/12/01 职场文书
2015年保送生自荐信
2015/03/24 职场文书
js 实现验证码输入框示例详解
2022/09/23 Javascript