keras导入weights方式


Posted in Python onJune 12, 2020

keras源码engine中toplogy.py定义了加载权重的函数:

load_weights(self, filepath, by_name=False)

其中默认by_name为False,这时候加载权重按照网络拓扑结构加载,适合直接使用keras中自带的网络模型,如VGG16

VGG19/resnet50等,源码描述如下:

If `by_name` is False (default) weights are loaded
based on the network's topology, meaning the architecture
should be the same as when the weights were saved.
Note that layers that don't have weights are not taken
into account in the topological ordering, so adding or
removing layers is fine as long as they don't have weights.

若将by_name改为True则加载权重按照layer的name进行,layer的name相同时加载权重,适合用于改变了

模型的相关结构或增加了节点但利用了原网络的主体结构情况下使用,源码描述如下:

If `by_name` is True, weights are loaded into layers
only if they share the same name. This is useful
for fine-tuning or transfer-learning models where
some of the layers have changed.

在进行边缘检测时,利用VGG网络的主体结构,网络中增加反卷积层,这时加载权重应该使用

model.load_weights(filepath,by_name=True)

补充知识:Keras下实现mnist手写数字

之前一直在用tensorflow,被同学推荐来用keras了,把之前文档中的mnist手写数字数据集拿来练手,

代码如下。

import struct
import numpy as np
import os
 
import keras
from keras.models import Sequential 
from keras.layers import Dense
from keras.optimizers import SGD
 
def load_mnist(path, kind):
  labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
  images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)
  with open(labels_path, 'rb') as lbpath:
    magic, n = struct.unpack('>II', lbpath.read(8))
    labels = np.fromfile(lbpath, dtype=np.uint8)
  with open(images_path, 'rb') as imgpath:
    magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16))
    images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) #28*28=784
  return images, labels
 
#loading train and test data
X_train, Y_train = load_mnist('.\\data', kind='train')
X_test, Y_test = load_mnist('.\\data', kind='t10k')
 
#turn labels to one_hot code
Y_train_ohe = keras.utils.to_categorical(Y_train, num_classes=10)
 
#define models
model = Sequential()
 
model.add(Dense(input_dim=X_train.shape[1],output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=Y_train_ohe.shape[1],init='uniform',activation='softmax')) 
 
sgd = SGD(lr=0.001, decay=1e-7, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=["accuracy"])
 
#start training
model.fit(X_train,Y_train_ohe,epochs=50,batch_size=300,shuffle=True,verbose=1,validation_split=0.3)
 
#count accuracy
y_train_pred = model.predict_classes(X_train, verbose=0)
 
train_acc = np.sum(Y_train == y_train_pred, axis=0) / X_train.shape[0] 
print('Training accuracy: %.2f%%' % (train_acc * 100))
 
y_test_pred = model.predict_classes(X_test, verbose=0)
test_acc = np.sum(Y_test == y_test_pred, axis=0) / X_test.shape[0] 
print('Test accuracy: %.2f%%' % (test_acc * 100))

训练结果如下:

Epoch 45/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2174 - acc: 0.9380 - val_loss: 0.2341 - val_acc: 0.9323
Epoch 46/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2061 - acc: 0.9404 - val_loss: 0.2244 - val_acc: 0.9358
Epoch 47/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1994 - acc: 0.9413 - val_loss: 0.2295 - val_acc: 0.9347
Epoch 48/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2003 - acc: 0.9413 - val_loss: 0.2224 - val_acc: 0.9350
Epoch 49/50
42000/42000 [==============================] - 1s 18us/step - loss: 0.2013 - acc: 0.9417 - val_loss: 0.2248 - val_acc: 0.9359
Epoch 50/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1960 - acc: 0.9433 - val_loss: 0.2300 - val_acc: 0.9346
Training accuracy: 94.11%
Test accuracy: 93.61%

以上这篇keras导入weights方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用 Selenium 实现网页截图实例
Jul 18 Python
Python中函数的用法实例教程
Sep 08 Python
零基础写python爬虫之爬虫框架Scrapy安装配置
Nov 06 Python
在Django的上下文中设置变量的方法
Jul 20 Python
Python模块包中__init__.py文件功能分析
Jun 14 Python
Python实现简单的HttpServer服务器示例
Sep 25 Python
浅谈Python使用Bottle来提供一个简单的web服务
Dec 27 Python
Python通过属性手段实现只允许调用一次的示例讲解
Apr 21 Python
Python爬虫使用脚本登录Github并查看信息
Jul 16 Python
django利用request id便于定位及给日志加上request_id
Aug 26 Python
Win10下用Anaconda安装TensorFlow(图文教程)
Jun 18 Python
Python自动发送和收取邮件的方法
Aug 12 Python
keras读取h5文件load_weights、load代码操作
Jun 12 #Python
Python matplotlib 绘制双Y轴曲线图的示例代码
Jun 12 #Python
keras的siamese(孪生网络)实现案例
Jun 12 #Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
You might like
配置PHP使之能同时支持GIF和JPEG
2006/10/09 PHP
PHP FOR MYSQL 代码生成助手(根据Mysql里的字段自动生成类文件的)
2011/07/23 PHP
php另类上传图片的方法(PHP用Socket上传图片)
2013/10/30 PHP
单台服务器的PHP进程之间实现共享内存的方法
2014/06/13 PHP
php中数字、字符与对象判断函数用法实例
2014/11/26 PHP
微信小程序 消息推送php服务器验证实例详解
2017/03/30 PHP
CSS3画一个阴阳八卦图
2021/03/09 HTML / CSS
初学JavaScript第二章
2008/09/30 Javascript
jQuery autocomplete插件修改
2009/04/17 Javascript
jquery text()要注意啦
2009/10/30 Javascript
JS获得QQ号码的昵称,头像,生日的简单实例
2013/12/04 Javascript
jQueryMobile之Helloworld与页面切换的方法
2015/02/04 Javascript
第一次接触JS require.js模块化工具
2016/04/17 Javascript
快速实现JS图片懒加载(可视区域加载)示例代码
2017/01/04 Javascript
为vue-router懒加载时下载js的过程中添加loading提示避免无响应问题
2018/04/03 Javascript
vue树形结构获取键值的方法示例
2018/06/21 Javascript
详解使用React制作一个模态框
2019/03/14 Javascript
基于Vue插入视频的2种方法小结
2019/04/02 Javascript
js 下拉菜单点击旁边收起实现(踩坑记)
2019/09/29 Javascript
ant design vue 表格table 默认勾选几项的操作
2020/10/31 Javascript
element-ui点击查看大图的方法示例
2020/12/14 Javascript
python实现360皮肤按钮控件示例
2014/02/21 Python
python教程之用py2exe将PY文件转成EXE文件
2014/06/12 Python
python 寻找优化使成本函数最小的最优解的方法
2017/12/28 Python
Python 转换RGB颜色值的示例代码
2019/10/13 Python
Python利用多线程同步锁实现多窗口订票系统(推荐)
2019/12/22 Python
canvas简单连线动画的实现代码
2020/02/04 HTML / CSS
澳大利亚玩具剧场:Toy Playhouse
2019/03/03 全球购物
值类型与引用类型有什么不同?请举例说明?并分别列举几种相应的数据类型
2015/10/24 面试题
《泉水》教学反思
2014/04/11 职场文书
群众路线剖析材料怎么写
2014/10/09 职场文书
实习班主任自我评价
2015/03/11 职场文书
趣味运动会新闻稿
2015/07/17 职场文书
六五普法先进个人主要事迹材料
2015/11/03 职场文书
浅谈tf.train.Saver()与tf.train.import_meta_graph的要点
2021/05/26 Python
Mysql InnoDB 的内存逻辑架构
2022/05/06 MySQL