Keras保存模型并载入模型继续训练的实现


Posted in Python onFebruary 20, 2021

我们以MNIST手写数字识别为例

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
 
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
 
# 创建模型,输入784个神经元,输出10个神经元
model = Sequential([
    Dense(units=10,input_dim=784,bias_initializer='one',activation='softmax')
  ])
 
# 定义优化器
sgd = SGD(lr=0.2)
 
# 定义优化器,loss function,训练过程中计算准确率
model.compile(
  optimizer = sgd,
  loss = 'mse',
  metrics=['accuracy'],
)
 
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=5)
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 保存模型
model.save('model.h5')  # HDF5文件,pip install h5py

Keras保存模型并载入模型继续训练的实现

Keras保存模型并载入模型继续训练的实现

载入初次训练的模型,再训练

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
from keras.models import load_model
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
 
# 载入模型
model = load_model('model.h5')
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=2)
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 保存参数,载入参数
model.save_weights('my_model_weights.h5')
model.load_weights('my_model_weights.h5')
# 保存网络结构,载入网络结构
from keras.models import model_from_json
json_string = model.to_json()
model = model_from_json(json_string)
 
print(json_string)

关于compile和load_model()的使用顺序

这一段落主要是为了解决我们fit、evaluate、predict之前还是之后使用compile。想要弄明白,首先我们要清楚compile在程序中是做什么的?都做了什么?

compile做什么?

compile定义了loss function损失函数、optimizer优化器和metrics度量。它与权重无关,也就是说compile并不会影响权重,不会影响之前训练的问题。

如果我们要训练模型或者评估模型evaluate,则需要compile,因为训练要使用损失函数和优化器,评估要使用度量方法;如果我们要预测,则没有必要compile模型。

是否需要多次编译?

除非我们要更改其中之一:损失函数、优化器 / 学习率、度量

又或者我们加载了尚未编译的模型。或者您的加载/保存方法没有考虑以前的编译。

再次compile的后果?

如果再次编译模型,将会丢失优化器状态.

这意味着您的训练在开始时会受到一点影响,直到调整学习率,动量等为止。但是绝对不会对重量造成损害(除非您的初始学习率如此之大,以至于第一次训练步骤疯狂地更改微调的权重)。

到此这篇关于Keras保存模型并载入模型继续训练的实现的文章就介绍到这了,更多相关Keras保存模型并加载模型内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
简单介绍Python中的floor()方法
May 15 Python
举例讲解Python编程中对线程锁的使用
Jul 12 Python
Python实现按特定格式对文件进行读写的方法示例
Nov 30 Python
python生成器,可迭代对象,迭代器区别和联系
Feb 04 Python
pandas ix &iloc &loc的区别
Jan 10 Python
Python一行代码实现快速排序的方法
Apr 30 Python
python中for循环把字符串或者字典添加到列表的方法
Jul 20 Python
Python进度条的制作代码实例
Aug 31 Python
Python pyautogui模块实现鼠标键盘自动化方法详解
Feb 17 Python
python实现扫雷游戏
Mar 03 Python
Python爬虫之Selenium实现键盘事件
Dec 04 Python
python代码实现扫码关注公众号登录的实战
Nov 01 Python
TensorFlow2.0使用keras训练模型的实现
Feb 20 #Python
tensorflow2.0教程之Keras快速入门
Feb 20 #Python
在Pycharm中安装Pandas库方法(简单易懂)
Feb 20 #Python
Python3爬虫RedisDump的安装步骤
Feb 20 #Python
python爬取2021猫眼票房字体加密实例
Feb 19 #Python
Python之Sklearn使用入门教程
Feb 19 #Python
Python爬虫UA伪装爬取的实例讲解
Feb 19 #Python
You might like
php面向对象值单例模式
2016/05/03 PHP
CentOS 7.2 下编译安装PHP7.0.10+MySQL5.7.14+Nginx1.10.1的方法详解(mini版本)
2016/09/01 PHP
php实现的网页版剪刀石头布游戏示例
2016/11/25 PHP
PHP删除数组中指定下标的元素方法
2018/02/03 PHP
JS 图片缩放效果代码
2010/06/09 Javascript
jquery图片切换实例分析
2015/04/15 Javascript
js实现不重复导入的方法
2016/03/02 Javascript
jquery ajax局部加载方法详解(实现代码)
2016/05/12 Javascript
详解Angular 4.x NgIf 的用法
2017/05/22 Javascript
解决jquery appaend元素中id绑定事件失效的问题
2017/09/12 jQuery
Vue动态创建注册component的实例代码
2019/06/14 Javascript
Vue-Ant Design Vue-普通及自定义校验实例
2020/10/24 Javascript
[01:25]DOTA2自定义游戏灵园鬼域等你踏足
2015/10/30 DOTA
Python查询阿里巴巴关键字排名的方法
2015/07/08 Python
python 捕获shell脚本的输出结果实例
2017/01/04 Python
Python中生成Epoch的方法
2017/04/26 Python
Python hashlib模块用法实例分析
2018/06/12 Python
Django添加KindEditor富文本编辑器的使用
2018/10/24 Python
基于Modernizr 让网站进行优雅降级的分析
2013/04/21 HTML / CSS
HTML5未来发展趋势
2016/02/01 HTML / CSS
Html5应用程序缓存(Cache manifest)
2018/06/04 HTML / CSS
威尔逊皮革:Wilsons Leather
2018/12/07 全球购物
泰国时尚电商:POMELO Fashion
2020/03/11 全球购物
params有什么用
2016/03/01 面试题
酒店服务实习自我鉴定
2013/09/22 职场文书
计算机专业职业生涯规划范文
2014/01/19 职场文书
乔迁之喜主持词
2014/03/27 职场文书
科长竞争上岗演讲稿
2014/05/12 职场文书
给校长的建议书400字
2014/05/15 职场文书
工厂门卫的岗位职责
2014/07/27 职场文书
2014县政府领导班子对照检查材料思想汇报
2014/09/25 职场文书
大三学年自我鉴定范文(3篇)
2014/09/28 职场文书
长城导游词400字
2015/01/30 职场文书
2015年办公室人员工作总结
2015/05/15 职场文书
MySQL中的布尔值,怎么存储false或true
2021/06/04 MySQL
vscode中使用npm安装babel的方法
2021/08/02 Javascript