tensorflow2.0保存和恢复模型3种方法


Posted in Python onFebruary 03, 2020

方法1:只保存模型的权重和偏置

这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创建和之前模型一模一样的模型,以保证权重和偏置的维度和保存之前的相同。

tf.keras.model类中的save_weights方法和load_weights方法,参数解释我就直接搬运官网的内容了。

save_weights(
 filepath,
 overwrite=True,
 save_format=None
)

Arguments:

filepath: String, path to the file to save the weights to. When saving in TensorFlow format, this is the prefix used for checkpoint files (multiple files are generated). Note that the '.h5' suffix causes weights to be saved in HDF5 format.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

save_format: Either 'tf' or 'h5'. A filepath ending in '.h5' or '.keras' will default to HDF5 if save_format is None. Otherwise None defaults to 'tf'.

load_weights(
 filepath,
 by_name=False
)

实例1:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save_weights('./save_weights/my_save_weights')
 
# step7 删除模型
del model
 
# step8 重新创建模型
model = create_model()
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step9 恢复权重
model.load_weights('./save_weights/my_save_weights')
 
# step10 测试模型
loss, acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.55%

Restored model, accuracy:96.55%

可以看到在模型的权重和偏置恢复之后,在测试集合上同样达到了训练之前相同的准确率。

方法2:直接保存整个模型

这种方法会将网络的结构,权重和优化器的状态等参数全部保存下来,后期恢复的时候就没必要创建新的网络了。

tf.keras.model类中的save方法和load_model方法

save(
 filepath,
 overwrite=True,
 include_optimizer=True,
 save_format=None
)

Arguments:

filepath: String, path to SavedModel or H5 file to save the model.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

include_optimizer: If True, save optimizer's state together.

save_format: Either 'tf' or 'h5', indicating whether to save the model to Tensorflow SavedModel or HDF5. The default is currently 'h5', but will switch to 'tf' in TensorFlow 2.0. The 'tf' option is currently disabled (use tf.keras.experimental.export_saved_model instead).

实例2:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
 
# step7 删除模型
del model # deletes the existing model
 
 
# step8 恢复模型
# returns a compiled model
# identical to the previous one
restored_model = tf.keras.models.load_model('my_model.h5')
 
# step9 测试模型
loss, acc = restored_model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.94%

Restored model, accuracy:96.94%

方法3:使用tf.keras.callbacks.ModelCheckpoint方法在训练过程中保存模型

该方法继承自tf.keras.callbacks类,一般配合mode.fit函数使用

以上这篇tensorflow2.0保存和恢复模型3种方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python在Windows8下获取本机ip地址的方法
Mar 14 Python
详解Python中的join()函数的用法
Apr 07 Python
一篇文章快速了解Python的GIL
Jan 12 Python
Django restframework 源码分析之认证详解
Feb 22 Python
PyQt5 多窗口连接实例
Jun 19 Python
python:删除离群值操作(每一行为一类数据)
Jun 08 Python
Python使用socket模块实现简单tcp通信
Aug 18 Python
如何从csv文件构建Tensorflow的数据集
Sep 21 Python
详解Python中string模块除去Str还剩下什么
Nov 30 Python
Python 如何利用ffmpeg 处理视频素材
Nov 27 Python
Python机器学习应用之基于线性判别模型的分类篇详解
Jan 18 Python
python字符串的一些常见实用操作
Apr 06 Python
详解字符串在Python内部是如何省内存的
Feb 03 #Python
python自动化unittest yaml使用过程解析
Feb 03 #Python
Python类如何定义私有变量
Feb 03 #Python
python异常处理try except过程解析
Feb 03 #Python
利用Python脚本实现自动刷网课
Feb 03 #Python
tensorflow 限制显存大小的实现
Feb 03 #Python
基于tensorflow指定GPU运行及GPU资源分配的几种方式小结
Feb 03 #Python
You might like
php 遍历数据表数据并列表横向排列的代码
2009/09/05 PHP
IIS下PHP的三种配置方式对比
2014/11/20 PHP
PHP中strnatcmp()函数“自然排序算法”进行字符串比较用法分析(对比strcmp函数)
2016/01/07 PHP
PHP实现的微信APP支付功能示例【基于TP5框架】
2019/09/16 PHP
Javascript优化技巧(文件瘦身篇)
2008/01/28 Javascript
JS 对象介绍
2010/01/20 Javascript
jquery miniui 教程 表格控件 合并单元格应用
2012/11/25 Javascript
jQuery里filter()函数与find()函数用法分析
2015/06/24 Javascript
Bootstrap入门书籍之(一)排版
2016/02/17 Javascript
火狐和ie下获取javascript 获取event的方法(推荐)
2016/11/26 Javascript
AngularJS实现tab选项卡的方法详解
2017/07/05 Javascript
2种简单的js倒计时方式
2017/10/20 Javascript
jQuery代码优化方法总结
2018/01/29 jQuery
vue中设置、获取、删除cookie的方法
2018/09/21 Javascript
浅谈让你的代码更简短,更整洁,更易读的ES6小技巧
2018/10/25 Javascript
[09:43]DOTA2每周TOP10 精彩击杀集锦vol.5
2014/06/25 DOTA
[36:43]NB vs Optic 2018国际邀请赛小组赛BO1 B组加赛 8.19
2018/08/21 DOTA
python简单猜数游戏实例
2015/07/09 Python
使用Python简单的实现树莓派的WEB控制
2016/02/18 Python
详解python 字符串和日期之间转换 StringAndDate
2017/05/04 Python
使用Python进行QQ批量登录的实例代码
2018/06/11 Python
python爬取淘宝商品销量信息
2018/11/16 Python
pyqt 实现QlineEdit 输入密码显示成圆点的方法
2019/06/24 Python
python把转列表为集合的方法
2019/06/28 Python
python中时间、日期、时间戳的转换的实现方法
2019/07/06 Python
Python3.7在anaconda里面使用IDLE编译器的步骤详解
2020/04/29 Python
西班牙伏林航空公司:Vueling
2016/08/05 全球购物
高中毕业生自我鉴定
2013/11/03 职场文书
自强之星事迹材料
2014/05/12 职场文书
道德演讲稿
2014/05/21 职场文书
公司口号大全
2014/06/11 职场文书
工伤事故证明
2014/10/20 职场文书
医学生自荐信范文
2015/03/05 职场文书
小学体育队列队形教学反思
2016/02/16 职场文书
学习nginx基础知识
2021/09/04 Servers
Canvas绘制像素风图片的示例代码
2021/09/25 HTML / CSS