浅谈keras保存模型中的save()和save_weights()区别


Posted in Python onMay 21, 2020

今天做了一个关于keras保存模型的实验,希望有助于大家了解keras保存模型的区别。

我们知道keras的模型一般保存为后缀名为h5的文件,比如final_model.h5。同样是h5文件用save()和save_weight()保存效果是不一样的。

我们用宇宙最通用的数据集MNIST来做这个实验,首先设计一个两层全连接网络:

inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)

然后,导入MNIST数据训练,分别用两种方式保存模型,在这里我还把未训练的模型也保存下来,如下:

from keras.models import Model
from keras.layers import Input, Dense
from keras.datasets import mnist
from keras.utils import np_utils
 
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train=x_train.reshape(x_train.shape[0],-1)/255.0
x_test=x_test.reshape(x_test.shape[0],-1)/255.0
y_train=np_utils.to_categorical(y_train,num_classes=10)
y_test=np_utils.to_categorical(y_test,num_classes=10)
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
 
model.save('m1.h5')
model.summary()
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=10)
#loss,accuracy=model.evaluate(x_test,y_test)
 
model.save('m2.h5')
model.save_weights('m3.h5')

如上可见,我一共保存了m1.h5, m2.h5, m3.h5 这三个h5文件。那么,我们来看看这三个玩意儿有什么区别。首先,看看大小:

浅谈keras保存模型中的save()和save_weights()区别

m2表示save()保存的模型结果,它既保持了模型的图结构,又保存了模型的参数。所以它的size最大的。

m1表示save()保存的训练前的模型结果,它保存了模型的图结构,但应该没有保存模型的初始化参数,所以它的size要比m2小很多。

m3表示save_weights()保存的模型结果,它只保存了模型的参数,但并没有保存模型的图结构。所以它的size也要比m2小很多。

通过可视化工具,我们发现:(打开m1和m2均可以显示出以下结构)

浅谈keras保存模型中的save()和save_weights()区别

而打开m3的时候,可视化工具报错了。由此可以论证, save_weights()是不含有模型结构信息的。

加载模型

两种不同方法保存的模型文件也需要用不同的加载方法。

from keras.models import load_model
 
model = load_model('m1.h5')
#model = load_model('m2.h5')
#model = load_model('m3.h5')
model.summary()

只有加载m3.h5的时候,这段代码才会报错。其他输出如下:

浅谈keras保存模型中的save()和save_weights()区别

可见,由save()保存下来的h5文件才可以直接通过load_model()打开!

那么,我们保存下来的参数(m3.h5)该怎么打开呢?

这就稍微复杂一点了,因为m3不含有模型结构信息,所以我们需要把模型结构再描述一遍才可以加载m3,如下:

from keras.models import Model
from keras.layers import Input, Dense
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
model.load_weights('m3.h5')

以上把m3换成m1和m2也是没有问题的!可见,save()保存的模型除了占用内存大一点以外,其他的优点太明显了。所以,在不怎么缺硬盘空间的情况下,还是建议大家多用save()来存。

注意!如果要load_weights(),必须保证你描述的有参数计算结构与h5文件中完全一致!什么叫有参数计算结构呢?就是有参数坑,直接填进去就行了。我们把上面的非参数结构换了一下,发现h5文件依然可以加载成功,比如将softmax换成relu,依然不影响加载。

对于keras的save()和save_weights(),完全没问题了吧

以上这篇浅谈keras保存模型中的save()和save_weights()区别就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
打开电脑上的QQ的python代码
Feb 10 Python
python简单获取数组元素个数的方法
Jul 13 Python
Django中传递参数到URLconf的视图函数中的方法
Jul 18 Python
Python的“二维”字典 (two-dimension dictionary)定义与实现方法
Apr 27 Python
深入学习Python中的装饰器使用
Jun 20 Python
python 判断是否为正小数和正整数的实例
Jul 23 Python
使用python画个小猪佩奇的示例代码
Jun 06 Python
pyqt5 实现多窗口跳转的方法
Jun 19 Python
python实现简单颜色识别程序
Feb 19 Python
基于Tensorflow一维卷积用法详解
May 22 Python
基于python tkinter的点名小程序功能的实例代码
Aug 22 Python
如何在scrapy中捕获并处理各种异常
Sep 28 Python
Python通过文本和图片生成词云图
May 21 #Python
解决在keras中使用model.save()函数保存模型失败的问题
May 21 #Python
Python 实现敏感目录扫描的示例代码
May 21 #Python
基于python检查矩阵计算结果
May 21 #Python
Django 解决由save方法引发的错误
May 21 #Python
Python OrderedDict字典排序方法详解
May 21 #Python
django中嵌套的try-except实例
May 21 #Python
You might like
PHP与javascript对多项选择的处理
2006/10/09 PHP
在apache下限制每个虚拟主机的并发数!!!!
2006/10/09 PHP
PHP 增加了对 .ZIP 文件的读取功能
2006/10/09 PHP
有关 PHP 和 MySQL 时区的一点总结
2008/03/26 PHP
vs中通过剪切板循环来循环粘贴不同内容
2011/04/30 PHP
php中操作memcached缓存进行增删改查数据的实现代码
2014/08/15 PHP
PHP读取CURL模拟登录时生成Cookie文件的方法
2014/11/04 PHP
老版本PHP转义Json里的特殊字符的函数
2015/06/08 PHP
PHP以json或xml格式返回请求数据的方法
2018/05/31 PHP
laravel框架模型和数据库基础操作实例详解
2020/01/25 PHP
PHP代码加密的方法总结
2020/03/13 PHP
jquery HotKeys轻松搞定键盘事件代码
2008/08/30 Javascript
让你的网站可编辑的实现js代码
2009/10/19 Javascript
THREE.JS入门教程(2)着色器-上
2013/01/24 Javascript
js文件包含的几种方式介绍
2014/09/28 Javascript
原生javascript实现图片滚动、延时加载功能
2015/01/12 Javascript
JavaScript hasOwnProperty() 函数实例详解
2017/08/04 Javascript
vue2.5.2使用http请求获取静态json数据的实例代码
2018/02/27 Javascript
nodejs实现的简单web服务器功能示例
2018/03/15 NodeJs
vue2.0 + ele的循环表单及验证字段方法
2018/09/18 Javascript
jQuery实现评论模块
2020/08/19 jQuery
[01:33]DOTA2上海特级锦标赛 LIQUID战队完整宣传片
2016/03/16 DOTA
Python提取Linux内核源代码的目录结构实现方法
2016/06/24 Python
python使用多进程的实例详解
2018/09/19 Python
python实现TCP文件传输
2020/03/20 Python
Python使用grequests并发发送请求的示例
2020/11/05 Python
使用django自带的user做外键的方法
2020/11/30 Python
CSS3动画:5种预载动画效果实例
2017/04/05 HTML / CSS
印度领先的在线时尚商店:Koovs
2016/08/28 全球购物
美国家用电器和电子产品商店:Abt
2016/09/06 全球购物
Snapfish爱尔兰:在线照片打印和个性化照片礼品
2018/09/17 全球购物
优秀学生干部个人的自我评价
2013/10/04 职场文书
饭店工作计划书
2014/01/10 职场文书
个人违纪检讨书
2014/09/15 职场文书
2015年政治教研组工作总结
2015/07/22 职场文书
将MySQL的表数据全量导入clichhouse库中
2022/03/21 MySQL