tensorflow2.0保存和恢复模型3种方法


Posted in Python onFebruary 03, 2020

方法1:只保存模型的权重和偏置

这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创建和之前模型一模一样的模型,以保证权重和偏置的维度和保存之前的相同。

tf.keras.model类中的save_weights方法和load_weights方法,参数解释我就直接搬运官网的内容了。

save_weights(
 filepath,
 overwrite=True,
 save_format=None
)

Arguments:

filepath: String, path to the file to save the weights to. When saving in TensorFlow format, this is the prefix used for checkpoint files (multiple files are generated). Note that the '.h5' suffix causes weights to be saved in HDF5 format.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

save_format: Either 'tf' or 'h5'. A filepath ending in '.h5' or '.keras' will default to HDF5 if save_format is None. Otherwise None defaults to 'tf'.

load_weights(
 filepath,
 by_name=False
)

实例1:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save_weights('./save_weights/my_save_weights')
 
# step7 删除模型
del model
 
# step8 重新创建模型
model = create_model()
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step9 恢复权重
model.load_weights('./save_weights/my_save_weights')
 
# step10 测试模型
loss, acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.55%

Restored model, accuracy:96.55%

可以看到在模型的权重和偏置恢复之后,在测试集合上同样达到了训练之前相同的准确率。

方法2:直接保存整个模型

这种方法会将网络的结构,权重和优化器的状态等参数全部保存下来,后期恢复的时候就没必要创建新的网络了。

tf.keras.model类中的save方法和load_model方法

save(
 filepath,
 overwrite=True,
 include_optimizer=True,
 save_format=None
)

Arguments:

filepath: String, path to SavedModel or H5 file to save the model.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

include_optimizer: If True, save optimizer's state together.

save_format: Either 'tf' or 'h5', indicating whether to save the model to Tensorflow SavedModel or HDF5. The default is currently 'h5', but will switch to 'tf' in TensorFlow 2.0. The 'tf' option is currently disabled (use tf.keras.experimental.export_saved_model instead).

实例2:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
 
# step7 删除模型
del model # deletes the existing model
 
 
# step8 恢复模型
# returns a compiled model
# identical to the previous one
restored_model = tf.keras.models.load_model('my_model.h5')
 
# step9 测试模型
loss, acc = restored_model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.94%

Restored model, accuracy:96.94%

方法3:使用tf.keras.callbacks.ModelCheckpoint方法在训练过程中保存模型

该方法继承自tf.keras.callbacks类,一般配合mode.fit函数使用

以上这篇tensorflow2.0保存和恢复模型3种方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python脚本暴力破解栅栏密码
Oct 19 Python
在Python中通过threading模块定义和调用线程的方法
Jul 12 Python
Python之自动获取公网IP的实例讲解
Oct 01 Python
详解Python3 中hasattr()、getattr()、setattr()、delattr()函数及示例代码数
Apr 18 Python
程序员写Python时的5个坏习惯,你有几条?
Nov 26 Python
Python设计模式之命令模式原理与用法实例分析
Jan 11 Python
python设置环境变量的作用和实例
Jul 09 Python
python SVM 线性分类模型的实现
Jul 19 Python
Python pandas用法最全整理
Aug 04 Python
使用tensorflow DataSet实现高效加载变长文本输入
Jan 20 Python
PyTorch-GPU加速实例
Jun 23 Python
Python中对象的比较操作==和is区别详析
Feb 12 Python
详解字符串在Python内部是如何省内存的
Feb 03 #Python
python自动化unittest yaml使用过程解析
Feb 03 #Python
Python类如何定义私有变量
Feb 03 #Python
python异常处理try except过程解析
Feb 03 #Python
利用Python脚本实现自动刷网课
Feb 03 #Python
tensorflow 限制显存大小的实现
Feb 03 #Python
基于tensorflow指定GPU运行及GPU资源分配的几种方式小结
Feb 03 #Python
You might like
php 静态变量与自定义常量的使用方法
2010/01/26 PHP
php中的常用魔术方法总结
2013/08/02 PHP
php实现的Captcha验证码类实例
2014/09/22 PHP
js prototype截取字符串函数
2010/04/01 Javascript
自动最大化窗口的Javascript代码
2013/05/22 Javascript
JS小游戏之仙剑翻牌源码详解
2014/09/25 Javascript
JavaScript中实现依赖注入的思路分享
2015/01/15 Javascript
使用DNode实现php和nodejs之间通信的简单实例
2015/07/06 NodeJs
解析JavaScript的ES6版本中的解构赋值
2015/07/28 Javascript
jQuery垂直多级导航菜单代码分享
2015/08/18 Javascript
jQuery+CSS3实现3D立方体旋转效果
2015/11/10 Javascript
Jquery promise实现一张一张加载图片
2015/11/13 Javascript
基于javascript实现listbox左右移动
2016/01/29 Javascript
js本地图片预览实现代码
2016/10/09 Javascript
JS实现HTML标签转义及反转义
2020/04/14 Javascript
vue如何根据网站路由判断页面主题色详解
2018/11/02 Javascript
微信小程序保存多张图片的实现方法
2019/03/05 Javascript
angular 实现下拉列表组件的示例代码
2019/03/09 Javascript
[01:02]DOTA2上海特锦赛SHOWOPEN
2016/03/25 DOTA
python概率计算器实例分析
2015/03/25 Python
Python编程实现双链表,栈,队列及二叉树的方法示例
2017/11/01 Python
Python简单实现的代理服务器端口映射功能示例
2018/04/08 Python
使用Numpy读取CSV文件,并进行行列删除的操作方法
2018/07/04 Python
Python浮点数四舍五入问题的分析与解决方法
2019/11/19 Python
python3格式化字符串 f-string的高级用法(推荐)
2020/03/04 Python
Python如何实现单例模式
2016/06/03 面试题
咖啡店自主创业商业计划书
2014/01/22 职场文书
个人近期表现材料
2014/02/11 职场文书
服装采购员岗位职责
2014/03/15 职场文书
企业金融服务方案
2014/06/03 职场文书
大班亲子运动会方案
2014/06/10 职场文书
科学发展观活动总结
2014/08/28 职场文书
党章培训心得体会
2014/09/04 职场文书
财务工作个人总结
2015/02/27 职场文书
小时代观后感
2015/06/10 职场文书
环境卫生整治简报
2015/07/20 职场文书