tensorflow2.0保存和恢复模型3种方法


Posted in Python onFebruary 03, 2020

方法1:只保存模型的权重和偏置

这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创建和之前模型一模一样的模型,以保证权重和偏置的维度和保存之前的相同。

tf.keras.model类中的save_weights方法和load_weights方法,参数解释我就直接搬运官网的内容了。

save_weights(
 filepath,
 overwrite=True,
 save_format=None
)

Arguments:

filepath: String, path to the file to save the weights to. When saving in TensorFlow format, this is the prefix used for checkpoint files (multiple files are generated). Note that the '.h5' suffix causes weights to be saved in HDF5 format.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

save_format: Either 'tf' or 'h5'. A filepath ending in '.h5' or '.keras' will default to HDF5 if save_format is None. Otherwise None defaults to 'tf'.

load_weights(
 filepath,
 by_name=False
)

实例1:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save_weights('./save_weights/my_save_weights')
 
# step7 删除模型
del model
 
# step8 重新创建模型
model = create_model()
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step9 恢复权重
model.load_weights('./save_weights/my_save_weights')
 
# step10 测试模型
loss, acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.55%

Restored model, accuracy:96.55%

可以看到在模型的权重和偏置恢复之后,在测试集合上同样达到了训练之前相同的准确率。

方法2:直接保存整个模型

这种方法会将网络的结构,权重和优化器的状态等参数全部保存下来,后期恢复的时候就没必要创建新的网络了。

tf.keras.model类中的save方法和load_model方法

save(
 filepath,
 overwrite=True,
 include_optimizer=True,
 save_format=None
)

Arguments:

filepath: String, path to SavedModel or H5 file to save the model.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

include_optimizer: If True, save optimizer's state together.

save_format: Either 'tf' or 'h5', indicating whether to save the model to Tensorflow SavedModel or HDF5. The default is currently 'h5', but will switch to 'tf' in TensorFlow 2.0. The 'tf' option is currently disabled (use tf.keras.experimental.export_saved_model instead).

实例2:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
 
 
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
 
 
# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()
 
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])
 
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )
 
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
 
# step6 保存模型的权重和偏置
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
 
# step7 删除模型
del model # deletes the existing model
 
 
# step8 恢复模型
# returns a compiled model
# identical to the previous one
restored_model = tf.keras.models.load_model('my_model.h5')
 
# step9 测试模型
loss, acc = restored_model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.94%

Restored model, accuracy:96.94%

方法3:使用tf.keras.callbacks.ModelCheckpoint方法在训练过程中保存模型

该方法继承自tf.keras.callbacks类,一般配合mode.fit函数使用

以上这篇tensorflow2.0保存和恢复模型3种方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用PyHook监听鼠标和键盘事件实例
Jul 18 Python
wxPython事件驱动实例详解
Sep 28 Python
简单介绍Python的轻便web框架Bottle
Apr 08 Python
Pandas 对Dataframe结构排序的实现方法
Apr 10 Python
Scrapy框架使用的基本知识
Oct 21 Python
Python函数返回不定数量的值方法
Jan 22 Python
Django-Model数据库操作(增删改查、连表结构)详解
Jul 17 Python
Python完成哈夫曼树编码过程及原理详解
Jul 29 Python
python hashlib加密实现代码
Oct 17 Python
python队列原理及实现方法示例
Nov 27 Python
python输出数组中指定元素的所有索引示例
Dec 06 Python
详解Tensorflow不同版本要求与CUDA及CUDNN版本对应关系
Aug 04 Python
详解字符串在Python内部是如何省内存的
Feb 03 #Python
python自动化unittest yaml使用过程解析
Feb 03 #Python
Python类如何定义私有变量
Feb 03 #Python
python异常处理try except过程解析
Feb 03 #Python
利用Python脚本实现自动刷网课
Feb 03 #Python
tensorflow 限制显存大小的实现
Feb 03 #Python
基于tensorflow指定GPU运行及GPU资源分配的几种方式小结
Feb 03 #Python
You might like
成本8450万,票房仅2亿,口碑两极分化,又一部DC电影扑街了
2020/04/09 欧美动漫
php生成年月日下载列表的方法
2015/04/24 PHP
微信支付开发维权通知实例
2016/07/12 PHP
详解PHP实现定时任务的五种方法
2016/07/25 PHP
PHP生成及获取JSON文件的方法
2016/08/23 PHP
PHP中的函数声明与使用详解
2017/05/27 PHP
PHP实现的mysql主从数据库状态检测功能示例
2017/07/20 PHP
禁止F5等快捷键的JS代码
2007/03/06 Javascript
Javascript String.replace的妙用
2009/09/08 Javascript
Extjs 继承Ext.data.Store不起作用原因分析及解决
2013/04/15 Javascript
JQuery实现防止退格键返回的方法
2015/02/12 Javascript
文字垂直滚动之javascript代码
2015/07/29 Javascript
js获取浏览器高度 窗口高度 元素尺寸 偏移属性的方法
2016/11/21 Javascript
原生js实现类似fullpage的单页/全屏滚动
2017/01/22 Javascript
windows下vue-cli及webpack搭建安装环境
2017/04/25 Javascript
Angular中$broadcast和$emit的使用方法详解
2017/05/22 Javascript
微信小程序如何获取用户信息
2018/01/26 Javascript
Webpack中雪碧图插件使用详解
2018/05/25 Javascript
详解从NodeJS搭建中间层再谈前后端分离
2018/11/13 NodeJs
vue使用Font Awesome的方法步骤
2019/02/26 Javascript
详解vuex数据传输的两种方式及this.$store undefined的解决办法
2019/08/26 Javascript
windows实现npm和cnpm安装步骤
2019/10/24 Javascript
在Vue中使用mockjs代码实例
2020/11/25 Vue.js
为python设置socket代理的方法
2015/01/14 Python
python连接MySQL数据库实例分析
2015/05/12 Python
实例讲解Python中SocketServer模块处理网络请求的用法
2016/06/28 Python
详解flask入门模板引擎
2018/07/18 Python
Python Numpy:找到list中的np.nan值方法
2018/10/30 Python
欧洲第一中国智能手机和平板电脑网上商店:CECT-SHOP
2018/01/08 全球购物
进程的查看和调度分别使用什么命令
2015/03/25 面试题
妇女干部培训方案
2014/05/12 职场文书
基层工作经验证明样本
2014/11/16 职场文书
护士年终考核评语
2014/12/31 职场文书
昆虫记读书笔记
2015/06/26 职场文书
JVM上高性能数据格式库包Apache Arrow入门和架构详解(Gkatziouras)
2021/05/26 Servers
MySQL中order by的使用详情
2021/11/17 MySQL