tensorflow2.0保存和恢复模型3种方法


Posted in Python onFebruary 03, 2020

方法1:只保存模型的权重和偏置

这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创建和之前模型一模一样的模型,以保证权重和偏置的维度和保存之前的相同。

tf.keras.model类中的save_weights方法和load_weights方法,参数解释我就直接搬运官网的内容了。

save_weights(
 filepath,
 overwrite=True,
 save_format=None
)

Arguments:

filepath: String, path to the file to save the weights to. When saving in TensorFlow format, this is the prefix used for checkpoint files (multiple files are generated). Note that the '.h5' suffix causes weights to be saved in HDF5 format.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

save_format: Either 'tf' or 'h5'. A filepath ending in '.h5' or '.keras' will default to HDF5 if save_format is None. Otherwise None defaults to 'tf'.

load_weights(
 filepath,
 by_name=False
)

实例1:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save_weights('./save_weights/my_save_weights')
 
# step7 删除模型
del model
 
# step8 重新创建模型
model = create_model()
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step9 恢复权重
model.load_weights('./save_weights/my_save_weights')
 
# step10 测试模型
loss, acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.55%

Restored model, accuracy:96.55%

可以看到在模型的权重和偏置恢复之后,在测试集合上同样达到了训练之前相同的准确率。

方法2:直接保存整个模型

这种方法会将网络的结构,权重和优化器的状态等参数全部保存下来,后期恢复的时候就没必要创建新的网络了。

tf.keras.model类中的save方法和load_model方法

save(
 filepath,
 overwrite=True,
 include_optimizer=True,
 save_format=None
)

Arguments:

filepath: String, path to SavedModel or H5 file to save the model.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

include_optimizer: If True, save optimizer's state together.

save_format: Either 'tf' or 'h5', indicating whether to save the model to Tensorflow SavedModel or HDF5. The default is currently 'h5', but will switch to 'tf' in TensorFlow 2.0. The 'tf' option is currently disabled (use tf.keras.experimental.export_saved_model instead).

实例2:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
 
# step7 删除模型
del model # deletes the existing model
 
 
# step8 恢复模型
# returns a compiled model
# identical to the previous one
restored_model = tf.keras.models.load_model('my_model.h5')
 
# step9 测试模型
loss, acc = restored_model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.94%

Restored model, accuracy:96.94%

方法3:使用tf.keras.callbacks.ModelCheckpoint方法在训练过程中保存模型

该方法继承自tf.keras.callbacks类,一般配合mode.fit函数使用

以上这篇tensorflow2.0保存和恢复模型3种方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基于queue和threading实现多线程下载实例
Oct 08 Python
Python操作串口的方法
Jun 17 Python
python字符串的常用操作方法小结
May 21 Python
Python的爬虫框架scrapy用21行代码写一个爬虫
Apr 24 Python
查找python项目依赖并生成requirements.txt的方法
Jul 10 Python
Python对CSV、Excel、txt、dat文件的处理
Sep 18 Python
Python tkinter的grid布局及Text动态显示方法
Oct 11 Python
python读取有密码的zip压缩文件实例
Feb 08 Python
Python正则表达式匹配日期与时间的方法
Jul 07 Python
对django 模型 unique together的示例讲解
Aug 06 Python
Python time库基本使用方法分析
Dec 13 Python
Pyinstaller 打包发布经验总结
Jun 02 Python
详解字符串在Python内部是如何省内存的
Feb 03 #Python
python自动化unittest yaml使用过程解析
Feb 03 #Python
Python类如何定义私有变量
Feb 03 #Python
python异常处理try except过程解析
Feb 03 #Python
利用Python脚本实现自动刷网课
Feb 03 #Python
tensorflow 限制显存大小的实现
Feb 03 #Python
基于tensorflow指定GPU运行及GPU资源分配的几种方式小结
Feb 03 #Python
You might like
如何使用PHP获取网络上文件
2006/10/09 PHP
我的论坛源代码(二)
2006/10/09 PHP
php 实现进制转换(二进制、八进制、十六进制)互相转换实现代码
2010/10/22 PHP
PHP微信开发之微信消息自动回复下所遇到的坑
2016/05/09 PHP
PHP验证终端类型是否为手机的简单实例
2017/02/07 PHP
Aster vs KG BO3 第二场2.19
2021/03/10 DOTA
javascript[js]获取url参数的代码
2007/10/17 Javascript
JavaScript Accessor实现说明
2010/12/06 Javascript
JavaScript测试工具之Karma-Jasmine的安装和使用详解
2015/12/03 Javascript
bootstrap模态框实现拖拽效果
2016/12/14 Javascript
详解angularJs指令的3种绑定策略
2017/04/13 Javascript
微信小程序实现锚点定位楼层跳跃的实例
2017/05/18 Javascript
深入理解nodejs中Express的中间件
2017/05/19 NodeJs
JavaScript实现简单的树形菜单效果
2017/06/23 Javascript
详解使用element-ui table组件的筛选功能的一个小坑
2018/11/02 Javascript
vue-cli3使用 DllPlugin 实现预编译提升构建速度
2019/04/24 Javascript
JS数据类型分类及常用判断方法
2020/11/19 Javascript
Angular处理未可知异常错误的方法详解
2021/01/17 Javascript
python计算程序开始到程序结束的运行时间和程序运行的CPU时间
2013/11/28 Python
Python统计日志中每个IP出现次数的方法
2015/07/06 Python
Python 多线程抓取图片效率对比
2016/02/27 Python
pip matplotlib报错equired packages can not be built解决
2018/01/06 Python
详解如何从TensorFlow的mnist数据集导出手写体数字图片
2019/08/05 Python
python根据时间获取周数代码实例
2019/09/30 Python
VSCode中自动为Python文件添加头部注释
2019/11/14 Python
python 实现两个npy档案合并
2020/07/01 Python
澳大利亚小众服装品牌:Maurie & Eve
2018/03/27 全球购物
美国美食礼品篮网站:Gourmet Gift Baskets
2019/12/15 全球购物
渗透攻击的测试步骤
2014/06/07 面试题
工厂搬迁方案
2014/05/11 职场文书
求职信范文大全
2014/05/26 职场文书
见习报告格式要求
2014/11/04 职场文书
我们的节日中秋节活动总结
2015/03/23 职场文书
Pytorch可视化的几种实现方法
2021/06/10 Python
Python基本数据类型之字符串str
2021/07/21 Python
vue ant design 封装弹窗表单的使用
2022/06/01 Vue.js