tensorflow2.0保存和恢复模型3种方法


Posted in Python onFebruary 03, 2020

方法1:只保存模型的权重和偏置

这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创建和之前模型一模一样的模型,以保证权重和偏置的维度和保存之前的相同。

tf.keras.model类中的save_weights方法和load_weights方法,参数解释我就直接搬运官网的内容了。

save_weights(
 filepath,
 overwrite=True,
 save_format=None
)

Arguments:

filepath: String, path to the file to save the weights to. When saving in TensorFlow format, this is the prefix used for checkpoint files (multiple files are generated). Note that the '.h5' suffix causes weights to be saved in HDF5 format.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

save_format: Either 'tf' or 'h5'. A filepath ending in '.h5' or '.keras' will default to HDF5 if save_format is None. Otherwise None defaults to 'tf'.

load_weights(
 filepath,
 by_name=False
)

实例1:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save_weights('./save_weights/my_save_weights')
 
# step7 删除模型
del model
 
# step8 重新创建模型
model = create_model()
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step9 恢复权重
model.load_weights('./save_weights/my_save_weights')
 
# step10 测试模型
loss, acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.55%

Restored model, accuracy:96.55%

可以看到在模型的权重和偏置恢复之后,在测试集合上同样达到了训练之前相同的准确率。

方法2:直接保存整个模型

这种方法会将网络的结构,权重和优化器的状态等参数全部保存下来,后期恢复的时候就没必要创建新的网络了。

tf.keras.model类中的save方法和load_model方法

save(
 filepath,
 overwrite=True,
 include_optimizer=True,
 save_format=None
)

Arguments:

filepath: String, path to SavedModel or H5 file to save the model.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

include_optimizer: If True, save optimizer's state together.

save_format: Either 'tf' or 'h5', indicating whether to save the model to Tensorflow SavedModel or HDF5. The default is currently 'h5', but will switch to 'tf' in TensorFlow 2.0. The 'tf' option is currently disabled (use tf.keras.experimental.export_saved_model instead).

实例2:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
 
# step7 删除模型
del model # deletes the existing model
 
 
# step8 恢复模型
# returns a compiled model
# identical to the previous one
restored_model = tf.keras.models.load_model('my_model.h5')
 
# step9 测试模型
loss, acc = restored_model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.94%

Restored model, accuracy:96.94%

方法3:使用tf.keras.callbacks.ModelCheckpoint方法在训练过程中保存模型

该方法继承自tf.keras.callbacks类,一般配合mode.fit函数使用

以上这篇tensorflow2.0保存和恢复模型3种方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python读取mp3中ID3信息的方法
Mar 05 Python
python复制文件的方法实例详解
May 22 Python
python 全文检索引擎详解
Apr 25 Python
Python3使用PyQt5制作简单的画板/手写板实例
Oct 19 Python
Python+OpenCV感兴趣区域ROI提取方法
Jan 10 Python
Python 数据库操作 SQLAlchemy的示例代码
Feb 18 Python
python爬取内容存入Excel实例
Feb 20 Python
详解Python装饰器
Mar 25 Python
后端开发使用pycharm的技巧(推荐)
Mar 27 Python
python 一维二维插值实例
Apr 22 Python
python用Configobj模块读取配置文件
Sep 26 Python
Python 下载Bing壁纸的示例
Sep 29 Python
详解字符串在Python内部是如何省内存的
Feb 03 #Python
python自动化unittest yaml使用过程解析
Feb 03 #Python
Python类如何定义私有变量
Feb 03 #Python
python异常处理try except过程解析
Feb 03 #Python
利用Python脚本实现自动刷网课
Feb 03 #Python
tensorflow 限制显存大小的实现
Feb 03 #Python
基于tensorflow指定GPU运行及GPU资源分配的几种方式小结
Feb 03 #Python
You might like
libmysql.dll与php.ini是否真的要拷贝到c:\windows目录下呢
2010/03/15 PHP
常用PHP数组排序函数归纳
2016/08/08 PHP
php利用imagemagick实现复古老照片效果实例
2017/02/16 PHP
利用PHP判断是否是连乘数字串的方法示例
2017/07/03 PHP
一个tab标签切换效果代码
2009/03/27 Javascript
javascript 运算数的求值顺序
2011/08/23 Javascript
基于jQuery的360图片展示实现代码
2012/06/14 Javascript
基于jquery创建的一个图片、视频缓冲的效果样式插件
2012/08/28 Javascript
js onkeypress与onkeydown 事件区别详细说明
2012/12/13 Javascript
使用jQuery5分钟快速搞定双色表格的简单实例
2016/08/08 Javascript
关于jquery中动态增加select,事件无效的快速解决方法
2016/08/29 Javascript
javascript中call,apply,bind函数用法示例
2016/12/19 Javascript
详解react-webpack2-热模块替换[HMR]
2017/08/03 Javascript
基于vue的短信验证码倒计时demo
2017/09/13 Javascript
利用js给datalist或select动态添加option选项的方法
2018/01/25 Javascript
AngularJS双向数据绑定原理之$watch、$apply和$digest的应用
2018/01/30 Javascript
JavaScript 五大常见函数
2018/03/23 Javascript
jQuery实现的回车触发按钮事件功能示例
2018/03/25 jQuery
详解Vue底部导航栏组件
2019/05/02 Javascript
Python标准库defaultdict模块使用示例
2015/04/28 Python
让python在hadoop上跑起来
2016/01/27 Python
python中os模块详解
2016/10/14 Python
python遍历序列enumerate函数浅析
2017/10/17 Python
Python3.6实现连接mysql或mariadb的方法分析
2018/05/18 Python
python线程定时器Timer实现原理解析
2019/11/30 Python
pycharm不能运行.py文件的解决方法
2020/02/12 Python
详解css position 5种不同的值的用法
2019/07/30 HTML / CSS
详解css3 flex弹性盒自动铺满写法
2020/09/17 HTML / CSS
HTML5 微格式和相关的属性名称
2010/02/10 HTML / CSS
CSS3 画基本图形,圆形、椭圆形、三角形等
2016/09/20 HTML / CSS
精彩自我鉴定
2014/01/16 职场文书
小公司融资,商业计划书的8切记
2019/07/15 职场文书
SQL实现LeetCode(180.连续的数字)
2021/08/04 MySQL
js 数组 fill() 填充方法
2021/11/02 Javascript
JavaScript选择器函数querySelector和querySelectorAll
2021/11/27 Javascript
css3带你实现3D转换效果
2022/02/24 HTML / CSS