keras导入weights方式


Posted in Python onJune 12, 2020

keras源码engine中toplogy.py定义了加载权重的函数:

load_weights(self, filepath, by_name=False)

其中默认by_name为False,这时候加载权重按照网络拓扑结构加载,适合直接使用keras中自带的网络模型,如VGG16

VGG19/resnet50等,源码描述如下:

If `by_name` is False (default) weights are loaded
based on the network's topology, meaning the architecture
should be the same as when the weights were saved.
Note that layers that don't have weights are not taken
into account in the topological ordering, so adding or
removing layers is fine as long as they don't have weights.

若将by_name改为True则加载权重按照layer的name进行,layer的name相同时加载权重,适合用于改变了

模型的相关结构或增加了节点但利用了原网络的主体结构情况下使用,源码描述如下:

If `by_name` is True, weights are loaded into layers
only if they share the same name. This is useful
for fine-tuning or transfer-learning models where
some of the layers have changed.

在进行边缘检测时,利用VGG网络的主体结构,网络中增加反卷积层,这时加载权重应该使用

model.load_weights(filepath,by_name=True)

补充知识:Keras下实现mnist手写数字

之前一直在用tensorflow,被同学推荐来用keras了,把之前文档中的mnist手写数字数据集拿来练手,

代码如下。

import struct
import numpy as np
import os
 
import keras
from keras.models import Sequential 
from keras.layers import Dense
from keras.optimizers import SGD
 
def load_mnist(path, kind):
  labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
  images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)
  with open(labels_path, 'rb') as lbpath:
    magic, n = struct.unpack('>II', lbpath.read(8))
    labels = np.fromfile(lbpath, dtype=np.uint8)
  with open(images_path, 'rb') as imgpath:
    magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16))
    images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) #28*28=784
  return images, labels
 
#loading train and test data
X_train, Y_train = load_mnist('.\\data', kind='train')
X_test, Y_test = load_mnist('.\\data', kind='t10k')
 
#turn labels to one_hot code
Y_train_ohe = keras.utils.to_categorical(Y_train, num_classes=10)
 
#define models
model = Sequential()
 
model.add(Dense(input_dim=X_train.shape[1],output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=Y_train_ohe.shape[1],init='uniform',activation='softmax')) 
 
sgd = SGD(lr=0.001, decay=1e-7, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=["accuracy"])
 
#start training
model.fit(X_train,Y_train_ohe,epochs=50,batch_size=300,shuffle=True,verbose=1,validation_split=0.3)
 
#count accuracy
y_train_pred = model.predict_classes(X_train, verbose=0)
 
train_acc = np.sum(Y_train == y_train_pred, axis=0) / X_train.shape[0] 
print('Training accuracy: %.2f%%' % (train_acc * 100))
 
y_test_pred = model.predict_classes(X_test, verbose=0)
test_acc = np.sum(Y_test == y_test_pred, axis=0) / X_test.shape[0] 
print('Test accuracy: %.2f%%' % (test_acc * 100))

训练结果如下:

Epoch 45/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2174 - acc: 0.9380 - val_loss: 0.2341 - val_acc: 0.9323
Epoch 46/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2061 - acc: 0.9404 - val_loss: 0.2244 - val_acc: 0.9358
Epoch 47/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1994 - acc: 0.9413 - val_loss: 0.2295 - val_acc: 0.9347
Epoch 48/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2003 - acc: 0.9413 - val_loss: 0.2224 - val_acc: 0.9350
Epoch 49/50
42000/42000 [==============================] - 1s 18us/step - loss: 0.2013 - acc: 0.9417 - val_loss: 0.2248 - val_acc: 0.9359
Epoch 50/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1960 - acc: 0.9433 - val_loss: 0.2300 - val_acc: 0.9346
Training accuracy: 94.11%
Test accuracy: 93.61%

以上这篇keras导入weights方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python高效编程技巧
Jan 07 Python
学习python的几条建议分享
Feb 10 Python
python查找第k小元素代码分享
Dec 18 Python
C#返回当前系统所有可用驱动器符号的方法
Apr 18 Python
Python使用dis模块把Python反编译为字节码的用法详解
Jun 14 Python
对Python Pexpect 模块的使用说明详解
Feb 14 Python
Python3 实现串口两进程同时读写
Jun 12 Python
通过PYTHON来实现图像分割详解
Jun 26 Python
django 消息框架 message使用详解
Jul 22 Python
Python中itertools的用法详解
Feb 07 Python
pandas和spark dataframe互相转换实例详解
Feb 18 Python
在pycharm创建scrapy项目的实现步骤
Dec 01 Python
keras读取h5文件load_weights、load代码操作
Jun 12 #Python
Python matplotlib 绘制双Y轴曲线图的示例代码
Jun 12 #Python
keras的siamese(孪生网络)实现案例
Jun 12 #Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
You might like
php过滤XSS攻击的函数
2013/11/12 PHP
php发送post请求的三种方法
2014/02/11 PHP
php利用反射实现插件机制的方法
2015/03/14 PHP
PHP5.3新特性小结
2016/02/14 PHP
PHP处理Ajax请求与Ajax跨域问题
2017/02/13 PHP
javascript让setInteval里的函数参数中的this指向特定的对象
2010/01/31 Javascript
js简易namespace管理器 实例代码
2013/06/21 Javascript
JavaScript中window、doucment、body的解释
2013/08/14 Javascript
鼠标滚轮改变图片大小的示例代码
2013/11/20 Javascript
Javascript与jQuery方法的隐藏与显示
2015/01/19 Javascript
js实现鼠标触发图片抖动效果的方法
2015/02/27 Javascript
js实现可折叠展开的手风琴菜单效果
2015/09/07 Javascript
快速解决Canvas.toDataURL 图片跨域的问题
2016/05/10 Javascript
BootStrap智能表单实战系列(七)验证的支持
2016/06/13 Javascript
jQuery Tags Input Plugin(添加/删除标签插件)详解
2016/06/20 Javascript
利用fecha进行JS日期处理
2016/11/21 Javascript
Javascript的this用法
2017/01/16 Javascript
vue2实现移动端上传、预览、压缩图片解决拍照旋转问题
2017/04/13 Javascript
vue实现动态显示与隐藏底部导航的方法分析
2019/02/11 Javascript
深入探索VueJS Scoped CSS 实现原理
2019/09/23 Javascript
Python命令启动Web服务器实例详解
2017/02/23 Python
python3.4用循环往mysql5.7中写数据并输出的实现方法
2017/06/20 Python
python去除拼音声调字母,替换为字母的方法
2018/11/28 Python
python单例模式原理与创建方法实例分析
2019/10/26 Python
解决python彩色螺旋线绘制引发的问题
2019/11/23 Python
生产车间实习自我鉴定
2013/09/23 职场文书
大学生党员承诺书
2014/05/20 职场文书
民族精神月活动总结
2014/08/28 职场文书
购房公证委托书(2014版)
2014/09/12 职场文书
怎样写离婚协议书
2015/01/26 职场文书
销售会议开幕词
2015/01/28 职场文书
趵突泉导游词
2015/02/03 职场文书
大一学生个人总结
2015/02/15 职场文书
管理失职检讨书
2015/05/05 职场文书
爱岗敬业事迹材料
2019/06/20 职场文书
python 制作一个gui界面的翻译工具
2021/05/14 Python