Keras保存模型并载入模型继续训练的实现


Posted in Python onFebruary 20, 2021

我们以MNIST手写数字识别为例

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
 
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
 
# 创建模型,输入784个神经元,输出10个神经元
model = Sequential([
    Dense(units=10,input_dim=784,bias_initializer='one',activation='softmax')
  ])
 
# 定义优化器
sgd = SGD(lr=0.2)
 
# 定义优化器,loss function,训练过程中计算准确率
model.compile(
  optimizer = sgd,
  loss = 'mse',
  metrics=['accuracy'],
)
 
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=5)
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 保存模型
model.save('model.h5')  # HDF5文件,pip install h5py

Keras保存模型并载入模型继续训练的实现

Keras保存模型并载入模型继续训练的实现

载入初次训练的模型,再训练

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
from keras.models import load_model
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
 
# 载入模型
model = load_model('model.h5')
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=2)
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 保存参数,载入参数
model.save_weights('my_model_weights.h5')
model.load_weights('my_model_weights.h5')
# 保存网络结构,载入网络结构
from keras.models import model_from_json
json_string = model.to_json()
model = model_from_json(json_string)
 
print(json_string)

关于compile和load_model()的使用顺序

这一段落主要是为了解决我们fit、evaluate、predict之前还是之后使用compile。想要弄明白,首先我们要清楚compile在程序中是做什么的?都做了什么?

compile做什么?

compile定义了loss function损失函数、optimizer优化器和metrics度量。它与权重无关,也就是说compile并不会影响权重,不会影响之前训练的问题。

如果我们要训练模型或者评估模型evaluate,则需要compile,因为训练要使用损失函数和优化器,评估要使用度量方法;如果我们要预测,则没有必要compile模型。

是否需要多次编译?

除非我们要更改其中之一:损失函数、优化器 / 学习率、度量

又或者我们加载了尚未编译的模型。或者您的加载/保存方法没有考虑以前的编译。

再次compile的后果?

如果再次编译模型,将会丢失优化器状态.

这意味着您的训练在开始时会受到一点影响,直到调整学习率,动量等为止。但是绝对不会对重量造成损害(除非您的初始学习率如此之大,以至于第一次训练步骤疯狂地更改微调的权重)。

到此这篇关于Keras保存模型并载入模型继续训练的实现的文章就介绍到这了,更多相关Keras保存模型并加载模型内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python基础教程之自定义函数介绍
Aug 29 Python
Python中函数eval和ast.literal_eval的区别详解
Aug 10 Python
一文总结学习Python的14张思维导图
Oct 17 Python
python导入csv文件出现SyntaxError问题分析
Dec 15 Python
python实现寻找最长回文子序列的方法
Jun 02 Python
Python高级编程之继承问题详解(super与mro)
Nov 19 Python
基于SpringBoot构造器注入循环依赖及解决方式
Apr 26 Python
Django ORM实现按天获取数据去重求和例子
May 18 Python
python变量的作用域是什么
May 26 Python
Python基于正则表达式实现计算器功能
Jul 13 Python
利用python如何实现猫捉老鼠小游戏
Dec 04 Python
Pytest之测试命名规则的使用
Apr 16 Python
TensorFlow2.0使用keras训练模型的实现
Feb 20 #Python
tensorflow2.0教程之Keras快速入门
Feb 20 #Python
在Pycharm中安装Pandas库方法(简单易懂)
Feb 20 #Python
Python3爬虫RedisDump的安装步骤
Feb 20 #Python
python爬取2021猫眼票房字体加密实例
Feb 19 #Python
Python之Sklearn使用入门教程
Feb 19 #Python
Python爬虫UA伪装爬取的实例讲解
Feb 19 #Python
You might like
采用header定义为文件然后readfile下载(隐藏下载地址)
2014/01/31 PHP
些很实用且必用的小脚本代码
2006/06/26 Javascript
jQuery中Form相关知识汇总
2015/01/06 Javascript
javascript中几个容易混淆的概念总结
2015/04/14 Javascript
js实现延时加载Flash的方法
2015/11/26 Javascript
jquery 点击元素后,滚动条滚动至该元素位置的方法
2016/08/05 Javascript
Vue弹出菜单功能的实现代码
2018/09/12 Javascript
在Vue项目中使用Typescript的实现
2019/12/19 Javascript
es6 super关键字的理解与应用实例分析
2020/02/15 Javascript
微信小程序实现点击导航标签滚动定位到对应位置
2020/11/19 Javascript
python下函数参数的传递(参数带星号的说明)
2010/09/19 Python
python+django加载静态网页模板解析
2017/12/12 Python
python机器学习实战之树回归详解
2017/12/20 Python
对Tensorflow中的变量初始化函数详解
2018/07/27 Python
使用Selenium破解新浪微博的四宫格验证码
2018/10/19 Python
pyqt弹出新对话框,以及关闭对话框获取数据的实例
2019/06/18 Python
python实现人工蜂群算法
2020/09/18 Python
Python在centos7.6上安装python3.9的详细教程(默认python版本为2.7.5)
2020/10/15 Python
pycharm + django跨域无提示的解决方法
2020/12/06 Python
如何利用python生成MD5并去重
2020/12/07 Python
python 如何把docker-compose.yaml导入到数据库相关条目里
2021/01/15 Python
与UNIX有关的几个名词
2015/09/17 面试题
体育专业学生自我评价范文
2014/01/17 职场文书
优秀党员主要事迹
2014/01/19 职场文书
会计专业自我鉴定
2014/02/10 职场文书
承诺书怎么写
2014/03/26 职场文书
终止劳动合同协议书
2014/04/14 职场文书
电子商务系毕业生自荐信
2014/05/29 职场文书
学校2014重阳节活动策划方案
2014/09/16 职场文书
2014最新预备党员思想汇报范文:中国梦,我的梦
2014/10/25 职场文书
学习党章的体会
2014/11/07 职场文书
解除劳动合同证明书模板
2014/11/20 职场文书
党员个人总结自评
2015/02/14 职场文书
小孩不笨观后感
2015/06/03 职场文书
超市主管竞聘书
2015/09/15 职场文书
TV动画《八十龟酱观察日记》第四季宣传PV公布
2022/04/06 日漫