tensorflow2.0保存和恢复模型3种方法


Posted in Python onFebruary 03, 2020

方法1:只保存模型的权重和偏置

这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创建和之前模型一模一样的模型,以保证权重和偏置的维度和保存之前的相同。

tf.keras.model类中的save_weights方法和load_weights方法,参数解释我就直接搬运官网的内容了。

save_weights(
 filepath,
 overwrite=True,
 save_format=None
)

Arguments:

filepath: String, path to the file to save the weights to. When saving in TensorFlow format, this is the prefix used for checkpoint files (multiple files are generated). Note that the '.h5' suffix causes weights to be saved in HDF5 format.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

save_format: Either 'tf' or 'h5'. A filepath ending in '.h5' or '.keras' will default to HDF5 if save_format is None. Otherwise None defaults to 'tf'.

load_weights(
 filepath,
 by_name=False
)

实例1:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save_weights('./save_weights/my_save_weights')
 
# step7 删除模型
del model
 
# step8 重新创建模型
model = create_model()
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step9 恢复权重
model.load_weights('./save_weights/my_save_weights')
 
# step10 测试模型
loss, acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.55%

Restored model, accuracy:96.55%

可以看到在模型的权重和偏置恢复之后,在测试集合上同样达到了训练之前相同的准确率。

方法2:直接保存整个模型

这种方法会将网络的结构,权重和优化器的状态等参数全部保存下来,后期恢复的时候就没必要创建新的网络了。

tf.keras.model类中的save方法和load_model方法

save(
 filepath,
 overwrite=True,
 include_optimizer=True,
 save_format=None
)

Arguments:

filepath: String, path to SavedModel or H5 file to save the model.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

include_optimizer: If True, save optimizer's state together.

save_format: Either 'tf' or 'h5', indicating whether to save the model to Tensorflow SavedModel or HDF5. The default is currently 'h5', but will switch to 'tf' in TensorFlow 2.0. The 'tf' option is currently disabled (use tf.keras.experimental.export_saved_model instead).

实例2:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
 
# step7 删除模型
del model # deletes the existing model
 
 
# step8 恢复模型
# returns a compiled model
# identical to the previous one
restored_model = tf.keras.models.load_model('my_model.h5')
 
# step9 测试模型
loss, acc = restored_model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.94%

Restored model, accuracy:96.94%

方法3:使用tf.keras.callbacks.ModelCheckpoint方法在训练过程中保存模型

该方法继承自tf.keras.callbacks类,一般配合mode.fit函数使用

以上这篇tensorflow2.0保存和恢复模型3种方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
深入理解Python中的元类(metaclass)
Feb 14 Python
简介二分查找算法与相关的Python实现示例
Aug 26 Python
python实现用户登录系统
May 21 Python
Python表示矩阵的方法分析
May 26 Python
Python 3.6 性能测试框架Locust安装及使用方法(详解)
Oct 11 Python
Python实现字符型图片验证码识别完整过程详解
May 10 Python
Python中Numpy mat的使用详解
May 24 Python
python实现桌面托盘气泡提示
Jul 29 Python
在vscode中配置python环境过程解析
Sep 28 Python
Selenium使用Chrome模拟手机浏览器方法解析
Apr 10 Python
python异常中else的实例用法
Jun 15 Python
Python基础之变量的相关知识总结
Jun 23 Python
详解字符串在Python内部是如何省内存的
Feb 03 #Python
python自动化unittest yaml使用过程解析
Feb 03 #Python
Python类如何定义私有变量
Feb 03 #Python
python异常处理try except过程解析
Feb 03 #Python
利用Python脚本实现自动刷网课
Feb 03 #Python
tensorflow 限制显存大小的实现
Feb 03 #Python
基于tensorflow指定GPU运行及GPU资源分配的几种方式小结
Feb 03 #Python
You might like
国内咖啡文化
2021/03/03 咖啡文化
PHP使用get_headers函数判断远程文件是否存在的方法
2014/11/28 PHP
php绘制圆形的方法
2015/01/24 PHP
php无限分类使用concat如何实现
2015/11/05 PHP
Yii Framework框架使用PHPExcel组件的方法示例
2019/07/24 PHP
JavaScript CSS修改学习第一章 查找位置
2010/02/19 Javascript
兼容IE和Firefox的javascript获取iframe文档内容的函数
2011/08/15 Javascript
完美兼容各大浏览器获取HTTP_REFERER方法总结
2014/06/24 Javascript
基于jQuery实现网页进度显示插件
2015/03/04 Javascript
JQuery中DOM事件绑定用法详解
2015/06/13 Javascript
jQuery实现默认是闭合的FAQ展开效果菜单
2015/09/14 Javascript
Bootstrap源码解读按钮(5)
2016/12/23 Javascript
微信小程序 slider 详解及实例代码
2017/01/10 Javascript
防止重复发送 Ajax 请求
2017/02/15 Javascript
angularjs 动态从后台获取下拉框的值方法
2018/08/13 Javascript
对vue v-if v-else-if v-else 的简单使用详解
2018/09/29 Javascript
JS调用安卓手机摄像头扫描二维码
2018/10/16 Javascript
Vue2.x通用条件搜索组件的封装及应用详解
2019/05/28 Javascript
基于Vue实现微前端的示例代码
2020/04/24 Javascript
python实现得到一个给定类的虚函数
2014/09/28 Python
python检查序列seq是否含有aset中项的方法
2015/06/30 Python
django用户注册、登录、注销和用户扩展的示例
2018/03/19 Python
Tensorflow 训练自己的数据集将数据直接导入到内存
2018/06/19 Python
解决python 未发现数据源名称并且未指定默认驱动程序的问题
2018/12/07 Python
Python实现获取汉字偏旁部首的方法示例【测试可用】
2018/12/18 Python
Python之Class&Object用法详解
2019/12/25 Python
python 用opencv实现霍夫线变换
2020/11/27 Python
CSS3 3D制作实战案例分析
2016/09/18 HTML / CSS
开办化妆品公司创业计划书
2013/12/26 职场文书
主题酒店策划书
2014/01/28 职场文书
个人求职信范例
2014/01/29 职场文书
电子商务专业学生职业生涯规划
2014/03/07 职场文书
晚自修旷课检讨书怎么写
2014/11/17 职场文书
幼儿教师个人总结
2015/02/05 职场文书
地球一小时活动总结
2015/02/27 职场文书
redis配置文件中常用配置详解
2021/04/14 Redis