keras导入weights方式


Posted in Python onJune 12, 2020

keras源码engine中toplogy.py定义了加载权重的函数:

load_weights(self, filepath, by_name=False)

其中默认by_name为False,这时候加载权重按照网络拓扑结构加载,适合直接使用keras中自带的网络模型,如VGG16

VGG19/resnet50等,源码描述如下:

If `by_name` is False (default) weights are loaded
based on the network's topology, meaning the architecture
should be the same as when the weights were saved.
Note that layers that don't have weights are not taken
into account in the topological ordering, so adding or
removing layers is fine as long as they don't have weights.

若将by_name改为True则加载权重按照layer的name进行,layer的name相同时加载权重,适合用于改变了

模型的相关结构或增加了节点但利用了原网络的主体结构情况下使用,源码描述如下:

If `by_name` is True, weights are loaded into layers
only if they share the same name. This is useful
for fine-tuning or transfer-learning models where
some of the layers have changed.

在进行边缘检测时,利用VGG网络的主体结构,网络中增加反卷积层,这时加载权重应该使用

model.load_weights(filepath,by_name=True)

补充知识:Keras下实现mnist手写数字

之前一直在用tensorflow,被同学推荐来用keras了,把之前文档中的mnist手写数字数据集拿来练手,

代码如下。

import struct
import numpy as np
import os
 
import keras
from keras.models import Sequential 
from keras.layers import Dense
from keras.optimizers import SGD
 
def load_mnist(path, kind):
  labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
  images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)
  with open(labels_path, 'rb') as lbpath:
    magic, n = struct.unpack('>II', lbpath.read(8))
    labels = np.fromfile(lbpath, dtype=np.uint8)
  with open(images_path, 'rb') as imgpath:
    magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16))
    images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) #28*28=784
  return images, labels
 
#loading train and test data
X_train, Y_train = load_mnist('.\\data', kind='train')
X_test, Y_test = load_mnist('.\\data', kind='t10k')
 
#turn labels to one_hot code
Y_train_ohe = keras.utils.to_categorical(Y_train, num_classes=10)
 
#define models
model = Sequential()
 
model.add(Dense(input_dim=X_train.shape[1],output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=Y_train_ohe.shape[1],init='uniform',activation='softmax')) 
 
sgd = SGD(lr=0.001, decay=1e-7, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=["accuracy"])
 
#start training
model.fit(X_train,Y_train_ohe,epochs=50,batch_size=300,shuffle=True,verbose=1,validation_split=0.3)
 
#count accuracy
y_train_pred = model.predict_classes(X_train, verbose=0)
 
train_acc = np.sum(Y_train == y_train_pred, axis=0) / X_train.shape[0] 
print('Training accuracy: %.2f%%' % (train_acc * 100))
 
y_test_pred = model.predict_classes(X_test, verbose=0)
test_acc = np.sum(Y_test == y_test_pred, axis=0) / X_test.shape[0] 
print('Test accuracy: %.2f%%' % (test_acc * 100))

训练结果如下:

Epoch 45/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2174 - acc: 0.9380 - val_loss: 0.2341 - val_acc: 0.9323
Epoch 46/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2061 - acc: 0.9404 - val_loss: 0.2244 - val_acc: 0.9358
Epoch 47/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1994 - acc: 0.9413 - val_loss: 0.2295 - val_acc: 0.9347
Epoch 48/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2003 - acc: 0.9413 - val_loss: 0.2224 - val_acc: 0.9350
Epoch 49/50
42000/42000 [==============================] - 1s 18us/step - loss: 0.2013 - acc: 0.9417 - val_loss: 0.2248 - val_acc: 0.9359
Epoch 50/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1960 - acc: 0.9433 - val_loss: 0.2300 - val_acc: 0.9346
Training accuracy: 94.11%
Test accuracy: 93.61%

以上这篇keras导入weights方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python3序列化与反序列化用法实例
May 26 Python
利用Python实现Windows定时关机功能
Mar 21 Python
Python编程django实现同一个ip十分钟内只能注册一次
Nov 03 Python
python导出hive数据表的schema实例代码
Jan 22 Python
对python中矩阵相加函数sum()的使用详解
Jan 28 Python
python实现微信机器人: 登录微信、消息接收、自动回复功能
Apr 29 Python
python3.6 如何将list存入txt后再读出list的方法
Jul 02 Python
pytorch 图像预处理之减去均值,除以方差的实例
Jan 02 Python
np.random.seed() 的使用详解
Jan 14 Python
Matplotlib中rcParams使用方法
Jan 05 Python
详解Python中的Lock和Rlock
Jan 26 Python
python抢购软件/插件/脚本附完整源码
Mar 04 Python
keras读取h5文件load_weights、load代码操作
Jun 12 #Python
Python matplotlib 绘制双Y轴曲线图的示例代码
Jun 12 #Python
keras的siamese(孪生网络)实现案例
Jun 12 #Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
You might like
PHP截取中文字符串的问题
2006/07/12 PHP
php下过滤HTML代码的函数
2007/12/10 PHP
讲解WordPress中用于获取评论模板和搜索表单的PHP函数
2015/12/28 PHP
Mootools 1.2教程 设置和获取样式表属性
2009/09/15 Javascript
js封装的textarea操作方法集合(兼容很好)
2010/11/16 Javascript
JavaScript Title、alt提示(Tips)实现源码解读
2010/12/12 Javascript
Javascript的getYear、getFullYear、getUTCFullYear异同分享
2011/11/30 Javascript
游览器中javascript的执行过程(图文)
2012/05/20 Javascript
JavaScript基础知识之数据类型
2012/08/06 Javascript
用jQuery模拟select下拉框的简单示例代码
2014/01/26 Javascript
js实现多选项切换导航菜单的方法
2015/02/06 Javascript
js如何实现点击标签文字,文字在文本框出现
2015/08/05 Javascript
jQuery使用ajax方法解析返回的json数据功能示例
2017/01/10 Javascript
jQuery操作DOM_动力节点Java学院整理
2017/07/04 jQuery
利用JS制作万年历的方法
2017/08/16 Javascript
vue中实现图片和文件上传的示例代码
2018/03/16 Javascript
React中的render何时执行过程
2018/04/13 Javascript
一个简单的node.js界面实现方法
2018/06/01 Javascript
小程序Scroll-view上拉滚动刷新数据
2020/06/21 Javascript
Python中的元类编程入门指引
2015/04/15 Python
Python面向对象编程基础解析(一)
2017/10/26 Python
Django框架实现的普通登录案例【使用POST方法】
2019/05/15 Python
python向企业微信发送文字和图片消息的示例
2020/09/28 Python
常用的四种CSS透明属性介绍
2014/04/12 HTML / CSS
使用 css3 实现圆形进度条的示例
2017/07/05 HTML / CSS
Html5页面获取微信公众号的openid的方法
2020/05/12 HTML / CSS
英国皇家造币厂:The Royal Mint
2018/10/05 全球购物
行政人员工作职责
2013/12/05 职场文书
仓库文员岗位职责
2014/04/06 职场文书
中学生期中自我鉴定
2014/04/20 职场文书
班级学习计划书
2014/04/27 职场文书
学雷锋活动总结报告
2014/06/26 职场文书
房地产资料员岗位职责
2014/07/02 职场文书
机关作风整顿个人剖析材料
2014/10/06 职场文书
金融专业银行实习证明模板
2014/11/28 职场文书
正规借条模板
2015/05/26 职场文书