浅谈keras保存模型中的save()和save_weights()区别


Posted in Python onMay 21, 2020

今天做了一个关于keras保存模型的实验,希望有助于大家了解keras保存模型的区别。

我们知道keras的模型一般保存为后缀名为h5的文件,比如final_model.h5。同样是h5文件用save()和save_weight()保存效果是不一样的。

我们用宇宙最通用的数据集MNIST来做这个实验,首先设计一个两层全连接网络:

inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)

然后,导入MNIST数据训练,分别用两种方式保存模型,在这里我还把未训练的模型也保存下来,如下:

from keras.models import Model
from keras.layers import Input, Dense
from keras.datasets import mnist
from keras.utils import np_utils
 
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train=x_train.reshape(x_train.shape[0],-1)/255.0
x_test=x_test.reshape(x_test.shape[0],-1)/255.0
y_train=np_utils.to_categorical(y_train,num_classes=10)
y_test=np_utils.to_categorical(y_test,num_classes=10)
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
 
model.save('m1.h5')
model.summary()
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=10)
#loss,accuracy=model.evaluate(x_test,y_test)
 
model.save('m2.h5')
model.save_weights('m3.h5')

如上可见,我一共保存了m1.h5, m2.h5, m3.h5 这三个h5文件。那么,我们来看看这三个玩意儿有什么区别。首先,看看大小:

浅谈keras保存模型中的save()和save_weights()区别

m2表示save()保存的模型结果,它既保持了模型的图结构,又保存了模型的参数。所以它的size最大的。

m1表示save()保存的训练前的模型结果,它保存了模型的图结构,但应该没有保存模型的初始化参数,所以它的size要比m2小很多。

m3表示save_weights()保存的模型结果,它只保存了模型的参数,但并没有保存模型的图结构。所以它的size也要比m2小很多。

通过可视化工具,我们发现:(打开m1和m2均可以显示出以下结构)

浅谈keras保存模型中的save()和save_weights()区别

而打开m3的时候,可视化工具报错了。由此可以论证, save_weights()是不含有模型结构信息的。

加载模型

两种不同方法保存的模型文件也需要用不同的加载方法。

from keras.models import load_model
 
model = load_model('m1.h5')
#model = load_model('m2.h5')
#model = load_model('m3.h5')
model.summary()

只有加载m3.h5的时候,这段代码才会报错。其他输出如下:

浅谈keras保存模型中的save()和save_weights()区别

可见,由save()保存下来的h5文件才可以直接通过load_model()打开!

那么,我们保存下来的参数(m3.h5)该怎么打开呢?

这就稍微复杂一点了,因为m3不含有模型结构信息,所以我们需要把模型结构再描述一遍才可以加载m3,如下:

from keras.models import Model
from keras.layers import Input, Dense
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
model.load_weights('m3.h5')

以上把m3换成m1和m2也是没有问题的!可见,save()保存的模型除了占用内存大一点以外,其他的优点太明显了。所以,在不怎么缺硬盘空间的情况下,还是建议大家多用save()来存。

注意!如果要load_weights(),必须保证你描述的有参数计算结构与h5文件中完全一致!什么叫有参数计算结构呢?就是有参数坑,直接填进去就行了。我们把上面的非参数结构换了一下,发现h5文件依然可以加载成功,比如将softmax换成relu,依然不影响加载。

对于keras的save()和save_weights(),完全没问题了吧

以上这篇浅谈keras保存模型中的save()和save_weights()区别就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python mysqldb连接数据库
Mar 16 Python
从零学python系列之数据处理编程实例(一)
May 22 Python
Python接收Gmail新邮件并发送到gtalk的方法
Mar 10 Python
在Django的模型和公用函数中使用惰性翻译对象
Jul 27 Python
python爬虫实现中英翻译词典
Jun 25 Python
python实现windows倒计时锁屏功能
Jul 30 Python
Python从文件中读取指定的行以及在文件指定位置写入
Sep 06 Python
python实现的批量分析xml标签中各个类别个数功能示例
Dec 30 Python
Python递归及尾递归优化操作实例分析
Feb 01 Python
Python获取、格式化当前时间日期的方法
Feb 10 Python
pycharm设置当前工作目录的操作(working directory)
Feb 14 Python
如何基于python对接钉钉并获取access_token
Apr 21 Python
Python通过文本和图片生成词云图
May 21 #Python
解决在keras中使用model.save()函数保存模型失败的问题
May 21 #Python
Python 实现敏感目录扫描的示例代码
May 21 #Python
基于python检查矩阵计算结果
May 21 #Python
Django 解决由save方法引发的错误
May 21 #Python
Python OrderedDict字典排序方法详解
May 21 #Python
django中嵌套的try-except实例
May 21 #Python
You might like
PHP生成Gif图片验证码
2013/10/27 PHP
PHP eval函数使用介绍
2013/12/08 PHP
PHP使用正则表达式实现过滤非法字符串功能示例
2018/06/04 PHP
javascript的对话框详解与参数
2007/03/08 Javascript
JS字符串函数扩展代码
2011/09/13 Javascript
firefox下jquery ajax返回object XMLDocument处理方法
2014/01/26 Javascript
你可能不知道的JavaScript的new Function()方法
2014/04/17 Javascript
Javascript中的异步编程规范Promises/A详细介绍
2014/06/06 Javascript
Javascript递归打印Document层次关系实例分析
2015/05/15 Javascript
基于JavaScript制作霓虹灯文字 代码 特效
2015/09/01 Javascript
原生javascript实现匀速运动动画效果
2016/02/26 Javascript
Swiper实现轮播图效果
2017/07/03 Javascript
详解JS实现系统登录页的登录和验证
2019/04/29 Javascript
JavaScript实现简单计算器功能
2019/12/19 Javascript
JS三级联动代码格式实例详解
2019/12/30 Javascript
Vue插槽_特殊特性slot,slot-scope与指令v-slot说明
2020/09/04 Javascript
Python 字典与字符串的互转实例
2017/01/13 Python
matplotlib设置legend图例代码示例
2017/12/19 Python
对numpy和pandas中数组的合并和拆分详解
2018/04/11 Python
Anaconda 离线安装 python 包的操作方法
2018/06/11 Python
Python之指数与E记法的区别详解
2019/11/21 Python
Django media static外部访问Django中的图片设置教程
2020/04/07 Python
Python数据库封装实现代码示例解析
2020/09/05 Python
python录音并调用百度语音识别接口的示例
2020/12/01 Python
Shell如何接收变量输入
2012/09/24 面试题
应届生体育教师自荐信
2013/10/03 职场文书
高考备战决心书
2014/03/11 职场文书
项目投资建议书
2014/05/16 职场文书
酒店优秀员工推荐信
2015/03/24 职场文书
2015年大学班主任工作总结
2015/04/30 职场文书
酒桌上的开场白
2015/06/01 职场文书
钱学森观后感
2015/06/04 职场文书
地心历险记观后感
2015/06/15 职场文书
PySwarms(Python粒子群优化工具包)的使用:GlobalBestPSO例子解析
2021/04/05 Python
详解JavaScript中的执行上下文及调用堆栈
2021/04/29 Javascript
Tomcat执行startup.bat出现闪退的原因及解决办法
2022/04/20 Servers