keras导入weights方式


Posted in Python onJune 12, 2020

keras源码engine中toplogy.py定义了加载权重的函数:

load_weights(self, filepath, by_name=False)

其中默认by_name为False,这时候加载权重按照网络拓扑结构加载,适合直接使用keras中自带的网络模型,如VGG16

VGG19/resnet50等,源码描述如下:

If `by_name` is False (default) weights are loaded
based on the network's topology, meaning the architecture
should be the same as when the weights were saved.
Note that layers that don't have weights are not taken
into account in the topological ordering, so adding or
removing layers is fine as long as they don't have weights.

若将by_name改为True则加载权重按照layer的name进行,layer的name相同时加载权重,适合用于改变了

模型的相关结构或增加了节点但利用了原网络的主体结构情况下使用,源码描述如下:

If `by_name` is True, weights are loaded into layers
only if they share the same name. This is useful
for fine-tuning or transfer-learning models where
some of the layers have changed.

在进行边缘检测时,利用VGG网络的主体结构,网络中增加反卷积层,这时加载权重应该使用

model.load_weights(filepath,by_name=True)

补充知识:Keras下实现mnist手写数字

之前一直在用tensorflow,被同学推荐来用keras了,把之前文档中的mnist手写数字数据集拿来练手,

代码如下。

import struct
import numpy as np
import os
 
import keras
from keras.models import Sequential 
from keras.layers import Dense
from keras.optimizers import SGD
 
def load_mnist(path, kind):
  labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
  images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)
  with open(labels_path, 'rb') as lbpath:
    magic, n = struct.unpack('>II', lbpath.read(8))
    labels = np.fromfile(lbpath, dtype=np.uint8)
  with open(images_path, 'rb') as imgpath:
    magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16))
    images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) #28*28=784
  return images, labels
 
#loading train and test data
X_train, Y_train = load_mnist('.\\data', kind='train')
X_test, Y_test = load_mnist('.\\data', kind='t10k')
 
#turn labels to one_hot code
Y_train_ohe = keras.utils.to_categorical(Y_train, num_classes=10)
 
#define models
model = Sequential()
 
model.add(Dense(input_dim=X_train.shape[1],output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=Y_train_ohe.shape[1],init='uniform',activation='softmax')) 
 
sgd = SGD(lr=0.001, decay=1e-7, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=["accuracy"])
 
#start training
model.fit(X_train,Y_train_ohe,epochs=50,batch_size=300,shuffle=True,verbose=1,validation_split=0.3)
 
#count accuracy
y_train_pred = model.predict_classes(X_train, verbose=0)
 
train_acc = np.sum(Y_train == y_train_pred, axis=0) / X_train.shape[0] 
print('Training accuracy: %.2f%%' % (train_acc * 100))
 
y_test_pred = model.predict_classes(X_test, verbose=0)
test_acc = np.sum(Y_test == y_test_pred, axis=0) / X_test.shape[0] 
print('Test accuracy: %.2f%%' % (test_acc * 100))

训练结果如下:

Epoch 45/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2174 - acc: 0.9380 - val_loss: 0.2341 - val_acc: 0.9323
Epoch 46/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2061 - acc: 0.9404 - val_loss: 0.2244 - val_acc: 0.9358
Epoch 47/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1994 - acc: 0.9413 - val_loss: 0.2295 - val_acc: 0.9347
Epoch 48/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2003 - acc: 0.9413 - val_loss: 0.2224 - val_acc: 0.9350
Epoch 49/50
42000/42000 [==============================] - 1s 18us/step - loss: 0.2013 - acc: 0.9417 - val_loss: 0.2248 - val_acc: 0.9359
Epoch 50/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1960 - acc: 0.9433 - val_loss: 0.2300 - val_acc: 0.9346
Training accuracy: 94.11%
Test accuracy: 93.61%

以上这篇keras导入weights方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用any判断一个对象是否为空的方法
Nov 19 Python
Python实现Sqlite将字段当做索引进行查询的方法
Jul 21 Python
Python 12306抢火车票脚本
Feb 07 Python
numpy.delete删除一列或多列的方法
Apr 03 Python
Python3.4 splinter(模拟填写表单)使用方法
Oct 13 Python
详解python tkinter教程-事件绑定
Mar 28 Python
python 读取修改pcap包的例子
Jul 23 Python
python实现文件的分割与合并
Aug 29 Python
Python chardet库识别编码原理解析
Feb 18 Python
Python并发concurrent.futures和asyncio实例
May 04 Python
keras 回调函数Callbacks 断点ModelCheckpoint教程
Jun 18 Python
Python调用系统命令os.system()和os.popen()的实现
Dec 31 Python
keras读取h5文件load_weights、load代码操作
Jun 12 #Python
Python matplotlib 绘制双Y轴曲线图的示例代码
Jun 12 #Python
keras的siamese(孪生网络)实现案例
Jun 12 #Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
You might like
虚拟主机中对PHP的特殊设置
2006/10/09 PHP
php之CodeIgniter学习笔记
2013/06/17 PHP
thinkphp模板用法和内容输出实例
2014/11/28 PHP
php创建图像具体步骤
2017/03/13 PHP
实例讲解PHP表单
2020/06/10 PHP
PHP扩展安装方法步骤解析
2020/11/24 PHP
js wmp操作代码小结(音乐连播功能)
2008/11/08 Javascript
简洁短小的 JavaScript IE 浏览器判定代码
2010/03/21 Javascript
JavaScript Eval 函数使用
2010/03/23 Javascript
JS仿flash上传头像效果实现代码
2011/07/18 Javascript
同时使用n个window onload加载实例介绍
2013/04/25 Javascript
js实现完美兼容各大浏览器的人民币大小写相互转换
2015/10/29 Javascript
js实现百度搜索提示框
2017/02/05 Javascript
BootStrap Table 后台数据绑定、特殊列处理、排序功能
2017/05/27 Javascript
Angular中ng-options下拉数据默认值的设定方法
2017/06/21 Javascript
详解vue+css3做交互特效的方法
2017/11/20 Javascript
垃圾回收器的相关知识点总结
2018/05/13 Javascript
jQuery 点击获取验证码按钮及倒计时功能
2018/09/20 jQuery
python MysqlDb模块安装及其使用详解
2018/02/23 Python
利用python-docx模块写批量生日邀请函
2019/08/26 Python
使用Tkinter制作信息提示框
2020/02/18 Python
pycharm如何实现跨目录调用文件
2020/02/28 Python
彻底搞懂 python 中文乱码问题(深入分析)
2020/02/28 Python
python中np是做什么的
2020/07/21 Python
一款利用css3的鼠标经过动画显示详情特效的实例教程
2014/12/29 HTML / CSS
俄罗斯电子产品在线商店:UltraTrade
2020/01/30 全球购物
问卷调查计划书
2014/01/10 职场文书
英语简历自我评价
2014/01/26 职场文书
学校欢迎标语
2014/06/18 职场文书
2014迎接教师节演讲稿
2014/09/10 职场文书
初中生300字旷课检讨书
2014/11/19 职场文书
2014年汽车销售工作总结
2014/12/01 职场文书
奖励通知
2015/04/22 职场文书
暑期辅导班宣传单
2015/07/14 职场文书
2016五四青年节活动总结范文
2016/04/06 职场文书
你对自己的信用报告有过了解吗?
2019/07/09 职场文书