Keras保存模型并载入模型继续训练的实现


Posted in Python onFebruary 20, 2021

我们以MNIST手写数字识别为例

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
 
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
 
# 创建模型,输入784个神经元,输出10个神经元
model = Sequential([
    Dense(units=10,input_dim=784,bias_initializer='one',activation='softmax')
  ])
 
# 定义优化器
sgd = SGD(lr=0.2)
 
# 定义优化器,loss function,训练过程中计算准确率
model.compile(
  optimizer = sgd,
  loss = 'mse',
  metrics=['accuracy'],
)
 
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=5)
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 保存模型
model.save('model.h5')  # HDF5文件,pip install h5py

Keras保存模型并载入模型继续训练的实现

Keras保存模型并载入模型继续训练的实现

载入初次训练的模型,再训练

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
from keras.models import load_model
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
 
# 载入模型
model = load_model('model.h5')
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=2)
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 保存参数,载入参数
model.save_weights('my_model_weights.h5')
model.load_weights('my_model_weights.h5')
# 保存网络结构,载入网络结构
from keras.models import model_from_json
json_string = model.to_json()
model = model_from_json(json_string)
 
print(json_string)

关于compile和load_model()的使用顺序

这一段落主要是为了解决我们fit、evaluate、predict之前还是之后使用compile。想要弄明白,首先我们要清楚compile在程序中是做什么的?都做了什么?

compile做什么?

compile定义了loss function损失函数、optimizer优化器和metrics度量。它与权重无关,也就是说compile并不会影响权重,不会影响之前训练的问题。

如果我们要训练模型或者评估模型evaluate,则需要compile,因为训练要使用损失函数和优化器,评估要使用度量方法;如果我们要预测,则没有必要compile模型。

是否需要多次编译?

除非我们要更改其中之一:损失函数、优化器 / 学习率、度量

又或者我们加载了尚未编译的模型。或者您的加载/保存方法没有考虑以前的编译。

再次compile的后果?

如果再次编译模型,将会丢失优化器状态.

这意味着您的训练在开始时会受到一点影响,直到调整学习率,动量等为止。但是绝对不会对重量造成损害(除非您的初始学习率如此之大,以至于第一次训练步骤疯狂地更改微调的权重)。

到此这篇关于Keras保存模型并载入模型继续训练的实现的文章就介绍到这了,更多相关Keras保存模型并加载模型内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python3基础之基本运算符概述
Aug 13 Python
Python实现简单的可逆加密程序实例
Mar 05 Python
python 转换 Javascript %u 字符串为python unicode的代码
Sep 06 Python
听歌识曲--用python实现一个音乐检索器的功能
Nov 15 Python
Python复制Word内容并使用格式设字体与大小实例代码
Jan 22 Python
Python3.5.3下配置opencv3.2.0的操作方法
Apr 02 Python
Python中的并发处理之asyncio包使用的详解
Apr 03 Python
Python读取英文文件并记录每个单词出现次数后降序输出示例
Jun 28 Python
python3获取当前目录的实现方法
Jul 29 Python
Python基于yield遍历多个可迭代对象
Mar 12 Python
Python多线程正确用法实例解析
May 30 Python
教你怎么用python爬取爱奇艺热门电影
May 20 Python
TensorFlow2.0使用keras训练模型的实现
Feb 20 #Python
tensorflow2.0教程之Keras快速入门
Feb 20 #Python
在Pycharm中安装Pandas库方法(简单易懂)
Feb 20 #Python
Python3爬虫RedisDump的安装步骤
Feb 20 #Python
python爬取2021猫眼票房字体加密实例
Feb 19 #Python
Python之Sklearn使用入门教程
Feb 19 #Python
Python爬虫UA伪装爬取的实例讲解
Feb 19 #Python
You might like
也谈截取首页新闻 - 范例
2006/10/09 PHP
兼容PHP和Java的des加密解密代码分享
2014/06/26 PHP
详解Yii2 定制表单输入字段的标签和样式
2017/01/04 PHP
laravel实现批量更新多条记录的方法示例
2017/10/22 PHP
PHP使用openssl扩展实现加解密方法示例
2020/02/20 PHP
js 判断文件类型并控制表单提交示例代码
2013/11/14 Javascript
jquery自动将form表单封装成json的具体实现
2014/03/17 Javascript
onmouseover事件和onmouseout事件全面理解
2016/08/15 Javascript
DOM操作原生js 的bug,使用jQuery 可以消除的解决方法
2016/09/04 Javascript
JQuery页面随滚动条动态加载效果的简单实现(推荐)
2017/02/08 Javascript
基于js中this和event 的区别(详解)
2017/10/24 Javascript
JavaScript中利用Array filter() 方法压缩稀疏数组
2018/02/24 Javascript
JS使用百度地图API自动获取地址和经纬度操作示例
2019/04/16 Javascript
layui扩展上传组件模拟进度条的方法
2019/09/23 Javascript
[03:53]2016国际邀请赛中国区预选赛第三日TOP10精彩集锦
2016/06/29 DOTA
Python3实现将文件归档到zip文件及从zip文件中读取数据的方法
2015/05/22 Python
Python实现递归遍历文件夹并删除文件
2016/04/18 Python
python实现实时监控文件的方法
2016/08/26 Python
Python命令启动Web服务器实例详解
2017/02/23 Python
Python3 循环语句(for、while、break、range等)
2017/11/20 Python
django 在原有表格添加或删除字段的实例
2018/05/27 Python
24式加速你的Python(小结)
2019/06/13 Python
python打开使用的方法
2019/09/30 Python
python 通过手机号识别出对应的微信性别(实例代码)
2019/12/22 Python
Numpy与Pytorch 矩阵操作方式
2019/12/27 Python
Python reversed函数及使用方法解析
2020/03/17 Python
python从Oracle读取数据生成图表
2020/10/14 Python
Python经典五人分鱼实例讲解
2021/01/04 Python
网络工程师面试(三木通信技术有限公司)
2013/06/05 面试题
工商学院毕业生自荐信
2013/11/12 职场文书
英语专业学生的自我评价
2013/12/30 职场文书
求职信需要的五点内容
2014/02/01 职场文书
人事专员工作职责
2014/02/22 职场文书
个人总结怎么写
2015/02/26 职场文书
大学生自荐书范文
2015/03/05 职场文书
思品教学工作总结
2015/08/10 职场文书