浅谈keras保存模型中的save()和save_weights()区别


Posted in Python onMay 21, 2020

今天做了一个关于keras保存模型的实验,希望有助于大家了解keras保存模型的区别。

我们知道keras的模型一般保存为后缀名为h5的文件,比如final_model.h5。同样是h5文件用save()和save_weight()保存效果是不一样的。

我们用宇宙最通用的数据集MNIST来做这个实验,首先设计一个两层全连接网络:

inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)

然后,导入MNIST数据训练,分别用两种方式保存模型,在这里我还把未训练的模型也保存下来,如下:

from keras.models import Model
from keras.layers import Input, Dense
from keras.datasets import mnist
from keras.utils import np_utils
 
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train=x_train.reshape(x_train.shape[0],-1)/255.0
x_test=x_test.reshape(x_test.shape[0],-1)/255.0
y_train=np_utils.to_categorical(y_train,num_classes=10)
y_test=np_utils.to_categorical(y_test,num_classes=10)
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
 
model.save('m1.h5')
model.summary()
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=10)
#loss,accuracy=model.evaluate(x_test,y_test)
 
model.save('m2.h5')
model.save_weights('m3.h5')

如上可见,我一共保存了m1.h5, m2.h5, m3.h5 这三个h5文件。那么,我们来看看这三个玩意儿有什么区别。首先,看看大小:

浅谈keras保存模型中的save()和save_weights()区别

m2表示save()保存的模型结果,它既保持了模型的图结构,又保存了模型的参数。所以它的size最大的。

m1表示save()保存的训练前的模型结果,它保存了模型的图结构,但应该没有保存模型的初始化参数,所以它的size要比m2小很多。

m3表示save_weights()保存的模型结果,它只保存了模型的参数,但并没有保存模型的图结构。所以它的size也要比m2小很多。

通过可视化工具,我们发现:(打开m1和m2均可以显示出以下结构)

浅谈keras保存模型中的save()和save_weights()区别

而打开m3的时候,可视化工具报错了。由此可以论证, save_weights()是不含有模型结构信息的。

加载模型

两种不同方法保存的模型文件也需要用不同的加载方法。

from keras.models import load_model
 
model = load_model('m1.h5')
#model = load_model('m2.h5')
#model = load_model('m3.h5')
model.summary()

只有加载m3.h5的时候,这段代码才会报错。其他输出如下:

浅谈keras保存模型中的save()和save_weights()区别

可见,由save()保存下来的h5文件才可以直接通过load_model()打开!

那么,我们保存下来的参数(m3.h5)该怎么打开呢?

这就稍微复杂一点了,因为m3不含有模型结构信息,所以我们需要把模型结构再描述一遍才可以加载m3,如下:

from keras.models import Model
from keras.layers import Input, Dense
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
model.load_weights('m3.h5')

以上把m3换成m1和m2也是没有问题的!可见,save()保存的模型除了占用内存大一点以外,其他的优点太明显了。所以,在不怎么缺硬盘空间的情况下,还是建议大家多用save()来存。

注意!如果要load_weights(),必须保证你描述的有参数计算结构与h5文件中完全一致!什么叫有参数计算结构呢?就是有参数坑,直接填进去就行了。我们把上面的非参数结构换了一下,发现h5文件依然可以加载成功,比如将softmax换成relu,依然不影响加载。

对于keras的save()和save_weights(),完全没问题了吧

以上这篇浅谈keras保存模型中的save()和save_weights()区别就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 的 with 语句详解
Jun 13 Python
Python里隐藏的“禅”
Jun 16 Python
详解Python中with语句的用法
Apr 15 Python
深入浅析Python中join 和 split详解(推荐)
Jun 30 Python
利用pyinstaller或virtualenv将python程序打包详解
Mar 22 Python
python去除扩展名的实例讲解
Apr 23 Python
Python异常处理操作实例详解
Aug 28 Python
Python格式化输出字符串方法小结【%与format】
Oct 29 Python
Python面向对象程序设计中类的定义、实例化、封装及私有变量/方法详解
Feb 28 Python
Python输出指定字符串的方法
Feb 06 Python
python使用建议与技巧分享(二)
Aug 17 Python
python用Tkinter做自己的中文代码编辑器
Sep 07 Python
Python通过文本和图片生成词云图
May 21 #Python
解决在keras中使用model.save()函数保存模型失败的问题
May 21 #Python
Python 实现敏感目录扫描的示例代码
May 21 #Python
基于python检查矩阵计算结果
May 21 #Python
Django 解决由save方法引发的错误
May 21 #Python
Python OrderedDict字典排序方法详解
May 21 #Python
django中嵌套的try-except实例
May 21 #Python
You might like
德生H-501的评价与改造
2021/03/02 无线电
PHP date()函数警告: It is not safe to rely on the system解决方法
2014/08/20 PHP
PHP实现适用于自定义的验证码类
2016/06/15 PHP
PHP通过bypass disable functions执行系统命令的方法汇总
2018/05/02 PHP
php app支付宝回调(异步通知)详解
2018/07/25 PHP
Laravel 自定命令以及生成文件的例子
2019/10/23 PHP
解放web程序员的输入验证
2006/10/06 Javascript
jQuery为iframe的body添加click事件的实现代码
2011/04/07 Javascript
用Javascript获取页面元素的具体位置
2013/12/09 Javascript
js实现超酷的照片墙展示效果图附源码下载
2015/10/08 Javascript
学习vue.js表单控件绑定操作
2016/12/05 Javascript
详解Javascript几种跨域方式总结
2017/02/27 Javascript
详解angular用$sce服务来过滤HTML标签
2017/04/11 Javascript
Angular 4依赖注入学习教程之简介(一)
2017/06/04 Javascript
vue中component组件的props使用详解
2017/09/04 Javascript
最适应的vue.js的form提交涉及多种插件【推荐】
2018/08/27 Javascript
详解Vue 全局变量,局部变量
2019/04/17 Javascript
Nodejs监听日志文件的变化的过程解析
2019/08/04 NodeJs
layui 数据表格+分页+搜索+checkbox+缓存选中项数据的方法
2019/09/21 Javascript
[59:35]DOTA2上海特级锦标赛主赛事日 - 3 败者组第三轮#1COL VS Alliance第二局
2016/03/04 DOTA
Python3.6正式版新特性预览
2016/12/15 Python
Python数据结构之顺序表的实现代码示例
2017/11/15 Python
python脚本执行CMD命令并返回结果的例子
2019/08/14 Python
tensorflow estimator 使用hook实现finetune方式
2020/01/21 Python
python中使用input()函数获取用户输入值方式
2020/05/03 Python
利用Python中的Xpath实现一个在线汇率转换器
2020/09/09 Python
Django前后端分离csrf token获取方式
2020/12/25 Python
外贸主管求职简历的自我评价
2013/10/23 职场文书
2014基层党员干部学习全国两会心得体会
2014/03/17 职场文书
2015年车间安全管理工作总结
2015/05/13 职场文书
画展观后感
2015/06/17 职场文书
小学生读书笔记范文
2015/06/30 职场文书
2015年食品安全宣传周活动总结
2015/07/09 职场文书
课题研究阶段性总结
2015/08/13 职场文书
Django集成富文本编辑器summernote的实现步骤
2021/05/31 Python
JavaScript原型链中函数和对象的理解
2022/06/16 Javascript