tensorflow2.0保存和恢复模型3种方法


Posted in Python onFebruary 03, 2020

方法1:只保存模型的权重和偏置

这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创建和之前模型一模一样的模型,以保证权重和偏置的维度和保存之前的相同。

tf.keras.model类中的save_weights方法和load_weights方法,参数解释我就直接搬运官网的内容了。

save_weights(
 filepath,
 overwrite=True,
 save_format=None
)

Arguments:

filepath: String, path to the file to save the weights to. When saving in TensorFlow format, this is the prefix used for checkpoint files (multiple files are generated). Note that the '.h5' suffix causes weights to be saved in HDF5 format.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

save_format: Either 'tf' or 'h5'. A filepath ending in '.h5' or '.keras' will default to HDF5 if save_format is None. Otherwise None defaults to 'tf'.

load_weights(
 filepath,
 by_name=False
)

实例1:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save_weights('./save_weights/my_save_weights')
 
# step7 删除模型
del model
 
# step8 重新创建模型
model = create_model()
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step9 恢复权重
model.load_weights('./save_weights/my_save_weights')
 
# step10 测试模型
loss, acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.55%

Restored model, accuracy:96.55%

可以看到在模型的权重和偏置恢复之后,在测试集合上同样达到了训练之前相同的准确率。

方法2:直接保存整个模型

这种方法会将网络的结构,权重和优化器的状态等参数全部保存下来,后期恢复的时候就没必要创建新的网络了。

tf.keras.model类中的save方法和load_model方法

save(
 filepath,
 overwrite=True,
 include_optimizer=True,
 save_format=None
)

Arguments:

filepath: String, path to SavedModel or H5 file to save the model.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

include_optimizer: If True, save optimizer's state together.

save_format: Either 'tf' or 'h5', indicating whether to save the model to Tensorflow SavedModel or HDF5. The default is currently 'h5', but will switch to 'tf' in TensorFlow 2.0. The 'tf' option is currently disabled (use tf.keras.experimental.export_saved_model instead).

实例2:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
 
# step7 删除模型
del model # deletes the existing model
 
 
# step8 恢复模型
# returns a compiled model
# identical to the previous one
restored_model = tf.keras.models.load_model('my_model.h5')
 
# step9 测试模型
loss, acc = restored_model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.94%

Restored model, accuracy:96.94%

方法3:使用tf.keras.callbacks.ModelCheckpoint方法在训练过程中保存模型

该方法继承自tf.keras.callbacks类,一般配合mode.fit函数使用

以上这篇tensorflow2.0保存和恢复模型3种方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
pycharm 使用心得(七)一些实用功能介绍
Jun 06 Python
Python 爬虫多线程详解及实例代码
Oct 08 Python
python递归实现快速排序
Aug 18 Python
PyCharm设置护眼背景色的方法
Oct 29 Python
解决python中 f.write写入中文出错的问题
Oct 31 Python
python 执行文件时额外参数获取的实例
Dec 18 Python
Python完成哈夫曼树编码过程及原理详解
Jul 29 Python
使用OpenCV circle函数图像上画圆的示例代码
Dec 27 Python
Django实现将一个字典传到前端显示出来
Apr 03 Python
pyecharts在数据可视化中的应用详解
Jun 08 Python
如何解决安装python3.6.1失败
Jul 01 Python
python中取整数的几种方法
Nov 07 Python
详解字符串在Python内部是如何省内存的
Feb 03 #Python
python自动化unittest yaml使用过程解析
Feb 03 #Python
Python类如何定义私有变量
Feb 03 #Python
python异常处理try except过程解析
Feb 03 #Python
利用Python脚本实现自动刷网课
Feb 03 #Python
tensorflow 限制显存大小的实现
Feb 03 #Python
基于tensorflow指定GPU运行及GPU资源分配的几种方式小结
Feb 03 #Python
You might like
php过滤表单提交的html等危险代码
2014/11/03 PHP
php实现博客,论坛图片防盗链的方法
2016/10/15 PHP
PHP+MYSQL实现读写分离简单实战
2017/03/13 PHP
不用MOUSEMOVE也能滑动啊
2007/05/23 Javascript
jquery实现盒子下拉效果示例代码
2013/09/12 Javascript
javascript中数组的定义及使用实例
2015/01/21 Javascript
js实现简单选项卡与自动切换效果的方法
2015/04/10 Javascript
Javascript将数值转换为金额格式(分隔千分位和自动增加小数点)
2016/06/22 Javascript
javascript时间差插件分享
2016/07/18 Javascript
vue2.0开发实践总结之疑难篇
2016/12/07 Javascript
js异步编程小技巧详解
2017/08/14 Javascript
JS实现点击下拉菜单把选择的内容同步到input输入框内的实例
2018/01/23 Javascript
除Console.log()外更多的Javascript调试命令
2018/01/24 Javascript
angularjs获取到My97DatePicker选中的值方法
2018/10/02 Javascript
node.js中ws模块创建服务端和客户端,网页WebSocket客户端
2019/03/06 Javascript
vue 更改连接后台的api示例
2019/11/11 Javascript
Python程序退出方式小结
2017/12/09 Python
Python3.6简单反射操作示例
2018/06/14 Python
对web.py设置favicon.ico的方法详解
2018/12/04 Python
在Pycharm中修改文件默认打开方式的方法
2019/01/17 Python
python实现坦克大战游戏 附详细注释
2020/03/27 Python
详解python中各种文件打开模式
2020/01/19 Python
Python第三方库的几种安装方式(小结)
2020/04/03 Python
pandas将list数据拆分成行或列的实现
2020/12/13 Python
Python try except else使用详解
2021/01/12 Python
java关于string最常出现的面试题整理
2021/01/18 Python
什么是serialVersionUID
2016/03/04 面试题
工程管理造价应届生求职信
2013/11/13 职场文书
化妆师职业生涯规划书
2014/02/16 职场文书
党员四风问题对照检查材料思想汇报
2014/09/16 职场文书
2014年人事科工作总结
2014/11/19 职场文书
家属答谢词
2015/01/05 职场文书
幼儿园教师求职信
2015/03/20 职场文书
交通事故案件代理词
2015/05/23 职场文书
重阳节简报
2015/07/20 职场文书
python神经网络学习 使用Keras进行回归运算
2022/05/04 Python