tensorflow2.0保存和恢复模型3种方法


Posted in Python onFebruary 03, 2020

方法1:只保存模型的权重和偏置

这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创建和之前模型一模一样的模型,以保证权重和偏置的维度和保存之前的相同。

tf.keras.model类中的save_weights方法和load_weights方法,参数解释我就直接搬运官网的内容了。

save_weights(
 filepath,
 overwrite=True,
 save_format=None
)

Arguments:

filepath: String, path to the file to save the weights to. When saving in TensorFlow format, this is the prefix used for checkpoint files (multiple files are generated). Note that the '.h5' suffix causes weights to be saved in HDF5 format.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

save_format: Either 'tf' or 'h5'. A filepath ending in '.h5' or '.keras' will default to HDF5 if save_format is None. Otherwise None defaults to 'tf'.

load_weights(
 filepath,
 by_name=False
)

实例1:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save_weights('./save_weights/my_save_weights')
 
# step7 删除模型
del model
 
# step8 重新创建模型
model = create_model()
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step9 恢复权重
model.load_weights('./save_weights/my_save_weights')
 
# step10 测试模型
loss, acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.55%

Restored model, accuracy:96.55%

可以看到在模型的权重和偏置恢复之后,在测试集合上同样达到了训练之前相同的准确率。

方法2:直接保存整个模型

这种方法会将网络的结构,权重和优化器的状态等参数全部保存下来,后期恢复的时候就没必要创建新的网络了。

tf.keras.model类中的save方法和load_model方法

save(
 filepath,
 overwrite=True,
 include_optimizer=True,
 save_format=None
)

Arguments:

filepath: String, path to SavedModel or H5 file to save the model.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

include_optimizer: If True, save optimizer's state together.

save_format: Either 'tf' or 'h5', indicating whether to save the model to Tensorflow SavedModel or HDF5. The default is currently 'h5', but will switch to 'tf' in TensorFlow 2.0. The 'tf' option is currently disabled (use tf.keras.experimental.export_saved_model instead).

实例2:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
 
# step7 删除模型
del model # deletes the existing model
 
 
# step8 恢复模型
# returns a compiled model
# identical to the previous one
restored_model = tf.keras.models.load_model('my_model.h5')
 
# step9 测试模型
loss, acc = restored_model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.94%

Restored model, accuracy:96.94%

方法3:使用tf.keras.callbacks.ModelCheckpoint方法在训练过程中保存模型

该方法继承自tf.keras.callbacks类,一般配合mode.fit函数使用

以上这篇tensorflow2.0保存和恢复模型3种方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python脚本生成Android SALT扰码的方法
Sep 18 Python
老生常谈Python进阶之装饰器
May 11 Python
python使用邻接矩阵构造图代码示例
Nov 10 Python
python使用xpath中遇到:到底是什么?
Jan 04 Python
Python使用SQLite和Excel操作进行数据分析
Jan 20 Python
Python实现的径向基(RBF)神经网络示例
Feb 06 Python
python中数组和矩阵乘法及使用总结(推荐)
May 18 Python
Python 根据日志级别打印不同颜色的日志的方法示例
Aug 08 Python
Python Print实现在输出中插入变量的例子
Dec 25 Python
Python实现桌面翻译工具【新手必学】
Feb 12 Python
Nginx+Uwsgi+Django 项目部署到服务器的思路详解
May 08 Python
python获取百度热榜链接的实例方法
Aug 25 Python
详解字符串在Python内部是如何省内存的
Feb 03 #Python
python自动化unittest yaml使用过程解析
Feb 03 #Python
Python类如何定义私有变量
Feb 03 #Python
python异常处理try except过程解析
Feb 03 #Python
利用Python脚本实现自动刷网课
Feb 03 #Python
tensorflow 限制显存大小的实现
Feb 03 #Python
基于tensorflow指定GPU运行及GPU资源分配的几种方式小结
Feb 03 #Python
You might like
PHP与已存在的Java应用程序集成
2006/10/09 PHP
PHP发明人谈MVC和网站设计架构 貌似他不支持php用mvc
2011/06/04 PHP
PHP的几个常用数字判断函数代码
2012/04/24 PHP
php中is_null,empty,isset,unset 的区别详细介绍
2013/04/28 PHP
PHP中常用的输出函数总结
2014/09/22 PHP
php安装php_rar扩展实现rar文件读取和解压的方法
2016/11/17 PHP
JS实现浏览器菜单命令
2006/09/05 Javascript
原生js和jquery中有关透明度设置的相关问题
2014/01/08 Javascript
jQuery html()方法使用不了无法显示内容的问题
2014/08/06 Javascript
JS中捕获console.log()输出的方法
2015/04/16 Javascript
js实现类似菜单风格的TAB选项卡效果代码
2015/08/28 Javascript
JavaScript中的this到底是什么(一)
2015/12/09 Javascript
图文详解Heap Sort堆排序算法及JavaScript的代码实现
2016/05/04 Javascript
jquery使用FormData实现异步上传文件
2018/10/25 jQuery
vue+canvas实现炫酷时钟效果的倒计时插件(已发布到npm的vue2插件,开箱即用)
2018/11/05 Javascript
浅谈Vue页面级缓存解决方案feb-alive(上)
2019/04/14 Javascript
通过javascript实现段落的收缩与展开
2019/06/26 Javascript
vue中h5端打开app(判断是安卓还是苹果)
2021/02/26 Vue.js
Python使用scrapy抓取网站sitemap信息的方法
2015/04/08 Python
Python实例一个类背后发生了什么
2016/02/09 Python
python爬虫入门教程--利用requests构建知乎API(三)
2017/05/25 Python
django 常用orm操作详解
2017/09/13 Python
django+xadmin+djcelery实现后台管理定时任务
2018/08/14 Python
详解Python3网络爬虫(二):利用urllib.urlopen向有道翻译发送数据获得翻译结果
2019/05/07 Python
python批量将excel内容进行翻译写入功能
2019/10/10 Python
Python数据持久化存储实现方法分析
2019/12/21 Python
CSS3 实现的缩略图悬停效果
2020/12/09 HTML / CSS
关于HTML5的安全问题开发人员需要牢记的
2012/06/21 HTML / CSS
Linux中如何设置Java环境变量(Ubuntu)
2016/07/24 面试题
数控加工专业毕业生自荐信
2013/09/27 职场文书
小学生检讨书大全
2014/02/06 职场文书
小学生元旦感言
2014/02/26 职场文书
摄影专业毕业生求职信
2014/03/13 职场文书
效能监察建议书
2014/05/19 职场文书
综合素质评价自我评价
2015/03/06 职场文书
go 原生http web 服务跨域restful api的写法介绍
2021/04/27 Golang