keras导入weights方式


Posted in Python onJune 12, 2020

keras源码engine中toplogy.py定义了加载权重的函数:

load_weights(self, filepath, by_name=False)

其中默认by_name为False,这时候加载权重按照网络拓扑结构加载,适合直接使用keras中自带的网络模型,如VGG16

VGG19/resnet50等,源码描述如下:

If `by_name` is False (default) weights are loaded
based on the network's topology, meaning the architecture
should be the same as when the weights were saved.
Note that layers that don't have weights are not taken
into account in the topological ordering, so adding or
removing layers is fine as long as they don't have weights.

若将by_name改为True则加载权重按照layer的name进行,layer的name相同时加载权重,适合用于改变了

模型的相关结构或增加了节点但利用了原网络的主体结构情况下使用,源码描述如下:

If `by_name` is True, weights are loaded into layers
only if they share the same name. This is useful
for fine-tuning or transfer-learning models where
some of the layers have changed.

在进行边缘检测时,利用VGG网络的主体结构,网络中增加反卷积层,这时加载权重应该使用

model.load_weights(filepath,by_name=True)

补充知识:Keras下实现mnist手写数字

之前一直在用tensorflow,被同学推荐来用keras了,把之前文档中的mnist手写数字数据集拿来练手,

代码如下。

import struct
import numpy as np
import os
 
import keras
from keras.models import Sequential 
from keras.layers import Dense
from keras.optimizers import SGD
 
def load_mnist(path, kind):
  labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
  images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)
  with open(labels_path, 'rb') as lbpath:
    magic, n = struct.unpack('>II', lbpath.read(8))
    labels = np.fromfile(lbpath, dtype=np.uint8)
  with open(images_path, 'rb') as imgpath:
    magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16))
    images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) #28*28=784
  return images, labels
 
#loading train and test data
X_train, Y_train = load_mnist('.\\data', kind='train')
X_test, Y_test = load_mnist('.\\data', kind='t10k')
 
#turn labels to one_hot code
Y_train_ohe = keras.utils.to_categorical(Y_train, num_classes=10)
 
#define models
model = Sequential()
 
model.add(Dense(input_dim=X_train.shape[1],output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=Y_train_ohe.shape[1],init='uniform',activation='softmax')) 
 
sgd = SGD(lr=0.001, decay=1e-7, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=["accuracy"])
 
#start training
model.fit(X_train,Y_train_ohe,epochs=50,batch_size=300,shuffle=True,verbose=1,validation_split=0.3)
 
#count accuracy
y_train_pred = model.predict_classes(X_train, verbose=0)
 
train_acc = np.sum(Y_train == y_train_pred, axis=0) / X_train.shape[0] 
print('Training accuracy: %.2f%%' % (train_acc * 100))
 
y_test_pred = model.predict_classes(X_test, verbose=0)
test_acc = np.sum(Y_test == y_test_pred, axis=0) / X_test.shape[0] 
print('Test accuracy: %.2f%%' % (test_acc * 100))

训练结果如下:

Epoch 45/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2174 - acc: 0.9380 - val_loss: 0.2341 - val_acc: 0.9323
Epoch 46/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2061 - acc: 0.9404 - val_loss: 0.2244 - val_acc: 0.9358
Epoch 47/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1994 - acc: 0.9413 - val_loss: 0.2295 - val_acc: 0.9347
Epoch 48/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2003 - acc: 0.9413 - val_loss: 0.2224 - val_acc: 0.9350
Epoch 49/50
42000/42000 [==============================] - 1s 18us/step - loss: 0.2013 - acc: 0.9417 - val_loss: 0.2248 - val_acc: 0.9359
Epoch 50/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1960 - acc: 0.9433 - val_loss: 0.2300 - val_acc: 0.9346
Training accuracy: 94.11%
Test accuracy: 93.61%

以上这篇keras导入weights方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python批量修改文件后缀示例代码分享
Dec 24 Python
python实现井字棋游戏
Mar 30 Python
python实现实时监控文件的方法
Aug 26 Python
Python只用40行代码编写的计算器实例
May 10 Python
Python爬虫实例_城市公交网络站点数据的爬取方法
Jan 10 Python
python绘图模块matplotlib示例详解
Jul 26 Python
三个python爬虫项目实例代码
Dec 28 Python
python中for in的用法详解
Apr 17 Python
Python基于DB-API操作MySQL数据库过程解析
Apr 23 Python
Python爬虫爬取有道实现翻译功能
Nov 27 Python
Python如何使用logging为Flask增加logid
Mar 30 Python
Python实现制作销售数据可视化看板详解
Nov 27 Python
keras读取h5文件load_weights、load代码操作
Jun 12 #Python
Python matplotlib 绘制双Y轴曲线图的示例代码
Jun 12 #Python
keras的siamese(孪生网络)实现案例
Jun 12 #Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
You might like
YB217、YB235、YB400浅听
2021/03/02 无线电
用文本文件制作留言板提示(下)
2006/10/09 PHP
PHP获取当前日期所在星期(月份)的开始日期与结束日期(实现代码)
2013/06/18 PHP
限制ckeditor上传图片文件大小的方法
2013/11/15 PHP
php内核解析:PHP中的哈希表
2014/01/30 PHP
详解WordPress开发中wp_title()函数的用法
2016/01/07 PHP
yii2控制器Controller Ajax操作示例
2016/07/23 PHP
简单的JS多重继承示例
2008/03/13 Javascript
详解AngularJS中的依赖注入机制
2015/06/17 Javascript
Jquery中使用show()与hide()方法动画显示和隐藏图片
2015/10/08 Javascript
3种js实现string的substring方法
2015/11/09 Javascript
input file上传 图片预览功能实例代码
2016/10/25 Javascript
jquery ajaxfileupload异步上传插件使用详解
2017/02/08 Javascript
jQuery插件HighCharts绘制的基本折线图效果示例【附demo源码下载】
2017/03/07 Javascript
js中的闭包实例展示
2018/11/01 Javascript
简单两步使用node发送qq邮件的方法
2019/03/01 Javascript
微信公众平台获取access_token的方法步骤
2019/03/29 Javascript
[02:19]DOTA2女子战队FOX视频专访:希望更多美眉一起加入
2013/10/15 DOTA
python网络编程之UDP通信实例(含服务器端、客户端、UDP广播例子)
2014/04/25 Python
Python列表(list)常用操作方法小结
2015/02/02 Python
使用SAE部署Python运行环境的教程
2015/05/05 Python
Python 字典与字符串的互转实例
2017/01/13 Python
python中pandas.DataFrame对行与列求和及添加新行与列示例
2017/03/12 Python
python实现公司年会抽奖程序
2019/01/22 Python
详解Python3中ceil()函数用法
2019/02/19 Python
Django之form组件自动校验数据实现
2020/01/14 Python
Python调用jar包方法实现过程解析
2020/08/11 Python
I.T中国官网:精选时尚设计师单品网购平台
2018/03/26 全球购物
人力资源部门的主要职能
2014/02/22 职场文书
《爱如茉莉》教后反思
2014/04/12 职场文书
应届毕业生求职信
2014/05/26 职场文书
教代会闭幕词
2015/01/28 职场文书
培训感想范文
2015/08/07 职场文书
小学中队委竞选稿
2015/11/20 职场文书
Python实现的扫码工具居然这么好用!
2021/06/07 Python
详解Go语言中配置文件使用与日志配置
2022/06/01 Golang