keras读取训练好的模型参数并把参数赋值给其它模型详解


Posted in Python onJune 15, 2020

介绍

本博文中的代码,实现的是加载训练好的模型model_halcon_resenet.h5,并把该模型的参数赋值给两个不同的新的model。

函数式模型

官网上给出的调用一个训练好模型,并输出任意层的feature。

model = Model(inputs=base_model.input, outputs=base_model.get_layer(‘block4_pool').output)

但是这有一个问题,就是新的model,如果输入inputs和训练好的model的inputs大小不同呢?比如我想建立一个输入是600x600x3的新model,但是训练好的model输入是200x200x3,而这时我又想调用训练好模型的卷积核参数,这时该怎么办呢?

其实想一下,用训练好的模型参数,即使输入的尺寸不同,但是这些模型参数仍然可以处理计算,只是输出的feature map大小不同。那到底怎么赋值呢?其实很简单

在定义新的model时,新的model层在定义时,需要加上名字,而这个名字就是训练好的模型的每层名字。如下代码所示:

inputs=Input(shape=(400,500,3))
X=Conv2D(32, (3, 3),name=“conv2d_1”)(inputs)
X=BatchNormalization(name=“batch_normalization_1”)(X)
X=Activation(‘relu',name=“activation_1”)(X)

最后通过以下代码即可建立一个新的模型并拥有训练好模型的参数:

model=Model(inputs=inputs, outputs=X)
model.load_weights(‘model_halcon_resenet.h5', by_name=True)

源代码

from keras.models import load_model
from keras.preprocessing import image
from keras.applications.vgg19 import preprocess_input
from keras.models import Model
import numpy as np
from keras.layers import Conv2D, MaxPooling2D,merge
from keras.layers import BatchNormalization,Activation
from keras.layers import Input, Dense
from PIL import Image
import numpy as np
import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten,Input
from keras.layers import Conv2D, MaxPooling2D,merge,AveragePooling2D,GlobalAveragePooling2D
from keras.layers import BatchNormalization,Activation
from sklearn.model_selection import train_test_split
from keras.applications.densenet import DenseNet169, DenseNet121
from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.applications.inception_v3 import InceptionV3
from keras.optimizers import SGD
from keras import regularizers
from keras.models import Model
import tensorflow as tf
from PIL import Image
from keras.callbacks import TensorBoard
import os
import cv2
from keras import backend as K
from model import focal_loss
import keras.losses

#ReadMe 该代码是参考fast rcnn系列,先对整幅图像提取特征feature map,然后从原图对应位置上映射到feature map,并对feature map进行
# 切片,从而提取对应某个位置上的特征,并把该特征送进后面的识别网络进行分类识别。
keras.losses.focal_loss = focal_loss#这句代码是为了引入定义的loss
base_model=load_model('model_halcon_resenet.h5')
base_model.summary()

inputs=Input(shape=(400,500,3))
X=Conv2D(32, (3, 3),name="conv2d_1")(inputs)
X=BatchNormalization(name="batch_normalization_1")(X)
X=Activation('relu',name="activation_1")(X)
#第一个残差模块
X_1=Conv2D(32, (3, 3),padding='same',name="conv2d_2")(X)
X_1=BatchNormalization(name="batch_normalization_2")(X_1)
X_1= Activation('relu',name="activation_2")(X_1)
X_1 = Conv2D(32, (3, 3),padding='same',name="conv2d_3")(X_1)
X_1 = BatchNormalization(name="batch_normalization_3")(X_1)
merge_data = merge([X_1, X], mode='sum',name="merge_1")
X = Activation('relu',name="activation_3")(merge_data)
#第一个残差模块结束
X=MaxPooling2D(pool_size=(2, 2),strides=(2,2),name="max_pooling2d_1")(X)
X=Conv2D(64, (3, 3),kernel_regularizer=regularizers.l2(0.01),name="conv2d_4")(X)
X=BatchNormalization(name="batch_normalization_4")(X)
X=Activation('relu',name="activation_4")(X)
#第二个残差模块
X_2=Conv2D(64, (3, 3),padding='same',name="conv2d_5")(X)
X_2=BatchNormalization(name="batch_normalization_5")(X_2)
X_2= Activation('relu',name="activation_5")(X_2)
X_2 = Conv2D(64, (3, 3),padding='same',name="conv2d_6")(X_2)
X_2 = BatchNormalization(name="batch_normalization_6")(X_2)
merge_data = merge([X_2, X], mode='sum',name="merge_2")
X = Activation('relu',name="activation_6")(merge_data)
#第二个残差模块结束
X = MaxPooling2D(pool_size=(2, 2), strides=(2, 2),name="max_pooling2d_2")(X)
X=Conv2D(64, (3, 3),name="conv2d_7")(X)
X=BatchNormalization(name="batch_normalization_7")(X)
X=Activation('relu',name="activation_7")(X)
X=MaxPooling2D(pool_size=(2, 2),strides=(2,2),name="max_pooling2d_3")(X)
#第三个残差模块开始
X_3=Conv2D(64, (3, 3),padding='same',name="conv2d_8")(X)
X_3=BatchNormalization(name="batch_normalization_8")(X_3)
X_3= Activation('relu',name="activation_8")(X_3)
X_3 = Conv2D(64, (3, 3),padding='same',name="conv2d_9")(X_3)
X_3 = BatchNormalization(name="batch_normalization_9")(X_3)
merge_data = merge([X_3, X], mode='sum',name="merge_3")
X = Activation('relu',name="activation_9")(merge_data)
#第三个残差模块结束
X=Conv2D(32, (3, 3),kernel_regularizer=regularizers.l2(0.01),name="conv2d_10")(X)
X=BatchNormalization(name="batch_normalization_10")(X)
X=Activation('relu',name="activation_10")(X)
#第四个残差模块开始
X_4=Conv2D(32, (3, 3),padding='same',name="conv2d_11")(X)
X_4=BatchNormalization(name="batch_normalization_11")(X_4)
X_4= Activation('relu',name="activation_11")(X_4)
X_4 = Conv2D(32, (3, 3),padding='same',name="conv2d_12")(X_4)
X_4 = BatchNormalization(name="batch_normalization_12")(X_4)
merge_data = merge([X_4, X], mode='sum',name="merge_4")
X = Activation('relu',name="activation_12")(merge_data)
#第四个残差模块结束
X = MaxPooling2D(pool_size=(2, 2), strides=(2, 2),name="max_pooling2d_4")(X)
X = Conv2D(64, (3, 3),name="conv2d_13")(X)
X = BatchNormalization(name="batch_normalization_13")(X)
X = Activation('relu',name="activation_13")(X)
#第五个残差模块开始
X_5=Conv2D(64, (3, 3),padding='same',name="conv2d_14")(X)
X_5=BatchNormalization(name="batch_normalization_14")(X_5)
X_5= Activation('relu',name="activation_14")(X_5)
X_5 = Conv2D(64, (3, 3),padding='same',name="conv2d_15")(X_5)
X_5 = BatchNormalization(name="batch_normalization_15")(X_5)
merge_data = merge([X_5, X], mode='sum',name="merge_5")
X = Activation('relu',name="activation_15")(merge_data)
#第五个残差模块结束
model=Model(inputs=inputs, outputs=X)
model.load_weights('model_halcon_resenet.h5', by_name=True)
#读取指定图像数据
image_dir='C:/Users/18301/Desktop/blister/new/blister_mixed_11.png'
img = image.load_img(image_dir, target_size=(400, 500))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
#利用第一个模型预测出特征数据,并对特征数据进行切片
feature_map=model.predict(x)
T=np.array(feature_map)
f_1=T[:,16:21,0:10,:]
print(f_1.shape)
print(feature_map.shape)
#第一个模型没有问题
#定义第二个模型
inputs_sec=Input(shape=(1,5,10,64))
X_= Flatten(name="flatten_1")(inputs_sec)
X_ = Dense(256, activation='relu',name="dense_1")(X_)
X_ = Dropout(0.5,name="dropout_1")(X_)
predictions = Dense(6, activation='softmax',name="dense_2")(X_)
model_sec=Model(inputs=inputs_sec, outputs=predictions)
model_sec.load_weights('model_halcon_resenet.h5', by_name=True)
#第二个模型定义结束
model_sec.summary()
#开始对整幅图像进行切片,并记录坐标位置
pic=cv2.imread(image_dir)
cor_list=[]
name_list=['blank','green_blank','red_blank','yellow','yellow_balnk','yellow_blue']
font = cv2.FONT_HERSHEY_SIMPLEX
for i in range(3):
 for j in range(5):
 if(i==2):
  cut_feature = T[:, 4 * j:4 * j + 5, 17:27, :]
  data = np.expand_dims(cut_feature, axis=0)
  result = model_sec.predict(data)
  print(result)
  result_data=result[0].tolist()
  #如果置信度过低,则舍弃
  # if(max(result_data)<=0.7):
  # continue
  index_num = result_data.index(max(result_data))
  name=name_list[index_num]
  cor_list = [i * 160 + 6, j * 80] # 每个切片数据,映射到原图上,检测框对应的左上角坐标
  x=cor_list[0]
  y=cor_list[1]
  cv2.rectangle(pic, (160 * i + 6, 80 * j), ((i + 1) * 160 + 6, 80 * (j+ 1)), (0, 255, 0), 2)
  cv2.putText(pic, name, (x + 40, y + 40), font, 0.5, (0, 0, 255), 1)
 else:
  cut_feature = T[:, 4 * j:4 * j + 5, 9 * i:9 * i + 10, :]
  data = np.expand_dims(cut_feature, axis=0)
  result = model_sec.predict(data)
  print(result)
  result_data = result[0].tolist()
  #如果置信度过低,则舍弃
  # if (max(result_data) <= 0.7):
  # continue
  index_num = result_data.index(max(result_data))
  name = name_list[index_num]
  cor_list = [i * 160 + 6, j * 80] # 每个切片数据,映射到原图上,检测框对应的左上角坐标
  x = cor_list[0]
  y = cor_list[1]
  cv2.rectangle(pic, (160 * i + 6, 80 * j), ((i + 1) * 160 + 6, 80 * (j + 1)), (0, 255, 0), 2)
  cv2.putText(pic, name, (x + 40, y + 40), font, 0.5, (0, 0, 255), 1)

cv2.imshow('pic',pic)
cv2.waitKey(0)
cv2.destroyAllWindows()
# data= np.expand_dims(f_1, axis=0)
# result=model_sec.predict(data)
# print(result)
#第二个模型可以完全预测,没有问题

补充知识:加载训练好的模型参数,但是权重一直变化

keras读取训练好的模型参数并把参数赋值给其它模型详解

变量初始化会导致权重发生变化,去掉就好了。

以上这篇keras读取训练好的模型参数并把参数赋值给其它模型详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python脚本实现下载合并SAE日志
Feb 10 Python
Python 登录网站详解及实例
Apr 11 Python
python安装教程 Pycharm安装详细教程
May 02 Python
Python3生成手写体数字方法
Jan 30 Python
Python3.5 Pandas模块之DataFrame用法实例分析
Apr 23 Python
Django操作session 的方法
Mar 09 Python
django model的update时auto_now不被更新的原因及解决方式
Apr 01 Python
Python装饰器的应用场景代码总结
Apr 10 Python
在TensorFlow中实现矩阵维度扩展
May 22 Python
Python爬虫谷歌Chrome F12抓包过程原理解析
Jun 04 Python
基于python实现删除指定文件类型
Jul 21 Python
Pytorch 使用tensor特定条件判断索引
Apr 08 Python
keras得到每层的系数方式
Jun 15 #Python
Python类及获取对象属性方法解析
Jun 15 #Python
在Keras中实现保存和加载权重及模型结构
Jun 15 #Python
简单了解Python多态与属性运行原理
Jun 15 #Python
Python类super()及私有属性原理解析
Jun 15 #Python
Keras 实现加载预训练模型并冻结网络的层
Jun 15 #Python
Python StringIO及BytesIO包使用方法解析
Jun 15 #Python
You might like
php中运用http调用的GET和POST方法示例
2014/09/29 PHP
PHP实现适用于自定义的验证码类
2016/06/15 PHP
javascript实现时间格式输出FormatDate函数
2015/01/13 Javascript
javascript中的Base64、UTF8编码与解码详解
2015/03/18 Javascript
js获取url传值的方法
2015/12/18 Javascript
jQuery hover事件简单实现同时绑定2个方法
2016/06/07 Javascript
自动化测试读写64位操作系统的注册表
2016/08/15 Javascript
JS实现PC手机端和嵌入式滑动拼图验证码三种效果
2017/02/15 Javascript
ES6新特性之函数的扩展实例详解
2017/04/01 Javascript
微信小程序教程系列之视图层的条件渲染(10)
2017/04/19 Javascript
jquery实现图片跟随鼠标的实例
2017/10/17 jQuery
Node.js模块全局安装路径配置方法
2018/05/17 Javascript
VUE 3D轮播图封装实现方法
2018/07/03 Javascript
微信小程序云开发实现增删改查功能
2019/05/17 Javascript
VUE 实现动态给对象增加属性,并触发视图更新操作示例
2019/11/29 Javascript
Python help()函数用法详解
2014/03/11 Python
Python读写Excel文件方法介绍
2014/11/22 Python
python flask实现分页的示例代码
2018/08/02 Python
Pycharm+Scrapy安装并且初始化项目的方法
2019/01/15 Python
Pandas之DataFrame对象的列和索引之间的转化
2019/06/25 Python
Pyinstaller 打包exe教程及问题解决
2019/08/16 Python
wxPython实现文本框基础组件
2019/11/18 Python
CSS3 重置iphone浏览器按钮input,select等表单元素的默认样式
2014/10/11 HTML / CSS
如何让IE9以下版本(ie6/7/8)认识html5元素
2013/04/01 HTML / CSS
详解HTML5中的picture元素响应式处理图片
2018/01/03 HTML / CSS
英国玛莎百货澳大利亚:Marks & Spencer Australia
2019/08/30 全球购物
美国传奇滑手Paul Rodriguez创办的街头滑板品牌:Primitive Skateboarding
2019/10/29 全球购物
美国知名眼镜网站:Target Optical
2020/04/04 全球购物
芭比波朗加拿大官方网站:Bobbi Brown Cosmetics CA
2020/11/05 全球购物
法语专业求职信
2014/07/20 职场文书
采购员岗位职责范本
2015/04/07 职场文书
跑吧孩子观后感
2015/06/10 职场文书
2015年学校管理工作总结
2015/07/20 职场文书
大学生村官工作心得体会
2016/01/23 职场文书
2019朋友新婚祝福语精选
2019/10/10 职场文书
Nginx反向代理学习实例教程
2021/10/24 Servers