tensorflow2.0保存和恢复模型3种方法


Posted in Python onFebruary 03, 2020

方法1:只保存模型的权重和偏置

这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创建和之前模型一模一样的模型,以保证权重和偏置的维度和保存之前的相同。

tf.keras.model类中的save_weights方法和load_weights方法,参数解释我就直接搬运官网的内容了。

save_weights(
 filepath,
 overwrite=True,
 save_format=None
)

Arguments:

filepath: String, path to the file to save the weights to. When saving in TensorFlow format, this is the prefix used for checkpoint files (multiple files are generated). Note that the '.h5' suffix causes weights to be saved in HDF5 format.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

save_format: Either 'tf' or 'h5'. A filepath ending in '.h5' or '.keras' will default to HDF5 if save_format is None. Otherwise None defaults to 'tf'.

load_weights(
 filepath,
 by_name=False
)

实例1:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save_weights('./save_weights/my_save_weights')
 
# step7 删除模型
del model
 
# step8 重新创建模型
model = create_model()
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step9 恢复权重
model.load_weights('./save_weights/my_save_weights')
 
# step10 测试模型
loss, acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.55%

Restored model, accuracy:96.55%

可以看到在模型的权重和偏置恢复之后,在测试集合上同样达到了训练之前相同的准确率。

方法2:直接保存整个模型

这种方法会将网络的结构,权重和优化器的状态等参数全部保存下来,后期恢复的时候就没必要创建新的网络了。

tf.keras.model类中的save方法和load_model方法

save(
 filepath,
 overwrite=True,
 include_optimizer=True,
 save_format=None
)

Arguments:

filepath: String, path to SavedModel or H5 file to save the model.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

include_optimizer: If True, save optimizer's state together.

save_format: Either 'tf' or 'h5', indicating whether to save the model to Tensorflow SavedModel or HDF5. The default is currently 'h5', but will switch to 'tf' in TensorFlow 2.0. The 'tf' option is currently disabled (use tf.keras.experimental.export_saved_model instead).

实例2:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
 
# step7 删除模型
del model # deletes the existing model
 
 
# step8 恢复模型
# returns a compiled model
# identical to the previous one
restored_model = tf.keras.models.load_model('my_model.h5')
 
# step9 测试模型
loss, acc = restored_model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.94%

Restored model, accuracy:96.94%

方法3:使用tf.keras.callbacks.ModelCheckpoint方法在训练过程中保存模型

该方法继承自tf.keras.callbacks类,一般配合mode.fit函数使用

以上这篇tensorflow2.0保存和恢复模型3种方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中list循环语句用法实例
Nov 10 Python
在Python中操作时间之mktime()方法的使用教程
May 22 Python
Python3使用requests发闪存的方法
May 11 Python
Tensorflow卷积神经网络实例
May 24 Python
Python使用numpy模块创建数组操作示例
Jun 20 Python
pandas 读取各种格式文件的方法
Jun 22 Python
python3.8下载及安装步骤详解
Jan 15 Python
Python函数式编程实例详解
Jan 17 Python
tensorflow指定GPU与动态分配GPU memory设置
Feb 03 Python
python和js交互调用的方法
Jun 23 Python
利用Pycharm连接服务器的全过程记录
Jul 01 Python
python区块链实现简版工作量证明
May 25 Python
详解字符串在Python内部是如何省内存的
Feb 03 #Python
python自动化unittest yaml使用过程解析
Feb 03 #Python
Python类如何定义私有变量
Feb 03 #Python
python异常处理try except过程解析
Feb 03 #Python
利用Python脚本实现自动刷网课
Feb 03 #Python
tensorflow 限制显存大小的实现
Feb 03 #Python
基于tensorflow指定GPU运行及GPU资源分配的几种方式小结
Feb 03 #Python
You might like
php mysql索引问题
2008/06/07 PHP
在SAE上搭建最新wordpress的方法
2014/12/21 PHP
两种php去除二维数组的重复项方法
2015/11/04 PHP
如何使用php等比例缩放图片
2016/10/12 PHP
laravel高级的Join语法详解以及使用Join多个条件
2019/10/16 PHP
使用javascript访问XML数据的实例
2006/12/27 Javascript
用ASP将SQL搜索出来的内容导出为TXT的代码
2007/07/27 Javascript
Javascript 更新 JavaScript 数组的 uniq 方法
2008/01/23 Javascript
JavaScript flash复制库类 Zero Clipboard
2011/01/17 Javascript
利用javascript数组长度循环数组内所有元素
2013/12/27 Javascript
jQuery中:animated选择器用法实例
2014/12/29 Javascript
jQuery选择器querySelector的使用指南
2015/01/23 Javascript
javascript弹出拖动窗口
2015/08/11 Javascript
webpack+vue.js实现组件化详解
2016/10/12 Javascript
jquery实现图片平滑滚动详解
2017/03/22 jQuery
细说webpack源码之compile流程-rules参数处理技巧(2)
2017/12/26 Javascript
[01:02]DOTA2辉夜杯决赛日 CDEC.Y对阵VG赛前花絮
2015/12/27 DOTA
[50:24]VGJ.S vs Pain 2018国际邀请赛小组赛BO2 第二场 8.17
2018/08/20 DOTA
分析在Python中何种情况下需要使用断言
2015/04/01 Python
Python的Django框架中的数据过滤功能
2015/07/17 Python
Python处理CSV与List的转换方法
2018/04/19 Python
python中reader的next用法
2018/07/24 Python
pyqt远程批量执行Linux命令程序的方法
2019/02/14 Python
使用Python正则表达式操作文本数据的方法
2019/05/14 Python
基于Python实现ComicReaper漫画自动爬取脚本过程解析
2019/11/11 Python
使用python实现希尔、计数、基数基础排序的代码
2019/12/25 Python
pygame用blit()实现动画效果的示例代码
2020/05/28 Python
详解Python 循环嵌套
2020/07/09 Python
SmartBuyGlasses意大利:购买太阳镜、眼镜和隐形眼镜
2018/11/20 全球购物
如何启动时不需输入用户名与密码
2014/05/09 面试题
群众路线教育党课主持词
2014/04/01 职场文书
市场部经理岗位职责
2014/04/10 职场文书
小学班级口号
2014/06/09 职场文书
房屋维修申请报告
2015/05/18 职场文书
2016年教师党员承诺书范文
2016/03/24 职场文书
win10识别不了U盘怎么办 win10系统读取U盘失败的解决办法
2022/08/05 数码科技