keras导入weights方式


Posted in Python onJune 12, 2020

keras源码engine中toplogy.py定义了加载权重的函数:

load_weights(self, filepath, by_name=False)

其中默认by_name为False,这时候加载权重按照网络拓扑结构加载,适合直接使用keras中自带的网络模型,如VGG16

VGG19/resnet50等,源码描述如下:

If `by_name` is False (default) weights are loaded
based on the network's topology, meaning the architecture
should be the same as when the weights were saved.
Note that layers that don't have weights are not taken
into account in the topological ordering, so adding or
removing layers is fine as long as they don't have weights.

若将by_name改为True则加载权重按照layer的name进行,layer的name相同时加载权重,适合用于改变了

模型的相关结构或增加了节点但利用了原网络的主体结构情况下使用,源码描述如下:

If `by_name` is True, weights are loaded into layers
only if they share the same name. This is useful
for fine-tuning or transfer-learning models where
some of the layers have changed.

在进行边缘检测时,利用VGG网络的主体结构,网络中增加反卷积层,这时加载权重应该使用

model.load_weights(filepath,by_name=True)

补充知识:Keras下实现mnist手写数字

之前一直在用tensorflow,被同学推荐来用keras了,把之前文档中的mnist手写数字数据集拿来练手,

代码如下。

import struct
import numpy as np
import os
 
import keras
from keras.models import Sequential 
from keras.layers import Dense
from keras.optimizers import SGD
 
def load_mnist(path, kind):
  labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
  images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)
  with open(labels_path, 'rb') as lbpath:
    magic, n = struct.unpack('>II', lbpath.read(8))
    labels = np.fromfile(lbpath, dtype=np.uint8)
  with open(images_path, 'rb') as imgpath:
    magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16))
    images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) #28*28=784
  return images, labels
 
#loading train and test data
X_train, Y_train = load_mnist('.\\data', kind='train')
X_test, Y_test = load_mnist('.\\data', kind='t10k')
 
#turn labels to one_hot code
Y_train_ohe = keras.utils.to_categorical(Y_train, num_classes=10)
 
#define models
model = Sequential()
 
model.add(Dense(input_dim=X_train.shape[1],output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=Y_train_ohe.shape[1],init='uniform',activation='softmax')) 
 
sgd = SGD(lr=0.001, decay=1e-7, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=["accuracy"])
 
#start training
model.fit(X_train,Y_train_ohe,epochs=50,batch_size=300,shuffle=True,verbose=1,validation_split=0.3)
 
#count accuracy
y_train_pred = model.predict_classes(X_train, verbose=0)
 
train_acc = np.sum(Y_train == y_train_pred, axis=0) / X_train.shape[0] 
print('Training accuracy: %.2f%%' % (train_acc * 100))
 
y_test_pred = model.predict_classes(X_test, verbose=0)
test_acc = np.sum(Y_test == y_test_pred, axis=0) / X_test.shape[0] 
print('Test accuracy: %.2f%%' % (test_acc * 100))

训练结果如下:

Epoch 45/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2174 - acc: 0.9380 - val_loss: 0.2341 - val_acc: 0.9323
Epoch 46/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2061 - acc: 0.9404 - val_loss: 0.2244 - val_acc: 0.9358
Epoch 47/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1994 - acc: 0.9413 - val_loss: 0.2295 - val_acc: 0.9347
Epoch 48/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2003 - acc: 0.9413 - val_loss: 0.2224 - val_acc: 0.9350
Epoch 49/50
42000/42000 [==============================] - 1s 18us/step - loss: 0.2013 - acc: 0.9417 - val_loss: 0.2248 - val_acc: 0.9359
Epoch 50/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1960 - acc: 0.9433 - val_loss: 0.2300 - val_acc: 0.9346
Training accuracy: 94.11%
Test accuracy: 93.61%

以上这篇keras导入weights方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Django中反向生成models.py的实例讲解
May 30 Python
解决python升级引起的pip执行错误的问题
Jun 12 Python
Python 中字符串拼接的多种方法
Jul 30 Python
程序员写Python时的5个坏习惯,你有几条?
Nov 26 Python
python matplotlib库绘制散点图例题解析
Aug 10 Python
python处理自动化任务之同时批量修改word里面的内容的方法
Aug 23 Python
Python解析json代码实例解析
Nov 25 Python
python中count函数简单用法
Jan 05 Python
Python实现检测文件的MD5值来查找重复文件案例
Mar 12 Python
python如何快速生成时间戳
Jul 21 Python
Python return语句如何实现结果返回调用
Oct 15 Python
如何设置PyCharm中的Python代码模版(推荐)
Nov 20 Python
keras读取h5文件load_weights、load代码操作
Jun 12 #Python
Python matplotlib 绘制双Y轴曲线图的示例代码
Jun 12 #Python
keras的siamese(孪生网络)实现案例
Jun 12 #Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
You might like
php版小黄鸡simsimi聊天机器人接口分享
2014/01/26 PHP
ThinkPHP之用户注册登录留言完整实例
2014/07/22 PHP
PHP把MSSQL数据导入到MYSQL的方法
2014/12/27 PHP
php中fsockopen用法实例
2015/01/05 PHP
php连接oracle数据库的核心步骤
2016/05/26 PHP
php操纵mysqli数据库的实现方法
2016/09/18 PHP
php如何利用pecl安装mongodb扩展详解
2019/01/09 PHP
PHP7数组的底层实现示例
2019/08/25 PHP
ThinkPHP5.1的权限控制怎么写?分享一个AUTH权限控制
2021/03/09 PHP
JavaScript继承方式实例
2010/10/29 Javascript
JavaScript中使用正则匹配多条,且获取每条中的分组数据
2010/11/30 Javascript
Extjs中ComboBoxTree实现的下拉框树效果(自写)
2013/05/28 Javascript
浅析jquery的js图表组件highcharts
2014/03/06 Javascript
JavaScript检查弹出窗口是否被阻拦的方法技巧
2015/03/13 Javascript
使用AmplifyJS组件配合JavaScript进行编程的指南
2015/07/28 Javascript
easyui Draggable组件实现拖动效果
2015/08/19 Javascript
js判断输入字符串是否为空、空格、null的方法总结
2016/06/14 Javascript
原生JS实现图片懒加载(lazyload)实例
2017/06/13 Javascript
谈谈对vue响应式数据更新的误解
2017/08/01 Javascript
从零开始实现Vue简单的Toast插件
2018/12/03 Javascript
vue 实现Web端的定位功能 获取经纬度
2019/08/08 Javascript
vue实现百度搜索功能
2020/12/28 Javascript
python进阶_浅谈面向对象进阶
2017/08/17 Python
Keras - GPU ID 和显存占用设定步骤
2020/06/22 Python
Selenium alert 弹窗处理的示例代码
2020/08/06 Python
安德玛加拿大官网:Under Armour加拿大
2019/10/02 全球购物
介绍一下如何利用路径遍历进行攻击及如何防范
2014/01/19 面试题
学雷锋宣传标语
2014/06/25 职场文书
质监局领导班子对照检查材料思想汇报
2014/09/27 职场文书
警察正风肃纪剖析材料
2014/10/16 职场文书
大学生暑期社会实践证明范本
2014/10/24 职场文书
农村党支部承诺书
2015/04/30 职场文书
2016年党员学习廉政准则心得体会
2016/01/20 职场文书
Python matplotlib可视化之绘制韦恩图
2022/02/24 Python
vue elementUI批量上传文件
2022/04/26 Vue.js
5个实用的JavaScript新特性
2022/06/16 Javascript