Keras保存模型并载入模型继续训练的实现


Posted in Python onFebruary 20, 2021

我们以MNIST手写数字识别为例

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
 
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
 
# 创建模型,输入784个神经元,输出10个神经元
model = Sequential([
    Dense(units=10,input_dim=784,bias_initializer='one',activation='softmax')
  ])
 
# 定义优化器
sgd = SGD(lr=0.2)
 
# 定义优化器,loss function,训练过程中计算准确率
model.compile(
  optimizer = sgd,
  loss = 'mse',
  metrics=['accuracy'],
)
 
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=5)
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 保存模型
model.save('model.h5')  # HDF5文件,pip install h5py

Keras保存模型并载入模型继续训练的实现

Keras保存模型并载入模型继续训练的实现

载入初次训练的模型,再训练

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
from keras.models import load_model
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
 
# 载入模型
model = load_model('model.h5')
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=2)
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 保存参数,载入参数
model.save_weights('my_model_weights.h5')
model.load_weights('my_model_weights.h5')
# 保存网络结构,载入网络结构
from keras.models import model_from_json
json_string = model.to_json()
model = model_from_json(json_string)
 
print(json_string)

关于compile和load_model()的使用顺序

这一段落主要是为了解决我们fit、evaluate、predict之前还是之后使用compile。想要弄明白,首先我们要清楚compile在程序中是做什么的?都做了什么?

compile做什么?

compile定义了loss function损失函数、optimizer优化器和metrics度量。它与权重无关,也就是说compile并不会影响权重,不会影响之前训练的问题。

如果我们要训练模型或者评估模型evaluate,则需要compile,因为训练要使用损失函数和优化器,评估要使用度量方法;如果我们要预测,则没有必要compile模型。

是否需要多次编译?

除非我们要更改其中之一:损失函数、优化器 / 学习率、度量

又或者我们加载了尚未编译的模型。或者您的加载/保存方法没有考虑以前的编译。

再次compile的后果?

如果再次编译模型,将会丢失优化器状态.

这意味着您的训练在开始时会受到一点影响,直到调整学习率,动量等为止。但是绝对不会对重量造成损害(除非您的初始学习率如此之大,以至于第一次训练步骤疯狂地更改微调的权重)。

到此这篇关于Keras保存模型并载入模型继续训练的实现的文章就介绍到这了,更多相关Keras保存模型并加载模型内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
一个计算身份证号码校验位的Python小程序
Aug 15 Python
python中的多重继承实例讲解
Sep 28 Python
Python自动发邮件脚本
Mar 31 Python
使用python实现个性化词云的方法
Jun 16 Python
python学习教程之Numpy和Pandas的使用
Sep 11 Python
python3实现二叉树的遍历与递归算法解析(小结)
Jul 03 Python
Python使用Pandas库实现MySQL数据库的读写
Jul 06 Python
python UDP(udp)协议发送和接收的实例
Jul 22 Python
python os.fork() 循环输出方法
Aug 08 Python
Python的垃圾回收机制详解
Aug 28 Python
python 比较字典value的最大值的几种方法
Apr 17 Python
python3排序的实例方法
Oct 20 Python
TensorFlow2.0使用keras训练模型的实现
Feb 20 #Python
tensorflow2.0教程之Keras快速入门
Feb 20 #Python
在Pycharm中安装Pandas库方法(简单易懂)
Feb 20 #Python
Python3爬虫RedisDump的安装步骤
Feb 20 #Python
python爬取2021猫眼票房字体加密实例
Feb 19 #Python
Python之Sklearn使用入门教程
Feb 19 #Python
Python爬虫UA伪装爬取的实例讲解
Feb 19 #Python
You might like
php下正则来匹配dede模板标签的代码
2010/08/21 PHP
javascript:FF/Chrome与IE动态加载元素的区别说明
2014/01/26 Javascript
jQuery on()方法使用技巧详解
2015/04/16 Javascript
百度地图API之百度地图退拽标记点获取经纬度的实现代码
2017/01/12 Javascript
jQuery Ajax前后端使用JSON进行交互示例
2017/03/17 Javascript
Angular2学习教程之TemplateRef和ViewContainerRef详解
2017/05/25 Javascript
Angular实现的table表格排序功能完整示例
2017/12/22 Javascript
vue-cli脚手架config目录下index.js配置文件的方法
2018/03/13 Javascript
Vue无限滑动周选择日期的组件的示例代码
2018/07/18 Javascript
Vue axios全局拦截 get请求、post请求、配置请求的实例代码
2018/11/28 Javascript
对Layer弹窗使用及返回数据接收的实例详解
2019/09/26 Javascript
JavaScript设计模型Iterator实例解析
2020/01/22 Javascript
[09:13]2014DOTA2国际邀请赛 中国区预选赛coser表演
2014/05/23 DOTA
Python制作简易注册登录系统
2016/12/15 Python
Python实现公历(阳历)转农历(阴历)的方法示例
2017/08/22 Python
Python实现朴素贝叶斯分类器的方法详解
2018/07/04 Python
使用python接入微信聊天机器人
2020/03/31 Python
用Python徒手撸一个股票回测框架搭建【推荐】
2019/08/05 Python
python 实现让字典的value 成为列表
2019/12/16 Python
windows10 pycharm下安装pyltp库和加载模型实现语义角色标注的示例代码
2020/05/07 Python
澳大利亚自然和有机的健康美容产品一站式商店:Ziani Beauty
2017/12/28 全球购物
Puma印度官网:德国运动品牌
2019/10/06 全球购物
俄语地区最大的中国商品在线购物网站之一:Umka Mall
2019/11/03 全球购物
乌克兰品牌化妆品和香水在线商店:Bomond
2020/01/14 全球购物
个人自我鉴定范文
2013/10/04 职场文书
我的中国梦演讲稿初中篇
2014/08/19 职场文书
小学生推普周国旗下讲话稿
2014/09/21 职场文书
学习党的群众路线教育实践活动剖析材料
2014/10/13 职场文书
学院党委班子四风问题自查报告及整改措施
2014/10/25 职场文书
小学生毕业评语
2014/12/26 职场文书
2015年度个人思想工作总结
2015/04/08 职场文书
售后服务质量承诺书
2015/04/29 职场文书
劳动模范获奖感言
2015/07/31 职场文书
2019升学宴主持词范本5篇
2019/10/09 职场文书
JS中如何优雅的使用async await详解
2021/10/05 Javascript
Mysql忘记密码解决方法
2022/02/12 MySQL