keras导入weights方式


Posted in Python onJune 12, 2020

keras源码engine中toplogy.py定义了加载权重的函数:

load_weights(self, filepath, by_name=False)

其中默认by_name为False,这时候加载权重按照网络拓扑结构加载,适合直接使用keras中自带的网络模型,如VGG16

VGG19/resnet50等,源码描述如下:

If `by_name` is False (default) weights are loaded
based on the network's topology, meaning the architecture
should be the same as when the weights were saved.
Note that layers that don't have weights are not taken
into account in the topological ordering, so adding or
removing layers is fine as long as they don't have weights.

若将by_name改为True则加载权重按照layer的name进行,layer的name相同时加载权重,适合用于改变了

模型的相关结构或增加了节点但利用了原网络的主体结构情况下使用,源码描述如下:

If `by_name` is True, weights are loaded into layers
only if they share the same name. This is useful
for fine-tuning or transfer-learning models where
some of the layers have changed.

在进行边缘检测时,利用VGG网络的主体结构,网络中增加反卷积层,这时加载权重应该使用

model.load_weights(filepath,by_name=True)

补充知识:Keras下实现mnist手写数字

之前一直在用tensorflow,被同学推荐来用keras了,把之前文档中的mnist手写数字数据集拿来练手,

代码如下。

import struct
import numpy as np
import os
 
import keras
from keras.models import Sequential 
from keras.layers import Dense
from keras.optimizers import SGD
 
def load_mnist(path, kind):
  labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
  images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)
  with open(labels_path, 'rb') as lbpath:
    magic, n = struct.unpack('>II', lbpath.read(8))
    labels = np.fromfile(lbpath, dtype=np.uint8)
  with open(images_path, 'rb') as imgpath:
    magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16))
    images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) #28*28=784
  return images, labels
 
#loading train and test data
X_train, Y_train = load_mnist('.\\data', kind='train')
X_test, Y_test = load_mnist('.\\data', kind='t10k')
 
#turn labels to one_hot code
Y_train_ohe = keras.utils.to_categorical(Y_train, num_classes=10)
 
#define models
model = Sequential()
 
model.add(Dense(input_dim=X_train.shape[1],output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=Y_train_ohe.shape[1],init='uniform',activation='softmax')) 
 
sgd = SGD(lr=0.001, decay=1e-7, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=["accuracy"])
 
#start training
model.fit(X_train,Y_train_ohe,epochs=50,batch_size=300,shuffle=True,verbose=1,validation_split=0.3)
 
#count accuracy
y_train_pred = model.predict_classes(X_train, verbose=0)
 
train_acc = np.sum(Y_train == y_train_pred, axis=0) / X_train.shape[0] 
print('Training accuracy: %.2f%%' % (train_acc * 100))
 
y_test_pred = model.predict_classes(X_test, verbose=0)
test_acc = np.sum(Y_test == y_test_pred, axis=0) / X_test.shape[0] 
print('Test accuracy: %.2f%%' % (test_acc * 100))

训练结果如下:

Epoch 45/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2174 - acc: 0.9380 - val_loss: 0.2341 - val_acc: 0.9323
Epoch 46/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2061 - acc: 0.9404 - val_loss: 0.2244 - val_acc: 0.9358
Epoch 47/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1994 - acc: 0.9413 - val_loss: 0.2295 - val_acc: 0.9347
Epoch 48/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2003 - acc: 0.9413 - val_loss: 0.2224 - val_acc: 0.9350
Epoch 49/50
42000/42000 [==============================] - 1s 18us/step - loss: 0.2013 - acc: 0.9417 - val_loss: 0.2248 - val_acc: 0.9359
Epoch 50/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1960 - acc: 0.9433 - val_loss: 0.2300 - val_acc: 0.9346
Training accuracy: 94.11%
Test accuracy: 93.61%

以上这篇keras导入weights方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
常用python数据类型转换函数总结
Mar 11 Python
python操作mysql数据库
Mar 05 Python
Python的mysql数据库的更新如何实现
Jul 31 Python
Python cookbook(数据结构与算法)找到最大或最小的N个元素实现方法示例
Feb 13 Python
VSCode下配置python调试运行环境的方法
Apr 06 Python
python3+PyQt5实现拖放功能
Apr 24 Python
python读取word文档,插入mysql数据库的示例代码
Nov 07 Python
对python修改xml文件的节点值方法详解
Dec 24 Python
Django如何防止定时任务并发浅析
May 14 Python
tensorflow 环境变量设置方式
Feb 06 Python
Python 随机生成测试数据的模块:faker基本使用方法详解
Apr 09 Python
如何使用python包中的sched事件调度器
Apr 30 Python
keras读取h5文件load_weights、load代码操作
Jun 12 #Python
Python matplotlib 绘制双Y轴曲线图的示例代码
Jun 12 #Python
keras的siamese(孪生网络)实现案例
Jun 12 #Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
You might like
用php来检测proxy
2006/10/09 PHP
小偷PHP+Html+缓存
2006/11/25 PHP
php array_merge_recursive 数组合并
2016/10/26 PHP
PHP二维数组实现去除重复项的方法【保留各个键值】
2017/12/21 PHP
准确获得页面、窗口高度及宽度的JS
2006/11/26 Javascript
基于jQuery的为attr添加id title等效果的实现代码
2011/04/20 Javascript
js控制frameSet示例
2013/09/10 Javascript
JQuery创建DOM节点的方法
2015/06/11 Javascript
form+iframe解决跨域上传文件的方法
2016/11/18 Javascript
AngularJS 在同一个界面启动多个ng-app应用模块详解
2016/12/20 Javascript
Linux使用Node.js建立访问静态网页的服务实例详解
2017/03/21 Javascript
angularjs实现上拉加载和下拉刷新数据功能
2017/06/12 Javascript
深入理解redux之compose的具体应用
2020/01/12 Javascript
vue-router重写push方法,解决相同路径跳转报错问题
2020/08/07 Javascript
原生js实现购物车
2020/09/23 Javascript
[51:44]2018DOTA2亚洲邀请赛 4.3 突围赛 Optic vs iG 第二场
2018/04/04 DOTA
[51:32]Optic vs Serenity 2018国际邀请赛淘汰赛BO3 第一场 8.22
2018/08/23 DOTA
[15:07]lgd_OG_m2_BP
2019/09/10 DOTA
Python中利用sqrt()方法进行平方根计算的教程
2015/05/15 Python
Python的迭代器和生成器
2015/07/29 Python
Python 中 Meta Classes详解
2016/02/13 Python
浅谈Pandas中map, applymap and apply的区别
2018/04/10 Python
Python实现去除图片中指定颜色的像素功能示例
2019/04/13 Python
Python 生成一个从0到n个数字的列表4种方法小结
2019/11/28 Python
文职个人求职信范文
2013/09/23 职场文书
校园餐饮创业计划书
2014/01/10 职场文书
乡镇干部十八大感言
2014/02/17 职场文书
贷款承诺书范文
2014/05/19 职场文书
如何写早恋检讨书
2014/09/10 职场文书
党员“四风”方面存在问题及整改措施
2014/09/24 职场文书
四风查摆问题及整改措施
2014/10/10 职场文书
幼儿园教师考核评语
2014/12/31 职场文书
公司出纳岗位职责
2015/03/31 职场文书
2015年幼儿园大班工作总结
2015/04/25 职场文书
win10+anaconda安装yolov5的方法及问题解决方案
2021/04/29 Python
Pytorch 如何实现LSTM时间序列预测
2021/05/17 Python