Keras保存模型并载入模型继续训练的实现


Posted in Python onFebruary 20, 2021

我们以MNIST手写数字识别为例

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
 
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
 
# 创建模型,输入784个神经元,输出10个神经元
model = Sequential([
    Dense(units=10,input_dim=784,bias_initializer='one',activation='softmax')
  ])
 
# 定义优化器
sgd = SGD(lr=0.2)
 
# 定义优化器,loss function,训练过程中计算准确率
model.compile(
  optimizer = sgd,
  loss = 'mse',
  metrics=['accuracy'],
)
 
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=5)
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 保存模型
model.save('model.h5')  # HDF5文件,pip install h5py

Keras保存模型并载入模型继续训练的实现

Keras保存模型并载入模型继续训练的实现

载入初次训练的模型,再训练

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
from keras.models import load_model
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
 
# 载入模型
model = load_model('model.h5')
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=2)
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 保存参数,载入参数
model.save_weights('my_model_weights.h5')
model.load_weights('my_model_weights.h5')
# 保存网络结构,载入网络结构
from keras.models import model_from_json
json_string = model.to_json()
model = model_from_json(json_string)
 
print(json_string)

关于compile和load_model()的使用顺序

这一段落主要是为了解决我们fit、evaluate、predict之前还是之后使用compile。想要弄明白,首先我们要清楚compile在程序中是做什么的?都做了什么?

compile做什么?

compile定义了loss function损失函数、optimizer优化器和metrics度量。它与权重无关,也就是说compile并不会影响权重,不会影响之前训练的问题。

如果我们要训练模型或者评估模型evaluate,则需要compile,因为训练要使用损失函数和优化器,评估要使用度量方法;如果我们要预测,则没有必要compile模型。

是否需要多次编译?

除非我们要更改其中之一:损失函数、优化器 / 学习率、度量

又或者我们加载了尚未编译的模型。或者您的加载/保存方法没有考虑以前的编译。

再次compile的后果?

如果再次编译模型,将会丢失优化器状态.

这意味着您的训练在开始时会受到一点影响,直到调整学习率,动量等为止。但是绝对不会对重量造成损害(除非您的初始学习率如此之大,以至于第一次训练步骤疯狂地更改微调的权重)。

到此这篇关于Keras保存模型并载入模型继续训练的实现的文章就介绍到这了,更多相关Keras保存模型并加载模型内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Django框架下在URLconf中指定视图缓存的方法
Jul 23 Python
python+django快速实现文件上传
Oct 24 Python
Python基于matplotlib绘制栈式直方图的方法示例
Aug 09 Python
简单学习Python多进程Multiprocessing
Aug 29 Python
Python request设置HTTPS代理代码解析
Feb 12 Python
python的pip安装以及使用教程
Sep 18 Python
浅谈python的深浅拷贝以及fromkeys的用法
Mar 08 Python
Python Pandas实现数据分组求平均值并填充nan的示例
Jul 04 Python
Python之虚拟环境virtualenv,pipreqs生成项目依赖第三方包的方法
Jul 23 Python
Python搭建代理IP池实现接口设置与整体调度
Oct 27 Python
pyecharts绘制中国2020肺炎疫情地图的实例代码
Feb 12 Python
scrapy与selenium结合爬取数据(爬取动态网站)的示例代码
Sep 28 Python
TensorFlow2.0使用keras训练模型的实现
Feb 20 #Python
tensorflow2.0教程之Keras快速入门
Feb 20 #Python
在Pycharm中安装Pandas库方法(简单易懂)
Feb 20 #Python
Python3爬虫RedisDump的安装步骤
Feb 20 #Python
python爬取2021猫眼票房字体加密实例
Feb 19 #Python
Python之Sklearn使用入门教程
Feb 19 #Python
Python爬虫UA伪装爬取的实例讲解
Feb 19 #Python
You might like
php allow_url_include的应用和解释
2010/04/22 PHP
PHP持久连接mysql_pconnect()函数使用介绍
2012/02/05 PHP
PHP实现搜索地理位置及计算两点地理位置间距离的实例
2016/01/08 PHP
向fckeditor编辑器插入指定代码的方法
2007/05/25 Javascript
JQuery为textarea添加maxlength属性的代码
2010/04/07 Javascript
JavaScript的继承的封装介绍
2013/10/15 Javascript
js取消单选按钮选中示例代码
2013/11/14 Javascript
详解JavaScript 中的 replace 方法
2016/01/01 Javascript
jQuery获取选中单选按钮radio的值
2016/12/27 Javascript
微信小程序-获得用户输入内容
2017/02/13 Javascript
js中数组插入、删除元素操作的方法
2017/02/15 Javascript
Vue.js 2.0中select级联下拉框实例
2017/03/06 Javascript
详解如何优雅地在React项目中使用Redux
2017/12/28 Javascript
JS实现前端页面的搜索功能
2018/06/12 Javascript
在vue中读取本地Json文件的方法
2018/09/06 Javascript
jQuery插件实现的日历功能示例【附源码下载】
2018/09/07 jQuery
微信自定义分享链接信息(标题,图片和内容)实现过程详解
2019/09/04 Javascript
Vue插件之滑动验证码用法详解
2020/04/05 Javascript
Vue3不支持Filters过滤器的问题
2020/09/24 Javascript
Python Tkinter GUI编程入门介绍
2015/03/10 Python
Python 爬虫爬取指定博客的所有文章
2016/02/17 Python
在win和Linux系统中python命令行运行的不同
2016/07/03 Python
Python Gluon参数和模块命名操作教程
2019/12/18 Python
PyTorch学习:动态图和静态图的例子
2020/01/06 Python
Python用5行代码实现批量抠图的示例代码
2020/04/14 Python
PyCharm 2020.2 安装详细教程
2020/09/25 Python
戴森美国官网:Dyson美国
2016/09/11 全球购物
Staples加拿大官方网站:办公用品一站式采购
2016/09/25 全球购物
什么是虚拟内存?虚拟内存有什么优势?
2016/02/09 面试题
个人自我评价和职业目标
2014/01/24 职场文书
食品销售计划书
2014/04/26 职场文书
教师演讲稿开场白
2014/08/25 职场文书
小班下学期幼儿评语
2014/12/30 职场文书
学困生转化工作总结
2015/08/13 职场文书
nginx如何将http访问的网站改成https访问
2021/03/31 Servers
Java并发编程之详解CyclicBarrier线程同步
2021/06/23 Java/Android