keras导入weights方式


Posted in Python onJune 12, 2020

keras源码engine中toplogy.py定义了加载权重的函数:

load_weights(self, filepath, by_name=False)

其中默认by_name为False,这时候加载权重按照网络拓扑结构加载,适合直接使用keras中自带的网络模型,如VGG16

VGG19/resnet50等,源码描述如下:

If `by_name` is False (default) weights are loaded
based on the network's topology, meaning the architecture
should be the same as when the weights were saved.
Note that layers that don't have weights are not taken
into account in the topological ordering, so adding or
removing layers is fine as long as they don't have weights.

若将by_name改为True则加载权重按照layer的name进行,layer的name相同时加载权重,适合用于改变了

模型的相关结构或增加了节点但利用了原网络的主体结构情况下使用,源码描述如下:

If `by_name` is True, weights are loaded into layers
only if they share the same name. This is useful
for fine-tuning or transfer-learning models where
some of the layers have changed.

在进行边缘检测时,利用VGG网络的主体结构,网络中增加反卷积层,这时加载权重应该使用

model.load_weights(filepath,by_name=True)

补充知识:Keras下实现mnist手写数字

之前一直在用tensorflow,被同学推荐来用keras了,把之前文档中的mnist手写数字数据集拿来练手,

代码如下。

import struct
import numpy as np
import os
 
import keras
from keras.models import Sequential 
from keras.layers import Dense
from keras.optimizers import SGD
 
def load_mnist(path, kind):
  labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
  images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)
  with open(labels_path, 'rb') as lbpath:
    magic, n = struct.unpack('>II', lbpath.read(8))
    labels = np.fromfile(lbpath, dtype=np.uint8)
  with open(images_path, 'rb') as imgpath:
    magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16))
    images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) #28*28=784
  return images, labels
 
#loading train and test data
X_train, Y_train = load_mnist('.\\data', kind='train')
X_test, Y_test = load_mnist('.\\data', kind='t10k')
 
#turn labels to one_hot code
Y_train_ohe = keras.utils.to_categorical(Y_train, num_classes=10)
 
#define models
model = Sequential()
 
model.add(Dense(input_dim=X_train.shape[1],output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=Y_train_ohe.shape[1],init='uniform',activation='softmax')) 
 
sgd = SGD(lr=0.001, decay=1e-7, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=["accuracy"])
 
#start training
model.fit(X_train,Y_train_ohe,epochs=50,batch_size=300,shuffle=True,verbose=1,validation_split=0.3)
 
#count accuracy
y_train_pred = model.predict_classes(X_train, verbose=0)
 
train_acc = np.sum(Y_train == y_train_pred, axis=0) / X_train.shape[0] 
print('Training accuracy: %.2f%%' % (train_acc * 100))
 
y_test_pred = model.predict_classes(X_test, verbose=0)
test_acc = np.sum(Y_test == y_test_pred, axis=0) / X_test.shape[0] 
print('Test accuracy: %.2f%%' % (test_acc * 100))

训练结果如下:

Epoch 45/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2174 - acc: 0.9380 - val_loss: 0.2341 - val_acc: 0.9323
Epoch 46/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2061 - acc: 0.9404 - val_loss: 0.2244 - val_acc: 0.9358
Epoch 47/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1994 - acc: 0.9413 - val_loss: 0.2295 - val_acc: 0.9347
Epoch 48/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2003 - acc: 0.9413 - val_loss: 0.2224 - val_acc: 0.9350
Epoch 49/50
42000/42000 [==============================] - 1s 18us/step - loss: 0.2013 - acc: 0.9417 - val_loss: 0.2248 - val_acc: 0.9359
Epoch 50/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1960 - acc: 0.9433 - val_loss: 0.2300 - val_acc: 0.9346
Training accuracy: 94.11%
Test accuracy: 93.61%

以上这篇keras导入weights方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python异步任务队列示例
Apr 01 Python
用Python操作字符串之rindex()方法的使用
May 19 Python
Python学习入门之区块链详解
Jul 25 Python
浅谈numpy库的常用基本操作方法
Jan 09 Python
Django中使用Celery的教程详解
Aug 24 Python
JSON文件及Python对JSON文件的读写操作
Oct 07 Python
python微元法计算函数曲线长度的方法
Nov 08 Python
Python实现的对一个数进行因式分解操作示例
Jun 27 Python
如何在VSCode上轻松舒适的配置Python的方法步骤
Oct 28 Python
python调用函数、类和文件操作简单实例总结
Nov 29 Python
python识别验证码图片实例详解
Feb 17 Python
Python使用内置函数setattr设置对象的属性值
Oct 16 Python
keras读取h5文件load_weights、load代码操作
Jun 12 #Python
Python matplotlib 绘制双Y轴曲线图的示例代码
Jun 12 #Python
keras的siamese(孪生网络)实现案例
Jun 12 #Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
You might like
探讨PHP JSON中文乱码的解决方法详解
2013/06/06 PHP
解析php中两种缩放图片的函数,为图片添加水印
2013/06/14 PHP
PHP进程同步代码实例
2015/02/12 PHP
php使用array_search函数实现数组查找的方法
2015/06/12 PHP
提交表单后 PHP获取提交内容的实现方法
2016/05/25 PHP
php实现的数组转xml案例分析
2019/09/28 PHP
重载toString实现JS HashMap分析
2011/03/13 Javascript
仅img元素创建后不添加到文档中会执行onload事件的解决方法
2011/07/31 Javascript
JS自调用匿名函数具体实现
2014/02/11 Javascript
js简单实现调整网页字体大小的方法
2016/07/23 Javascript
vue使用$emit时,父组件无法监听到子组件的事件实例
2018/02/26 Javascript
React diff算法的实现示例
2018/04/20 Javascript
JS温故而知新之变量提升和时间死区
2019/01/27 Javascript
简述pm2常用命令集合及配置文件说明
2019/05/30 Javascript
elementUi vue el-radio 监听选中变化的实例代码
2019/06/28 Javascript
JS+canvas五子棋人机对战实现步骤详解
2020/06/04 Javascript
Python画图学习入门教程
2016/07/01 Python
Python对字符串实现去重操作的方法示例
2017/08/11 Python
django的登录注册系统的示例代码
2018/05/14 Python
python 匹配url中是否存在IP地址的方法
2018/06/04 Python
将tensorflow的ckpt模型存储为npy的实例
2018/07/09 Python
python小程序基于Jupyter实现天气查询的方法
2020/03/27 Python
使用Python三角函数公式计算三角形的夹角案例
2020/04/15 Python
CSS3 input框的实现代码类似Google登录的动画效果
2020/08/04 HTML / CSS
Omio意大利:全欧洲低价大巴、火车和航班搜索和比价
2017/12/02 全球购物
送给他或她的礼物:FUN.com
2018/08/17 全球购物
求职信范文英文版
2014/01/05 职场文书
企业办公室岗位职责
2014/03/12 职场文书
《青山处处埋忠骨》教学反思
2014/04/22 职场文书
开展批评与自我批评发言材料
2014/05/15 职场文书
软件研发工程师岗位职责
2014/09/30 职场文书
2015年财务人员工作总结
2015/04/10 职场文书
教导处教学工作总结
2015/08/12 职场文书
2016年大学生就业指导课心得体会
2015/10/09 职场文书
幼儿园语言教学反思
2016/02/23 职场文书
mapstruct的用法之qualifiedByName示例详解
2022/04/06 Java/Android