keras导入weights方式


Posted in Python onJune 12, 2020

keras源码engine中toplogy.py定义了加载权重的函数:

load_weights(self, filepath, by_name=False)

其中默认by_name为False,这时候加载权重按照网络拓扑结构加载,适合直接使用keras中自带的网络模型,如VGG16

VGG19/resnet50等,源码描述如下:

If `by_name` is False (default) weights are loaded
based on the network's topology, meaning the architecture
should be the same as when the weights were saved.
Note that layers that don't have weights are not taken
into account in the topological ordering, so adding or
removing layers is fine as long as they don't have weights.

若将by_name改为True则加载权重按照layer的name进行,layer的name相同时加载权重,适合用于改变了

模型的相关结构或增加了节点但利用了原网络的主体结构情况下使用,源码描述如下:

If `by_name` is True, weights are loaded into layers
only if they share the same name. This is useful
for fine-tuning or transfer-learning models where
some of the layers have changed.

在进行边缘检测时,利用VGG网络的主体结构,网络中增加反卷积层,这时加载权重应该使用

model.load_weights(filepath,by_name=True)

补充知识:Keras下实现mnist手写数字

之前一直在用tensorflow,被同学推荐来用keras了,把之前文档中的mnist手写数字数据集拿来练手,

代码如下。

import struct
import numpy as np
import os
 
import keras
from keras.models import Sequential 
from keras.layers import Dense
from keras.optimizers import SGD
 
def load_mnist(path, kind):
  labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
  images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)
  with open(labels_path, 'rb') as lbpath:
    magic, n = struct.unpack('>II', lbpath.read(8))
    labels = np.fromfile(lbpath, dtype=np.uint8)
  with open(images_path, 'rb') as imgpath:
    magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16))
    images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) #28*28=784
  return images, labels
 
#loading train and test data
X_train, Y_train = load_mnist('.\\data', kind='train')
X_test, Y_test = load_mnist('.\\data', kind='t10k')
 
#turn labels to one_hot code
Y_train_ohe = keras.utils.to_categorical(Y_train, num_classes=10)
 
#define models
model = Sequential()
 
model.add(Dense(input_dim=X_train.shape[1],output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=Y_train_ohe.shape[1],init='uniform',activation='softmax')) 
 
sgd = SGD(lr=0.001, decay=1e-7, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=["accuracy"])
 
#start training
model.fit(X_train,Y_train_ohe,epochs=50,batch_size=300,shuffle=True,verbose=1,validation_split=0.3)
 
#count accuracy
y_train_pred = model.predict_classes(X_train, verbose=0)
 
train_acc = np.sum(Y_train == y_train_pred, axis=0) / X_train.shape[0] 
print('Training accuracy: %.2f%%' % (train_acc * 100))
 
y_test_pred = model.predict_classes(X_test, verbose=0)
test_acc = np.sum(Y_test == y_test_pred, axis=0) / X_test.shape[0] 
print('Test accuracy: %.2f%%' % (test_acc * 100))

训练结果如下:

Epoch 45/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2174 - acc: 0.9380 - val_loss: 0.2341 - val_acc: 0.9323
Epoch 46/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2061 - acc: 0.9404 - val_loss: 0.2244 - val_acc: 0.9358
Epoch 47/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1994 - acc: 0.9413 - val_loss: 0.2295 - val_acc: 0.9347
Epoch 48/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2003 - acc: 0.9413 - val_loss: 0.2224 - val_acc: 0.9350
Epoch 49/50
42000/42000 [==============================] - 1s 18us/step - loss: 0.2013 - acc: 0.9417 - val_loss: 0.2248 - val_acc: 0.9359
Epoch 50/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1960 - acc: 0.9433 - val_loss: 0.2300 - val_acc: 0.9346
Training accuracy: 94.11%
Test accuracy: 93.61%

以上这篇keras导入weights方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python代码制作configure文件示例
Jul 28 Python
Python3.2中Print函数用法实例详解
May 19 Python
纯用NumPy实现神经网络的示例代码
Oct 24 Python
Python 中导入csv数据的三种方法
Nov 01 Python
对web.py设置favicon.ico的方法详解
Dec 04 Python
python去重,一个由dict组成的list的去重示例
Jan 21 Python
浅析python 中大括号中括号小括号的区分
Jul 29 Python
Python实现PyPDF2处理PDF文件的方法示例
Sep 25 Python
浅谈Keras中shuffle和validation_split的顺序
Jun 19 Python
超级实用的8个Python列表技巧
Aug 24 Python
python 如何在 Matplotlib 中绘制垂直线
Apr 02 Python
python多次执行绘制条形图
Apr 20 Python
keras读取h5文件load_weights、load代码操作
Jun 12 #Python
Python matplotlib 绘制双Y轴曲线图的示例代码
Jun 12 #Python
keras的siamese(孪生网络)实现案例
Jun 12 #Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
You might like
PHP优于Node.js的五大理由分享
2012/09/15 PHP
PHP使用PDO连接ACCESS数据库
2015/03/05 PHP
php使用指定编码导出mysql数据到csv文件的方法
2015/03/31 PHP
PHP+redis实现的购物车单例类示例
2019/02/02 PHP
破除网页鼠标右键被禁用的绝招大全
2006/12/27 Javascript
jquery 防止表单重复提交代码
2010/01/21 Javascript
JQuery 将元素显示在屏幕的中央的代码
2010/02/27 Javascript
Extjs 继承Ext.data.Store不起作用原因分析及解决
2013/04/15 Javascript
JavaScript实现cookie的写入、读取、删除功能
2015/11/05 Javascript
快速学习AngularJs HTTP响应拦截器
2015/12/31 Javascript
原生js获取元素样式的简单方法
2016/08/06 Javascript
微信小程序 常见问题总结(4058,40013)及解决办法
2017/01/11 Javascript
详解webpack 入门总结和实践(按需异步加载,css单独打包,生成多个入口文件)
2017/06/20 Javascript
JavaScript之class继承_动力节点Java学院整理
2017/07/03 Javascript
arcgis for js栅格图层叠加(Raster Layer)问题
2017/11/22 Javascript
关于Angularjs中自定义指令一些有价值的细节和技巧小结
2018/04/22 Javascript
ionic grid(栅格)九宫格制作详解
2018/06/30 Javascript
JS实现点击按钮随机生成可拖动的不同颜色块示例
2019/01/30 Javascript
Node.js HTTP服务器中的文件、图片上传的方法
2019/09/23 Javascript
[00:30]塑造者的传承礼包-戴泽“暗影之焰”套装展示视频
2014/04/04 DOTA
Pycharm编辑器技巧之自动导入模块详解
2017/07/18 Python
Python 打印中文字符的三种方法
2018/08/14 Python
python将txt等文件中的数据读为numpy数组的方法
2018/12/22 Python
Python3爬虫RedisDump的安装步骤
2021/02/20 Python
阻止移动设备(手机、pad)浏览器双击放大网页的方法
2014/06/03 HTML / CSS
3种方式实现瀑布流布局小结
2019/09/05 HTML / CSS
我们在web应用开发过程中经常遇到输出某种编码的字符,如iso8859-1等,如何输出一个某种编码的字符串?
2014/03/30 面试题
2014年两会学习心得体会
2014/03/17 职场文书
学校募捐倡议书
2014/05/14 职场文书
大学生学习面向未来的赶考思想汇报
2014/09/12 职场文书
2015年教学管理工作总结
2015/05/20 职场文书
学校运动会感想
2015/08/10 职场文书
送给教师们,到底该如何写好教学反思?
2019/07/02 职场文书
Python 用户输入和while循环的操作
2021/05/23 Python
Python实现Hash算法
2022/03/18 Python
大脑的记忆过程在做数据压缩,不同图形也有共同的记忆格式
2022/04/29 数码科技