tensorflow2.0保存和恢复模型3种方法


Posted in Python onFebruary 03, 2020

方法1:只保存模型的权重和偏置

这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创建和之前模型一模一样的模型,以保证权重和偏置的维度和保存之前的相同。

tf.keras.model类中的save_weights方法和load_weights方法,参数解释我就直接搬运官网的内容了。

save_weights(
 filepath,
 overwrite=True,
 save_format=None
)

Arguments:

filepath: String, path to the file to save the weights to. When saving in TensorFlow format, this is the prefix used for checkpoint files (multiple files are generated). Note that the '.h5' suffix causes weights to be saved in HDF5 format.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

save_format: Either 'tf' or 'h5'. A filepath ending in '.h5' or '.keras' will default to HDF5 if save_format is None. Otherwise None defaults to 'tf'.

load_weights(
 filepath,
 by_name=False
)

实例1:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save_weights('./save_weights/my_save_weights')
 
# step7 删除模型
del model
 
# step8 重新创建模型
model = create_model()
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step9 恢复权重
model.load_weights('./save_weights/my_save_weights')
 
# step10 测试模型
loss, acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.55%

Restored model, accuracy:96.55%

可以看到在模型的权重和偏置恢复之后,在测试集合上同样达到了训练之前相同的准确率。

方法2:直接保存整个模型

这种方法会将网络的结构,权重和优化器的状态等参数全部保存下来,后期恢复的时候就没必要创建新的网络了。

tf.keras.model类中的save方法和load_model方法

save(
 filepath,
 overwrite=True,
 include_optimizer=True,
 save_format=None
)

Arguments:

filepath: String, path to SavedModel or H5 file to save the model.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

include_optimizer: If True, save optimizer's state together.

save_format: Either 'tf' or 'h5', indicating whether to save the model to Tensorflow SavedModel or HDF5. The default is currently 'h5', but will switch to 'tf' in TensorFlow 2.0. The 'tf' option is currently disabled (use tf.keras.experimental.export_saved_model instead).

实例2:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
 
# step7 删除模型
del model # deletes the existing model
 
 
# step8 恢复模型
# returns a compiled model
# identical to the previous one
restored_model = tf.keras.models.load_model('my_model.h5')
 
# step9 测试模型
loss, acc = restored_model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.94%

Restored model, accuracy:96.94%

方法3:使用tf.keras.callbacks.ModelCheckpoint方法在训练过程中保存模型

该方法继承自tf.keras.callbacks类,一般配合mode.fit函数使用

以上这篇tensorflow2.0保存和恢复模型3种方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python的dict,set,list,tuple应用详解
Jul 24 Python
Python列表list数组array用法实例解析
Oct 28 Python
python使用socket连接远程服务器的方法
Apr 29 Python
让Python代码更快运行的5种方法
Jun 21 Python
通过实例浅析Python对比C语言的编程思想差异
Aug 30 Python
浅谈Python浅拷贝、深拷贝及引用机制
Dec 15 Python
Python 出现错误TypeError: ‘NoneType’ object is not iterable解决办法
Jan 12 Python
python实现五子棋游戏
Jun 18 Python
python中hasattr()、getattr()、setattr()函数的使用
Aug 16 Python
Tensorflow tf.nn.atrous_conv2d如何实现空洞卷积的
Apr 20 Python
Python 使用Opencv实现目标检测与识别的示例代码
Sep 08 Python
python对输出的奇数偶数排序实例代码
Dec 04 Python
详解字符串在Python内部是如何省内存的
Feb 03 #Python
python自动化unittest yaml使用过程解析
Feb 03 #Python
Python类如何定义私有变量
Feb 03 #Python
python异常处理try except过程解析
Feb 03 #Python
利用Python脚本实现自动刷网课
Feb 03 #Python
tensorflow 限制显存大小的实现
Feb 03 #Python
基于tensorflow指定GPU运行及GPU资源分配的几种方式小结
Feb 03 #Python
You might like
php fsockopen中多线程问题的解决办法[翻译]
2011/11/09 PHP
php获取参数的几种方法总结
2014/02/18 PHP
php为字符串前后添加指定数量字符的方法
2015/05/04 PHP
php中smarty模板条件判断用法实例
2015/06/11 PHP
WordPress中注册菜单与调用菜单的方法详解
2015/12/18 PHP
php编译安装php-amq扩展简明教程
2016/06/25 PHP
PHP代码重构方法漫谈
2018/04/17 PHP
PHPMailer ThinkPHP实现自动发送邮件功能
2018/06/10 PHP
PHP whois查询类定义与用法示例
2019/04/03 PHP
PHP远程连接oracle数据库操作实现方法图文详解
2019/04/11 PHP
laravel框架实现后台登录、退出功能示例
2019/10/31 PHP
jQuery 使用手册(四)
2009/09/23 Javascript
jquery左右滚动焦点图banner图片鼠标经过显示上下页按钮
2013/10/11 Javascript
JavaScript中的toUTCString()方法使用详解
2015/06/12 Javascript
介绍JavaScript的一个微型模版
2015/06/24 Javascript
js document.getElementsByClassName的使用介绍与自定义函数
2016/11/25 Javascript
纯javaScript、jQuery实现个性化图片轮播【推荐】
2017/01/08 Javascript
Ionic+AngularJS实现登录和注册带验证功能
2017/02/09 Javascript
node.js中express中间件body-parser的介绍与用法详解
2017/05/23 Javascript
浅谈NodeJs之数据库异常处理
2017/10/25 NodeJs
React 无状态组件(Stateless Component) 与高阶组件
2018/08/14 Javascript
nodejs中使用archive压缩文件的实现代码
2019/11/26 NodeJs
解决Vue大括号字符换行踩的坑
2020/11/09 Javascript
JavaScript手写数组的常用函数总结
2020/11/22 Javascript
[02:32]DOTA2亚洲邀请赛 VG战队巡礼
2015/02/03 DOTA
[05:23]DOTA2-DPC中国联赛2月1日Recap集锦
2021/03/11 DOTA
简单的Apache+FastCGI+Django配置指南
2015/07/22 Python
Python函数装饰器原理与用法详解
2019/08/16 Python
python3正则模块re的使用方法详解
2020/02/11 Python
pycharm无法导入本地模块的解决方式
2020/02/12 Python
python实现文法左递归的消除方法
2020/05/22 Python
浅谈keras中的keras.utils.to_categorical用法
2020/07/02 Python
Dr.Jart+美国官网:韩国药妆品牌
2019/01/18 全球购物
初级Java程序员面试题
2016/03/03 面试题
双创工作实施方案
2014/03/26 职场文书
微信小程序实现聊天室功能
2021/06/14 Javascript