keras导入weights方式


Posted in Python onJune 12, 2020

keras源码engine中toplogy.py定义了加载权重的函数:

load_weights(self, filepath, by_name=False)

其中默认by_name为False,这时候加载权重按照网络拓扑结构加载,适合直接使用keras中自带的网络模型,如VGG16

VGG19/resnet50等,源码描述如下:

If `by_name` is False (default) weights are loaded
based on the network's topology, meaning the architecture
should be the same as when the weights were saved.
Note that layers that don't have weights are not taken
into account in the topological ordering, so adding or
removing layers is fine as long as they don't have weights.

若将by_name改为True则加载权重按照layer的name进行,layer的name相同时加载权重,适合用于改变了

模型的相关结构或增加了节点但利用了原网络的主体结构情况下使用,源码描述如下:

If `by_name` is True, weights are loaded into layers
only if they share the same name. This is useful
for fine-tuning or transfer-learning models where
some of the layers have changed.

在进行边缘检测时,利用VGG网络的主体结构,网络中增加反卷积层,这时加载权重应该使用

model.load_weights(filepath,by_name=True)

补充知识:Keras下实现mnist手写数字

之前一直在用tensorflow,被同学推荐来用keras了,把之前文档中的mnist手写数字数据集拿来练手,

代码如下。

import struct
import numpy as np
import os
 
import keras
from keras.models import Sequential 
from keras.layers import Dense
from keras.optimizers import SGD
 
def load_mnist(path, kind):
  labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
  images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)
  with open(labels_path, 'rb') as lbpath:
    magic, n = struct.unpack('>II', lbpath.read(8))
    labels = np.fromfile(lbpath, dtype=np.uint8)
  with open(images_path, 'rb') as imgpath:
    magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16))
    images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) #28*28=784
  return images, labels
 
#loading train and test data
X_train, Y_train = load_mnist('.\\data', kind='train')
X_test, Y_test = load_mnist('.\\data', kind='t10k')
 
#turn labels to one_hot code
Y_train_ohe = keras.utils.to_categorical(Y_train, num_classes=10)
 
#define models
model = Sequential()
 
model.add(Dense(input_dim=X_train.shape[1],output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=Y_train_ohe.shape[1],init='uniform',activation='softmax')) 
 
sgd = SGD(lr=0.001, decay=1e-7, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=["accuracy"])
 
#start training
model.fit(X_train,Y_train_ohe,epochs=50,batch_size=300,shuffle=True,verbose=1,validation_split=0.3)
 
#count accuracy
y_train_pred = model.predict_classes(X_train, verbose=0)
 
train_acc = np.sum(Y_train == y_train_pred, axis=0) / X_train.shape[0] 
print('Training accuracy: %.2f%%' % (train_acc * 100))
 
y_test_pred = model.predict_classes(X_test, verbose=0)
test_acc = np.sum(Y_test == y_test_pred, axis=0) / X_test.shape[0] 
print('Test accuracy: %.2f%%' % (test_acc * 100))

训练结果如下:

Epoch 45/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2174 - acc: 0.9380 - val_loss: 0.2341 - val_acc: 0.9323
Epoch 46/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2061 - acc: 0.9404 - val_loss: 0.2244 - val_acc: 0.9358
Epoch 47/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1994 - acc: 0.9413 - val_loss: 0.2295 - val_acc: 0.9347
Epoch 48/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2003 - acc: 0.9413 - val_loss: 0.2224 - val_acc: 0.9350
Epoch 49/50
42000/42000 [==============================] - 1s 18us/step - loss: 0.2013 - acc: 0.9417 - val_loss: 0.2248 - val_acc: 0.9359
Epoch 50/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1960 - acc: 0.9433 - val_loss: 0.2300 - val_acc: 0.9346
Training accuracy: 94.11%
Test accuracy: 93.61%

以上这篇keras导入weights方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python time模块详解(常用函数实例讲解,非常好)
Apr 24 Python
python实现根据月份和日期得到星座的方法
Mar 27 Python
总结网络IO模型与select模型的Python实例讲解
Jun 27 Python
PyQt5每天必学之工具提示功能
Apr 19 Python
Django后台获取前端post上传的文件方法
May 28 Python
对Python中数组的几种使用方法总结
Jun 28 Python
使用pandas read_table读取csv文件的方法
Jul 04 Python
详解python:time模块用法
Mar 25 Python
selenium跳过webdriver检测并模拟登录淘宝
Jun 12 Python
python nmap实现端口扫描器教程
May 28 Python
python库skimage给灰度图像染色的方法示例
Apr 27 Python
基于Python 函数和方法的区别说明
Mar 24 Python
keras读取h5文件load_weights、load代码操作
Jun 12 #Python
Python matplotlib 绘制双Y轴曲线图的示例代码
Jun 12 #Python
keras的siamese(孪生网络)实现案例
Jun 12 #Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
You might like
不用GD库生成当前时间的PNG格式图象的程序
2006/10/09 PHP
php正则取img标记中任意属性(正则替换去掉或改变图片img标记中的任意属性)
2013/08/13 PHP
CI(CodeIgniter)框架介绍
2014/06/09 PHP
PHP的PDO事务与自动提交
2019/01/24 PHP
PHP匿名函数(闭包函数)详解
2019/03/22 PHP
你必须知道的JavaScript 中字符串连接的性能的一些问题
2013/05/07 Javascript
讨论html与javascript在浏览器中的加载顺序问题
2013/11/27 Javascript
js生成随机数之random函数随机示例
2013/12/20 Javascript
常用的jquery模板插件——jQuery Boilerplate介绍
2014/09/23 Javascript
nodejs中的fiber(纤程)库详解
2015/03/24 NodeJs
JS实现点击按钮控制Div变宽、增高及调整背景色的方法
2015/08/05 Javascript
PHP实现记录代码运行时间封装类实例教程
2017/05/08 Javascript
vue中阻止click事件冒泡,防止触发另一个事件的方法
2018/02/08 Javascript
webpack4+express+mongodb+vue实现增删改查的示例
2018/11/08 Javascript
详解vue-cli 2.0配置文件(小结)
2019/01/14 Javascript
新手快速上手webpack4打包工具的使用详解
2019/01/28 Javascript
jQuery Ajax async=>false异步改为同步时,解决导致浏览器假死的问题
2019/07/22 jQuery
小程序实现按下录音松开识别语音
2019/11/22 Javascript
vuex(vue状态管理)的特殊应用案例分享
2020/03/03 Javascript
微信小程序pinker组件使用实现自动相减日期
2020/05/07 Javascript
[52:29]DOTA2上海特级锦标赛主赛事日 - 2 胜者组第一轮#3Secret VS OG第三局
2016/03/03 DOTA
Python实现基于HTTP文件传输实例
2014/11/08 Python
举例讲解Python的Tornado框架实现数据可视化的教程
2015/05/02 Python
利用python模拟sql语句对员工表格进行增删改查
2017/07/05 Python
利用python打开摄像头及颜色检测方法
2018/08/03 Python
对python中类的继承与方法重写介绍
2019/01/20 Python
解决Pytorch训练过程中loss不下降的问题
2020/01/02 Python
python 判断一组数据是否符合正态分布
2020/09/23 Python
党员承诺书怎么写
2014/05/20 职场文书
校园演讲稿汇总
2014/05/21 职场文书
乡镇党委书记第三阶段个人整改措施
2014/09/16 职场文书
2015年大学生党员承诺书
2015/04/27 职场文书
企业廉洁教育心得体会
2016/01/20 职场文书
2016年社会管理综治宣传月活动总结
2016/03/16 职场文书
创业计划书之书店
2019/09/10 职场文书
Redis实现订单过期删除的方法步骤
2022/06/05 Redis