tensorflow2.0保存和恢复模型3种方法


Posted in Python onFebruary 03, 2020

方法1:只保存模型的权重和偏置

这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创建和之前模型一模一样的模型,以保证权重和偏置的维度和保存之前的相同。

tf.keras.model类中的save_weights方法和load_weights方法,参数解释我就直接搬运官网的内容了。

save_weights(
 filepath,
 overwrite=True,
 save_format=None
)

Arguments:

filepath: String, path to the file to save the weights to. When saving in TensorFlow format, this is the prefix used for checkpoint files (multiple files are generated). Note that the '.h5' suffix causes weights to be saved in HDF5 format.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

save_format: Either 'tf' or 'h5'. A filepath ending in '.h5' or '.keras' will default to HDF5 if save_format is None. Otherwise None defaults to 'tf'.

load_weights(
 filepath,
 by_name=False
)

实例1:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save_weights('./save_weights/my_save_weights')
 
# step7 删除模型
del model
 
# step8 重新创建模型
model = create_model()
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step9 恢复权重
model.load_weights('./save_weights/my_save_weights')
 
# step10 测试模型
loss, acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.55%

Restored model, accuracy:96.55%

可以看到在模型的权重和偏置恢复之后,在测试集合上同样达到了训练之前相同的准确率。

方法2:直接保存整个模型

这种方法会将网络的结构,权重和优化器的状态等参数全部保存下来,后期恢复的时候就没必要创建新的网络了。

tf.keras.model类中的save方法和load_model方法

save(
 filepath,
 overwrite=True,
 include_optimizer=True,
 save_format=None
)

Arguments:

filepath: String, path to SavedModel or H5 file to save the model.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

include_optimizer: If True, save optimizer's state together.

save_format: Either 'tf' or 'h5', indicating whether to save the model to Tensorflow SavedModel or HDF5. The default is currently 'h5', but will switch to 'tf' in TensorFlow 2.0. The 'tf' option is currently disabled (use tf.keras.experimental.export_saved_model instead).

实例2:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
 
# step7 删除模型
del model # deletes the existing model
 
 
# step8 恢复模型
# returns a compiled model
# identical to the previous one
restored_model = tf.keras.models.load_model('my_model.h5')
 
# step9 测试模型
loss, acc = restored_model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.94%

Restored model, accuracy:96.94%

方法3:使用tf.keras.callbacks.ModelCheckpoint方法在训练过程中保存模型

该方法继承自tf.keras.callbacks类,一般配合mode.fit函数使用

以上这篇tensorflow2.0保存和恢复模型3种方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python解析中国天气网的天气数据
Mar 21 Python
Python实现简单登录验证
Apr 13 Python
Python学习小技巧之列表项的拼接
May 20 Python
基于Django filter中用contains和icontains的区别(详解)
Dec 12 Python
Python实现的井字棋(Tic Tac Toe)游戏示例
Jan 31 Python
基于python 爬虫爬到含空格的url的处理方法
May 11 Python
python3 实现一行输入,空格隔开的示例
Nov 14 Python
Django模板导入母版继承和自定义返回Html片段过程解析
Sep 18 Python
50行Python代码实现视频中物体颜色识别和跟踪(必须以红色为例)
Nov 20 Python
Python监控服务器实用工具psutil使用解析
Dec 19 Python
详解python定时简单爬取网页新闻存入数据库并发送邮件
Nov 27 Python
只需要100行Python代码就可以实现的贪吃蛇小游戏
May 27 Python
详解字符串在Python内部是如何省内存的
Feb 03 #Python
python自动化unittest yaml使用过程解析
Feb 03 #Python
Python类如何定义私有变量
Feb 03 #Python
python异常处理try except过程解析
Feb 03 #Python
利用Python脚本实现自动刷网课
Feb 03 #Python
tensorflow 限制显存大小的实现
Feb 03 #Python
基于tensorflow指定GPU运行及GPU资源分配的几种方式小结
Feb 03 #Python
You might like
PHP 裁剪图片成固定大小代码方法
2009/09/09 PHP
php 数组的一个悲剧?
2011/05/11 PHP
PHP闭包(Closure)使用详解
2013/05/02 PHP
php5.2以下版本无json_decode函数的解决方法
2014/05/25 PHP
配置eAccelerator和XCache扩展来加速PHP程序的执行
2015/12/22 PHP
Zend Framework基本页面布局分析
2016/03/19 PHP
使用 laravel sms 构建短信验证码发送校验功能
2017/11/06 PHP
强悍无比的WEB开发好助手FireBug(Firefox Plugin)
2007/01/16 Javascript
基于jquery+thickbox仿校内登录注册框
2010/06/07 Javascript
jQuery效果 slideToggle() 方法(在隐藏和显示之间切换)
2011/06/28 Javascript
JavaScript XML和string相互转化实现代码
2011/07/04 Javascript
jQuery图片预加载 等比缩放实现代码
2011/10/04 Javascript
JS实现定时自动关闭DIV层提示框的方法
2015/05/11 Javascript
JavaScript中数组的合并以及排序实现示例
2015/10/24 Javascript
跟我学习javascript的执行上下文
2015/11/18 Javascript
javascript获取系统当前时间的方法
2015/11/19 Javascript
JS实现n秒后自动跳转的两种方法
2020/11/30 Javascript
jQuery通用的全局遍历方法$.each()用法实例
2016/07/04 Javascript
15款最好的Bootstrap在线编辑器
2016/08/03 Javascript
JS实现加载和读取XML文件的方法详解
2017/04/24 Javascript
vue中格式化时间过滤器代码实例
2019/04/17 Javascript
vue+web端仿微信网页版聊天室功能
2019/04/30 Javascript
使用layui前端框架弹出form表单以及提交的示例
2019/10/25 Javascript
Vue实现点击导航栏当前标签后变色功能
2020/08/19 Javascript
[46:49]完美世界DOTA2联赛PWL S3 access vs Rebirth 第二场 12.19
2020/12/24 DOTA
python+ffmpeg视频并发直播压力测试
2018/03/06 Python
python 表格打印代码实例解析
2019/10/12 Python
Pandas操作CSV文件的读写实现方法
2019/11/13 Python
Restful_framework视图组件代码实例解析
2020/11/17 Python
Python Process创建进程的2种方法详解
2021/01/25 Python
英国领先的露营和露营车品牌之一:OLPRO
2019/08/06 全球购物
世界上最大的乐谱选择:Sheet Music Plus
2020/01/18 全球购物
高考标语大全
2014/06/05 职场文书
大学竞选班干部演讲稿
2014/08/21 职场文书
浅谈Vue的computed计算属性
2022/03/21 Vue.js
win11电脑关机鼠标灯还亮怎么解决? win11关机后鼠标灯还亮解决方法
2023/01/09 数码科技