keras导入weights方式


Posted in Python onJune 12, 2020

keras源码engine中toplogy.py定义了加载权重的函数:

load_weights(self, filepath, by_name=False)

其中默认by_name为False,这时候加载权重按照网络拓扑结构加载,适合直接使用keras中自带的网络模型,如VGG16

VGG19/resnet50等,源码描述如下:

If `by_name` is False (default) weights are loaded
based on the network's topology, meaning the architecture
should be the same as when the weights were saved.
Note that layers that don't have weights are not taken
into account in the topological ordering, so adding or
removing layers is fine as long as they don't have weights.

若将by_name改为True则加载权重按照layer的name进行,layer的name相同时加载权重,适合用于改变了

模型的相关结构或增加了节点但利用了原网络的主体结构情况下使用,源码描述如下:

If `by_name` is True, weights are loaded into layers
only if they share the same name. This is useful
for fine-tuning or transfer-learning models where
some of the layers have changed.

在进行边缘检测时,利用VGG网络的主体结构,网络中增加反卷积层,这时加载权重应该使用

model.load_weights(filepath,by_name=True)

补充知识:Keras下实现mnist手写数字

之前一直在用tensorflow,被同学推荐来用keras了,把之前文档中的mnist手写数字数据集拿来练手,

代码如下。

import struct
import numpy as np
import os
 
import keras
from keras.models import Sequential 
from keras.layers import Dense
from keras.optimizers import SGD
 
def load_mnist(path, kind):
  labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
  images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)
  with open(labels_path, 'rb') as lbpath:
    magic, n = struct.unpack('>II', lbpath.read(8))
    labels = np.fromfile(lbpath, dtype=np.uint8)
  with open(images_path, 'rb') as imgpath:
    magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16))
    images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) #28*28=784
  return images, labels
 
#loading train and test data
X_train, Y_train = load_mnist('.\\data', kind='train')
X_test, Y_test = load_mnist('.\\data', kind='t10k')
 
#turn labels to one_hot code
Y_train_ohe = keras.utils.to_categorical(Y_train, num_classes=10)
 
#define models
model = Sequential()
 
model.add(Dense(input_dim=X_train.shape[1],output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=Y_train_ohe.shape[1],init='uniform',activation='softmax')) 
 
sgd = SGD(lr=0.001, decay=1e-7, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=["accuracy"])
 
#start training
model.fit(X_train,Y_train_ohe,epochs=50,batch_size=300,shuffle=True,verbose=1,validation_split=0.3)
 
#count accuracy
y_train_pred = model.predict_classes(X_train, verbose=0)
 
train_acc = np.sum(Y_train == y_train_pred, axis=0) / X_train.shape[0] 
print('Training accuracy: %.2f%%' % (train_acc * 100))
 
y_test_pred = model.predict_classes(X_test, verbose=0)
test_acc = np.sum(Y_test == y_test_pred, axis=0) / X_test.shape[0] 
print('Test accuracy: %.2f%%' % (test_acc * 100))

训练结果如下:

Epoch 45/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2174 - acc: 0.9380 - val_loss: 0.2341 - val_acc: 0.9323
Epoch 46/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2061 - acc: 0.9404 - val_loss: 0.2244 - val_acc: 0.9358
Epoch 47/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1994 - acc: 0.9413 - val_loss: 0.2295 - val_acc: 0.9347
Epoch 48/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2003 - acc: 0.9413 - val_loss: 0.2224 - val_acc: 0.9350
Epoch 49/50
42000/42000 [==============================] - 1s 18us/step - loss: 0.2013 - acc: 0.9417 - val_loss: 0.2248 - val_acc: 0.9359
Epoch 50/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1960 - acc: 0.9433 - val_loss: 0.2300 - val_acc: 0.9346
Training accuracy: 94.11%
Test accuracy: 93.61%

以上这篇keras导入weights方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现基于两张图片生成圆角图标效果的方法
Mar 26 Python
python中函数传参详解
Jul 03 Python
Python3多线程爬虫实例讲解代码
Jan 05 Python
python与sqlite3实现解密chrome cookie实例代码
Jan 20 Python
Tensorflow 利用tf.contrib.learn建立输入函数的方法
Feb 08 Python
pandas中的DataFrame按指定顺序输出所有列的方法
Apr 10 Python
Python打包模块wheel的使用方法与将python包发布到PyPI的方法详解
Feb 12 Python
解决python使用list()时总是报错的问题
May 05 Python
Python爬虫headers处理及网络超时问题解决方案
Jun 19 Python
什么是python的函数体
Jun 19 Python
python反扒机制的5种解决方法
Feb 06 Python
Python实现为PDF去除水印的示例代码
Apr 03 Python
keras读取h5文件load_weights、load代码操作
Jun 12 #Python
Python matplotlib 绘制双Y轴曲线图的示例代码
Jun 12 #Python
keras的siamese(孪生网络)实现案例
Jun 12 #Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
You might like
骨王战斗力在公会成员中排不进前五,却当选了会长,原因竟是这样
2020/03/02 日漫
无数据库的详细域名查询程序PHP版(1)
2006/10/09 PHP
php中批量替换文件名的实现代码
2011/07/20 PHP
新手学习PHP的一些基础知识分享
2011/07/27 PHP
解析PHP 5.5 新特性
2013/07/02 PHP
Zend Framework框架实现类似Google搜索分页效果
2016/11/25 PHP
得到文本框选中的文字,动态插入文字的js代码
2007/03/07 Javascript
SuperSlide2实现图片滚动特效
2014/06/20 Javascript
node.js中的dns.getServers方法使用说明
2014/12/08 Javascript
JS+CSS实现另类带提示效果的竖向导航菜单
2015/10/15 Javascript
jQuery获取元素父节点的方法
2016/06/21 Javascript
基于JS实现弹出一个隐藏的div窗口body页面变成灰色并且不可被编辑
2016/12/14 Javascript
简单实现jQuery级联菜单
2017/01/09 Javascript
jQuery插件FusionCharts绘制的2D双柱状图效果示例【附demo源码】
2017/05/13 jQuery
详解nodeJS之二进制buffer对象
2017/06/03 NodeJs
深究AngularJS之ui-router详解
2017/06/13 Javascript
基于jquery日历价格、库存等设置插件
2020/07/05 jQuery
利用SpringMVC过滤器解决vue跨域请求的问题
2018/02/10 Javascript
javaScript中"=="和"==="的区别详解
2018/03/16 Javascript
利用原生JavaScript实现造日历轮子实例代码
2019/05/08 Javascript
javascript 原型与原型链的理解及应用实例分析
2020/02/10 Javascript
Python复制文件操作实例详解
2015/11/10 Python
python getopt详解及简单实例
2016/12/30 Python
python实现录音小程序
2020/10/26 Python
Python 调用 zabbix api的方法示例
2019/01/06 Python
10 分钟快速入门 Python3的教程
2019/01/29 Python
python的移位操作实现详解
2019/08/21 Python
解决pyPdf和pyPdf2在合并pdf时出现异常的问题
2020/04/03 Python
python中sort sorted reverse reversed函数的区别说明
2020/05/11 Python
HTML5+JS实现俄罗斯方块原理及具体步骤
2013/11/29 HTML / CSS
Lookfantastic俄罗斯:欧洲在线化妆品零售商
2019/08/06 全球购物
模具设计与制造专业推荐信
2014/02/16 职场文书
酒店端午节促销方案
2014/02/18 职场文书
六一儿童节开幕词
2015/01/29 职场文书
企业财务人员岗位职责
2015/04/14 职场文书
保留意见审计报告
2015/06/05 职场文书