Keras保存模型并载入模型继续训练的实现


Posted in Python onFebruary 20, 2021

我们以MNIST手写数字识别为例

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
 
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
 
# 创建模型,输入784个神经元,输出10个神经元
model = Sequential([
    Dense(units=10,input_dim=784,bias_initializer='one',activation='softmax')
  ])
 
# 定义优化器
sgd = SGD(lr=0.2)
 
# 定义优化器,loss function,训练过程中计算准确率
model.compile(
  optimizer = sgd,
  loss = 'mse',
  metrics=['accuracy'],
)
 
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=5)
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 保存模型
model.save('model.h5')  # HDF5文件,pip install h5py

Keras保存模型并载入模型继续训练的实现

Keras保存模型并载入模型继续训练的实现

载入初次训练的模型,再训练

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
from keras.models import load_model
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
 
# 载入模型
model = load_model('model.h5')
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=2)
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 保存参数,载入参数
model.save_weights('my_model_weights.h5')
model.load_weights('my_model_weights.h5')
# 保存网络结构,载入网络结构
from keras.models import model_from_json
json_string = model.to_json()
model = model_from_json(json_string)
 
print(json_string)

关于compile和load_model()的使用顺序

这一段落主要是为了解决我们fit、evaluate、predict之前还是之后使用compile。想要弄明白,首先我们要清楚compile在程序中是做什么的?都做了什么?

compile做什么?

compile定义了loss function损失函数、optimizer优化器和metrics度量。它与权重无关,也就是说compile并不会影响权重,不会影响之前训练的问题。

如果我们要训练模型或者评估模型evaluate,则需要compile,因为训练要使用损失函数和优化器,评估要使用度量方法;如果我们要预测,则没有必要compile模型。

是否需要多次编译?

除非我们要更改其中之一:损失函数、优化器 / 学习率、度量

又或者我们加载了尚未编译的模型。或者您的加载/保存方法没有考虑以前的编译。

再次compile的后果?

如果再次编译模型,将会丢失优化器状态.

这意味着您的训练在开始时会受到一点影响,直到调整学习率,动量等为止。但是绝对不会对重量造成损害(除非您的初始学习率如此之大,以至于第一次训练步骤疯狂地更改微调的权重)。

到此这篇关于Keras保存模型并载入模型继续训练的实现的文章就介绍到这了,更多相关Keras保存模型并加载模型内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
简单介绍利用TK在Python下进行GUI编程的教程
Apr 13 Python
Python中asyncore异步模块的用法及实现httpclient的实例
Jun 28 Python
python中子类继承父类的__init__方法实例
Dec 15 Python
python使用opencv在Windows下调用摄像头实现解析
Nov 26 Python
Python计算指定日期是今年的第几天(三种方法)
Mar 26 Python
Pycharm插件(Grep Console)自定义规则输出颜色日志的方法
May 27 Python
Python爬虫如何应对Cloudflare邮箱加密
Jun 24 Python
python脚本使用阿里云slb对恶意攻击进行封堵的实现
Feb 04 Python
python实现b站直播自动发送弹幕功能
Feb 20 Python
详解Python常用的魔法方法
Jun 03 Python
Python Pandas模块实现数据的统计分析的方法
Jun 24 Python
Python中的tkinter库简单案例详解
Jan 22 Python
TensorFlow2.0使用keras训练模型的实现
Feb 20 #Python
tensorflow2.0教程之Keras快速入门
Feb 20 #Python
在Pycharm中安装Pandas库方法(简单易懂)
Feb 20 #Python
Python3爬虫RedisDump的安装步骤
Feb 20 #Python
python爬取2021猫眼票房字体加密实例
Feb 19 #Python
Python之Sklearn使用入门教程
Feb 19 #Python
Python爬虫UA伪装爬取的实例讲解
Feb 19 #Python
You might like
PHP显示今天、今月、上月、今年的起点/终点时间戳的代码
2011/05/25 PHP
php pthreads多线程的安装与使用
2016/01/19 PHP
Laravel 5.4.36中session没有保存成功问题的解决
2018/02/19 PHP
js返回上一页并刷新代码整理
2012/12/21 Javascript
jQuery实现选中弹出窗口选择框内容后赋值给文本框的方法
2015/11/23 Javascript
盘点javascript 正则表达式中 中括号的【坑】
2016/03/16 Javascript
jQuery 操作input中radio的技巧
2016/07/18 Javascript
使用node.js搭建服务器
2017/05/20 Javascript
基于vue 实现token验证的实例代码
2017/12/14 Javascript
jquery实现左右轮播切换效果
2018/01/01 jQuery
javascript原生封装一个淡入淡出效果的函数测试实例代码
2018/03/19 Javascript
微信小程序用户拒绝授权的处理方法详解
2019/09/20 Javascript
JS使用H5实现图片预览功能
2019/09/30 Javascript
谈谈JavaScript中的函数
2020/09/08 Javascript
Python基本数据结构与用法详解【列表、元组、集合、字典】
2019/03/23 Python
django框架使用orm实现批量更新数据的方法
2019/06/21 Python
int在python中的含义以及用法
2019/06/27 Python
Python读取配置文件(config.ini)以及写入配置文件
2020/04/08 Python
Python日志logging模块功能与用法详解
2020/04/09 Python
python3用PyPDF2解析pdf文件,用正则匹配数据方式
2020/05/12 Python
Alba Moda德国网上商店:意大利时尚女装销售
2016/11/14 全球购物
FORZIERI福喜利中国官网:奢侈品购物梦工厂
2019/05/03 全球购物
采购内勤岗位职责
2013/12/10 职场文书
新法人代表任命书
2014/06/06 职场文书
五月的鲜花活动方案
2014/08/21 职场文书
优秀教师单行材料
2014/12/16 职场文书
经费申请报告范文
2015/05/18 职场文书
亮剑观后感600字
2015/06/05 职场文书
回门宴新娘答谢词
2015/09/29 职场文书
幼师自荐信范文(2016推荐篇)
2016/01/28 职场文书
2016年万圣节活动个人总结
2016/04/05 职场文书
导游词之丽江普济寺
2019/10/22 职场文书
Python实现文本文件拆分写入到多个文本文件的方法
2021/04/18 Python
教你使用Pandas直接核算Excel中快递费用
2021/05/12 Python
如何理解及使用Python闭包
2021/06/01 Python
MongoDB orm框架的注意事项及简单使用
2021/06/20 MongoDB