Keras保存模型并载入模型继续训练的实现


Posted in Python onFebruary 20, 2021

我们以MNIST手写数字识别为例

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
 
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
 
# 创建模型,输入784个神经元,输出10个神经元
model = Sequential([
    Dense(units=10,input_dim=784,bias_initializer='one',activation='softmax')
  ])
 
# 定义优化器
sgd = SGD(lr=0.2)
 
# 定义优化器,loss function,训练过程中计算准确率
model.compile(
  optimizer = sgd,
  loss = 'mse',
  metrics=['accuracy'],
)
 
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=5)
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 保存模型
model.save('model.h5')  # HDF5文件,pip install h5py

Keras保存模型并载入模型继续训练的实现

Keras保存模型并载入模型继续训练的实现

载入初次训练的模型,再训练

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
from keras.models import load_model
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
 
# 载入模型
model = load_model('model.h5')
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=2)
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 保存参数,载入参数
model.save_weights('my_model_weights.h5')
model.load_weights('my_model_weights.h5')
# 保存网络结构,载入网络结构
from keras.models import model_from_json
json_string = model.to_json()
model = model_from_json(json_string)
 
print(json_string)

关于compile和load_model()的使用顺序

这一段落主要是为了解决我们fit、evaluate、predict之前还是之后使用compile。想要弄明白,首先我们要清楚compile在程序中是做什么的?都做了什么?

compile做什么?

compile定义了loss function损失函数、optimizer优化器和metrics度量。它与权重无关,也就是说compile并不会影响权重,不会影响之前训练的问题。

如果我们要训练模型或者评估模型evaluate,则需要compile,因为训练要使用损失函数和优化器,评估要使用度量方法;如果我们要预测,则没有必要compile模型。

是否需要多次编译?

除非我们要更改其中之一:损失函数、优化器 / 学习率、度量

又或者我们加载了尚未编译的模型。或者您的加载/保存方法没有考虑以前的编译。

再次compile的后果?

如果再次编译模型,将会丢失优化器状态.

这意味着您的训练在开始时会受到一点影响,直到调整学习率,动量等为止。但是绝对不会对重量造成损害(除非您的初始学习率如此之大,以至于第一次训练步骤疯狂地更改微调的权重)。

到此这篇关于Keras保存模型并载入模型继续训练的实现的文章就介绍到这了,更多相关Keras保存模型并加载模型内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python将ip地址转换成整数的方法
Mar 17 Python
Python输出9*9乘法表的方法
May 25 Python
Python中with及contextlib的用法详解
Jun 08 Python
python 并发编程 阻塞IO模型原理解析
Aug 20 Python
pygame实现贪吃蛇游戏(下)
Oct 29 Python
Django使用消息提示简单的弹出个对话框实例
Nov 15 Python
Python的对象传递与Copy函数使用详解
Dec 26 Python
浅析python标准库中的glob
Mar 13 Python
借助Paramiko通过Python实现linux远程登陆及sftp的操作
Mar 16 Python
Python xpath表达式如何实现数据处理
Jun 13 Python
python爬虫用scrapy获取影片的实例分析
Nov 23 Python
matplotlib绘制正余弦曲线图的实现
Feb 22 Python
TensorFlow2.0使用keras训练模型的实现
Feb 20 #Python
tensorflow2.0教程之Keras快速入门
Feb 20 #Python
在Pycharm中安装Pandas库方法(简单易懂)
Feb 20 #Python
Python3爬虫RedisDump的安装步骤
Feb 20 #Python
python爬取2021猫眼票房字体加密实例
Feb 19 #Python
Python之Sklearn使用入门教程
Feb 19 #Python
Python爬虫UA伪装爬取的实例讲解
Feb 19 #Python
You might like
PHP 内存缓存加速功能memcached安装与用法
2009/09/03 PHP
PHP获取163、gmail、126等邮箱联系人地址【已测试2009.10.10】
2009/10/11 PHP
PHP 时间日期操作实战
2011/08/26 PHP
ajax在joomla中的原生态应用代码
2012/07/19 PHP
解决php表单重复提交实现方法
2015/09/29 PHP
WordPress中获取所使用的模板的页面ID的简单方法
2015/12/31 PHP
php自定义中文字符串截取函数substr_for_gb2312及substr_for_utf8示例
2016/05/28 PHP
javascript编程起步(第四课)
2007/02/27 Javascript
分享Javascript中最常用的55个经典小技巧
2013/11/29 Javascript
js以及jquery实现手风琴效果
2020/04/17 Javascript
jQuery实现使用sort方法对json数据排序的方法
2018/04/17 jQuery
node链接mongodb数据库的方法详解【阿里云服务器环境ubuntu】
2019/03/07 Javascript
详解vue-cli中使用rem,vue自适应
2019/05/06 Javascript
谈谈IntersectionObserver懒加载的具体使用
2019/10/15 Javascript
React 父子组件通信的实现方法
2019/12/05 Javascript
解决js中的setInterval清空定时器不管用问题
2020/11/17 Javascript
Python使用scrapy采集时伪装成HTTP/1.1的方法
2015/04/08 Python
详解在Python程序中使用Cookie的教程
2015/04/30 Python
一张图带我们入门Python基础教程
2017/02/05 Python
Python 装饰器实现DRY(不重复代码)原则
2018/03/05 Python
人生苦短我用python python如何快速入门?
2018/03/12 Python
情人节快乐! python绘制漂亮玫瑰
2020/08/18 Python
python字符串循环左移
2019/03/08 Python
python中正则表达式与模式匹配
2019/05/07 Python
python队列Queue的详解
2019/05/10 Python
ubuntu 16.04下python版本切换的方法
2019/06/14 Python
将HTML5 Canvas的内容保存为图片借助toDataURL实现
2013/05/20 HTML / CSS
法国最大电子商务平台:Cdiscount
2018/03/13 全球购物
中国领先的汽车保养服务平台:途虎养车
2019/10/18 全球购物
沙特阿拉伯家用电器和电子产品购物网站:Sheta and Saif
2020/04/03 全球购物
Nobody Denim官网:购买高级女士牛仔裤
2021/03/15 全球购物
c/c++某大公司的两道笔试题
2014/02/02 面试题
高中语文教学反思
2014/01/16 职场文书
2014年最新离婚协议书范本
2014/10/11 职场文书
致百米运动员广播稿5篇
2014/10/13 职场文书
家访教师心得体会
2016/01/23 职场文书