Keras保存模型并载入模型继续训练的实现


Posted in Python onFebruary 20, 2021

我们以MNIST手写数字识别为例

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
 
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
 
# 创建模型,输入784个神经元,输出10个神经元
model = Sequential([
    Dense(units=10,input_dim=784,bias_initializer='one',activation='softmax')
  ])
 
# 定义优化器
sgd = SGD(lr=0.2)
 
# 定义优化器,loss function,训练过程中计算准确率
model.compile(
  optimizer = sgd,
  loss = 'mse',
  metrics=['accuracy'],
)
 
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=5)
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 保存模型
model.save('model.h5')  # HDF5文件,pip install h5py

Keras保存模型并载入模型继续训练的实现

Keras保存模型并载入模型继续训练的实现

载入初次训练的模型,再训练

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
from keras.models import load_model
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
 
# 载入模型
model = load_model('model.h5')
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=2)
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 保存参数,载入参数
model.save_weights('my_model_weights.h5')
model.load_weights('my_model_weights.h5')
# 保存网络结构,载入网络结构
from keras.models import model_from_json
json_string = model.to_json()
model = model_from_json(json_string)
 
print(json_string)

关于compile和load_model()的使用顺序

这一段落主要是为了解决我们fit、evaluate、predict之前还是之后使用compile。想要弄明白,首先我们要清楚compile在程序中是做什么的?都做了什么?

compile做什么?

compile定义了loss function损失函数、optimizer优化器和metrics度量。它与权重无关,也就是说compile并不会影响权重,不会影响之前训练的问题。

如果我们要训练模型或者评估模型evaluate,则需要compile,因为训练要使用损失函数和优化器,评估要使用度量方法;如果我们要预测,则没有必要compile模型。

是否需要多次编译?

除非我们要更改其中之一:损失函数、优化器 / 学习率、度量

又或者我们加载了尚未编译的模型。或者您的加载/保存方法没有考虑以前的编译。

再次compile的后果?

如果再次编译模型,将会丢失优化器状态.

这意味着您的训练在开始时会受到一点影响,直到调整学习率,动量等为止。但是绝对不会对重量造成损害(除非您的初始学习率如此之大,以至于第一次训练步骤疯狂地更改微调的权重)。

到此这篇关于Keras保存模型并载入模型继续训练的实现的文章就介绍到这了,更多相关Keras保存模型并加载模型内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
使用python分析git log日志示例
Feb 27 Python
python网络编程之TCP通信实例和socketserver框架使用例子
Apr 25 Python
在Django框架中设置语言偏好的教程
Jul 27 Python
利用Python命令行传递实例化对象的方法
Nov 02 Python
Python实现爬虫设置代理IP和伪装成浏览器的方法分享
May 07 Python
jenkins配置python脚本定时任务过程图解
Oct 29 Python
Python FTP文件定时自动下载实现过程解析
Nov 12 Python
PyCharm 2020.1版安装破解注册码永久激活(激活到2089年)
Sep 24 Python
Python通过递归函数输出嵌套列表元素
Oct 15 Python
Python中全局变量和局部变量的理解与区别
Feb 07 Python
告别网页搜索!教你用python实现一款属于自己的翻译词典软件
Jun 03 Python
Pandas 稀疏数据结构的实现
Jul 25 Python
TensorFlow2.0使用keras训练模型的实现
Feb 20 #Python
tensorflow2.0教程之Keras快速入门
Feb 20 #Python
在Pycharm中安装Pandas库方法(简单易懂)
Feb 20 #Python
Python3爬虫RedisDump的安装步骤
Feb 20 #Python
python爬取2021猫眼票房字体加密实例
Feb 19 #Python
Python之Sklearn使用入门教程
Feb 19 #Python
Python爬虫UA伪装爬取的实例讲解
Feb 19 #Python
You might like
php session_start()出错原因分析及解决方法
2013/10/28 PHP
thinkPHP商城公告功能开发问题分析
2016/12/01 PHP
在laravel-admin中列表中禁止某行编辑、删除的方法
2019/10/03 PHP
Laravel validate error处理,ajax,json示例
2019/10/25 PHP
javascript 另一种图片滚动切换效果思路
2012/04/20 Javascript
js实现网页随机切换背景图片的方法
2014/11/01 Javascript
手机浏览器 后退按钮强制刷新页面方法总结
2016/10/09 Javascript
如何使用jquery实现文字上下滚动效果
2016/10/12 Javascript
Angular2入门--架构总览
2017/03/29 Javascript
关于bootstrap日期转化,bootstrap-editable的简单使用,bootstrap-fileinput的使用详解
2017/05/12 Javascript
详解React开发必不可少的eslint配置
2018/02/05 Javascript
JavaScript 五大常见函数
2018/03/23 Javascript
实例讲解JavaScript预编译流程
2019/01/24 Javascript
Vue Cli 3项目使用融云IM实现聊天功能的方法
2019/04/19 Javascript
详解vue的双向绑定原理及实现
2019/05/05 Javascript
详解基于Vue/React项目的移动端适配方案
2019/08/23 Javascript
JavaScript数组去重实现方法小结
2020/01/17 Javascript
Vue中keep-alive组件作用详解
2020/02/04 Javascript
Python深度优先算法生成迷宫
2018/01/22 Python
pandas把dataframe转成Series,改变列中值的类型方法
2018/04/10 Python
浅析python中numpy包中的argsort函数的使用
2018/08/30 Python
详解python分布式进程
2018/10/08 Python
python argparse传入布尔参数false不生效的解决
2020/04/20 Python
python多进程使用函数封装实例
2020/05/02 Python
python模拟点击在ios中实现的实例讲解
2020/11/26 Python
详解Canvas 跨域脱坑实践
2018/11/07 HTML / CSS
美国眼镜网站:EyeBuyDirect
2017/04/13 全球购物
地球鞋加拿大官网:Earth Shoes Canada
2020/11/17 全球购物
如果NULL和0作为空指针常数是等价的,那我到底该用哪一个
2014/09/16 面试题
用Python写一个for循环的例子
2016/07/19 面试题
六年级数学教学反思
2014/02/03 职场文书
大学生军训感想
2014/02/16 职场文书
最新离婚协议书范本
2014/08/19 职场文书
助学金申请书该怎么写?
2019/07/16 职场文书
pytorch常用数据类型所占字节数对照表一览
2021/05/17 Python
人民币符号
2022/02/17 杂记