keras读取训练好的模型参数并把参数赋值给其它模型详解


Posted in Python onJune 15, 2020

介绍

本博文中的代码,实现的是加载训练好的模型model_halcon_resenet.h5,并把该模型的参数赋值给两个不同的新的model。

函数式模型

官网上给出的调用一个训练好模型,并输出任意层的feature。

model = Model(inputs=base_model.input, outputs=base_model.get_layer(‘block4_pool').output)

但是这有一个问题,就是新的model,如果输入inputs和训练好的model的inputs大小不同呢?比如我想建立一个输入是600x600x3的新model,但是训练好的model输入是200x200x3,而这时我又想调用训练好模型的卷积核参数,这时该怎么办呢?

其实想一下,用训练好的模型参数,即使输入的尺寸不同,但是这些模型参数仍然可以处理计算,只是输出的feature map大小不同。那到底怎么赋值呢?其实很简单

在定义新的model时,新的model层在定义时,需要加上名字,而这个名字就是训练好的模型的每层名字。如下代码所示:

inputs=Input(shape=(400,500,3))
X=Conv2D(32, (3, 3),name=“conv2d_1”)(inputs)
X=BatchNormalization(name=“batch_normalization_1”)(X)
X=Activation(‘relu',name=“activation_1”)(X)

最后通过以下代码即可建立一个新的模型并拥有训练好模型的参数:

model=Model(inputs=inputs, outputs=X)
model.load_weights(‘model_halcon_resenet.h5', by_name=True)

源代码

from keras.models import load_model
from keras.preprocessing import image
from keras.applications.vgg19 import preprocess_input
from keras.models import Model
import numpy as np
from keras.layers import Conv2D, MaxPooling2D,merge
from keras.layers import BatchNormalization,Activation
from keras.layers import Input, Dense
from PIL import Image
import numpy as np
import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten,Input
from keras.layers import Conv2D, MaxPooling2D,merge,AveragePooling2D,GlobalAveragePooling2D
from keras.layers import BatchNormalization,Activation
from sklearn.model_selection import train_test_split
from keras.applications.densenet import DenseNet169, DenseNet121
from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.applications.inception_v3 import InceptionV3
from keras.optimizers import SGD
from keras import regularizers
from keras.models import Model
import tensorflow as tf
from PIL import Image
from keras.callbacks import TensorBoard
import os
import cv2
from keras import backend as K
from model import focal_loss
import keras.losses

#ReadMe 该代码是参考fast rcnn系列,先对整幅图像提取特征feature map,然后从原图对应位置上映射到feature map,并对feature map进行
# 切片,从而提取对应某个位置上的特征,并把该特征送进后面的识别网络进行分类识别。
keras.losses.focal_loss = focal_loss#这句代码是为了引入定义的loss
base_model=load_model('model_halcon_resenet.h5')
base_model.summary()

inputs=Input(shape=(400,500,3))
X=Conv2D(32, (3, 3),name="conv2d_1")(inputs)
X=BatchNormalization(name="batch_normalization_1")(X)
X=Activation('relu',name="activation_1")(X)
#第一个残差模块
X_1=Conv2D(32, (3, 3),padding='same',name="conv2d_2")(X)
X_1=BatchNormalization(name="batch_normalization_2")(X_1)
X_1= Activation('relu',name="activation_2")(X_1)
X_1 = Conv2D(32, (3, 3),padding='same',name="conv2d_3")(X_1)
X_1 = BatchNormalization(name="batch_normalization_3")(X_1)
merge_data = merge([X_1, X], mode='sum',name="merge_1")
X = Activation('relu',name="activation_3")(merge_data)
#第一个残差模块结束
X=MaxPooling2D(pool_size=(2, 2),strides=(2,2),name="max_pooling2d_1")(X)
X=Conv2D(64, (3, 3),kernel_regularizer=regularizers.l2(0.01),name="conv2d_4")(X)
X=BatchNormalization(name="batch_normalization_4")(X)
X=Activation('relu',name="activation_4")(X)
#第二个残差模块
X_2=Conv2D(64, (3, 3),padding='same',name="conv2d_5")(X)
X_2=BatchNormalization(name="batch_normalization_5")(X_2)
X_2= Activation('relu',name="activation_5")(X_2)
X_2 = Conv2D(64, (3, 3),padding='same',name="conv2d_6")(X_2)
X_2 = BatchNormalization(name="batch_normalization_6")(X_2)
merge_data = merge([X_2, X], mode='sum',name="merge_2")
X = Activation('relu',name="activation_6")(merge_data)
#第二个残差模块结束
X = MaxPooling2D(pool_size=(2, 2), strides=(2, 2),name="max_pooling2d_2")(X)
X=Conv2D(64, (3, 3),name="conv2d_7")(X)
X=BatchNormalization(name="batch_normalization_7")(X)
X=Activation('relu',name="activation_7")(X)
X=MaxPooling2D(pool_size=(2, 2),strides=(2,2),name="max_pooling2d_3")(X)
#第三个残差模块开始
X_3=Conv2D(64, (3, 3),padding='same',name="conv2d_8")(X)
X_3=BatchNormalization(name="batch_normalization_8")(X_3)
X_3= Activation('relu',name="activation_8")(X_3)
X_3 = Conv2D(64, (3, 3),padding='same',name="conv2d_9")(X_3)
X_3 = BatchNormalization(name="batch_normalization_9")(X_3)
merge_data = merge([X_3, X], mode='sum',name="merge_3")
X = Activation('relu',name="activation_9")(merge_data)
#第三个残差模块结束
X=Conv2D(32, (3, 3),kernel_regularizer=regularizers.l2(0.01),name="conv2d_10")(X)
X=BatchNormalization(name="batch_normalization_10")(X)
X=Activation('relu',name="activation_10")(X)
#第四个残差模块开始
X_4=Conv2D(32, (3, 3),padding='same',name="conv2d_11")(X)
X_4=BatchNormalization(name="batch_normalization_11")(X_4)
X_4= Activation('relu',name="activation_11")(X_4)
X_4 = Conv2D(32, (3, 3),padding='same',name="conv2d_12")(X_4)
X_4 = BatchNormalization(name="batch_normalization_12")(X_4)
merge_data = merge([X_4, X], mode='sum',name="merge_4")
X = Activation('relu',name="activation_12")(merge_data)
#第四个残差模块结束
X = MaxPooling2D(pool_size=(2, 2), strides=(2, 2),name="max_pooling2d_4")(X)
X = Conv2D(64, (3, 3),name="conv2d_13")(X)
X = BatchNormalization(name="batch_normalization_13")(X)
X = Activation('relu',name="activation_13")(X)
#第五个残差模块开始
X_5=Conv2D(64, (3, 3),padding='same',name="conv2d_14")(X)
X_5=BatchNormalization(name="batch_normalization_14")(X_5)
X_5= Activation('relu',name="activation_14")(X_5)
X_5 = Conv2D(64, (3, 3),padding='same',name="conv2d_15")(X_5)
X_5 = BatchNormalization(name="batch_normalization_15")(X_5)
merge_data = merge([X_5, X], mode='sum',name="merge_5")
X = Activation('relu',name="activation_15")(merge_data)
#第五个残差模块结束
model=Model(inputs=inputs, outputs=X)
model.load_weights('model_halcon_resenet.h5', by_name=True)
#读取指定图像数据
image_dir='C:/Users/18301/Desktop/blister/new/blister_mixed_11.png'
img = image.load_img(image_dir, target_size=(400, 500))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
#利用第一个模型预测出特征数据,并对特征数据进行切片
feature_map=model.predict(x)
T=np.array(feature_map)
f_1=T[:,16:21,0:10,:]
print(f_1.shape)
print(feature_map.shape)
#第一个模型没有问题
#定义第二个模型
inputs_sec=Input(shape=(1,5,10,64))
X_= Flatten(name="flatten_1")(inputs_sec)
X_ = Dense(256, activation='relu',name="dense_1")(X_)
X_ = Dropout(0.5,name="dropout_1")(X_)
predictions = Dense(6, activation='softmax',name="dense_2")(X_)
model_sec=Model(inputs=inputs_sec, outputs=predictions)
model_sec.load_weights('model_halcon_resenet.h5', by_name=True)
#第二个模型定义结束
model_sec.summary()
#开始对整幅图像进行切片,并记录坐标位置
pic=cv2.imread(image_dir)
cor_list=[]
name_list=['blank','green_blank','red_blank','yellow','yellow_balnk','yellow_blue']
font = cv2.FONT_HERSHEY_SIMPLEX
for i in range(3):
 for j in range(5):
 if(i==2):
  cut_feature = T[:, 4 * j:4 * j + 5, 17:27, :]
  data = np.expand_dims(cut_feature, axis=0)
  result = model_sec.predict(data)
  print(result)
  result_data=result[0].tolist()
  #如果置信度过低,则舍弃
  # if(max(result_data)<=0.7):
  # continue
  index_num = result_data.index(max(result_data))
  name=name_list[index_num]
  cor_list = [i * 160 + 6, j * 80] # 每个切片数据,映射到原图上,检测框对应的左上角坐标
  x=cor_list[0]
  y=cor_list[1]
  cv2.rectangle(pic, (160 * i + 6, 80 * j), ((i + 1) * 160 + 6, 80 * (j+ 1)), (0, 255, 0), 2)
  cv2.putText(pic, name, (x + 40, y + 40), font, 0.5, (0, 0, 255), 1)
 else:
  cut_feature = T[:, 4 * j:4 * j + 5, 9 * i:9 * i + 10, :]
  data = np.expand_dims(cut_feature, axis=0)
  result = model_sec.predict(data)
  print(result)
  result_data = result[0].tolist()
  #如果置信度过低,则舍弃
  # if (max(result_data) <= 0.7):
  # continue
  index_num = result_data.index(max(result_data))
  name = name_list[index_num]
  cor_list = [i * 160 + 6, j * 80] # 每个切片数据,映射到原图上,检测框对应的左上角坐标
  x = cor_list[0]
  y = cor_list[1]
  cv2.rectangle(pic, (160 * i + 6, 80 * j), ((i + 1) * 160 + 6, 80 * (j + 1)), (0, 255, 0), 2)
  cv2.putText(pic, name, (x + 40, y + 40), font, 0.5, (0, 0, 255), 1)

cv2.imshow('pic',pic)
cv2.waitKey(0)
cv2.destroyAllWindows()
# data= np.expand_dims(f_1, axis=0)
# result=model_sec.predict(data)
# print(result)
#第二个模型可以完全预测,没有问题

补充知识:加载训练好的模型参数,但是权重一直变化

keras读取训练好的模型参数并把参数赋值给其它模型详解

变量初始化会导致权重发生变化,去掉就好了。

以上这篇keras读取训练好的模型参数并把参数赋值给其它模型详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
安装Python的web.py框架并从hello world开始编程
Apr 25 Python
Python中的lstrip()方法使用简介
May 19 Python
python目录与文件名操作例子
Aug 28 Python
Python编程判断一个正整数是否为素数的方法
Apr 14 Python
python中文分词,使用结巴分词对python进行分词(实例讲解)
Nov 14 Python
python爬虫之模拟登陆csdn的实例代码
May 18 Python
Python使用pymongo模块操作MongoDB的方法示例
Jul 20 Python
python实现文件助手中查看微信撤回消息
Apr 29 Python
python读取ini配置的类封装代码实例
Jan 08 Python
使用IPython或Spyder将省略号表示的内容完整输出
Apr 20 Python
基于OpenCV的网络实时视频流传输的实现
Nov 15 Python
python处理json数据文件
Apr 11 Python
keras得到每层的系数方式
Jun 15 #Python
Python类及获取对象属性方法解析
Jun 15 #Python
在Keras中实现保存和加载权重及模型结构
Jun 15 #Python
简单了解Python多态与属性运行原理
Jun 15 #Python
Python类super()及私有属性原理解析
Jun 15 #Python
Keras 实现加载预训练模型并冻结网络的层
Jun 15 #Python
Python StringIO及BytesIO包使用方法解析
Jun 15 #Python
You might like
搜索引擎技术核心揭密
2006/10/09 PHP
PHP操作xml代码
2010/06/17 PHP
php去掉URL网址中带有PHPSESSID的配置方法
2014/07/08 PHP
Smarty环境配置与使用入门教程
2016/05/11 PHP
PHP封装的数据库保存session功能类
2016/07/11 PHP
PHP图形计数器程序显示网站用户浏览量
2016/07/20 PHP
javascript 写类方式之六
2009/07/05 Javascript
javascript 从if else 到 switch case 再到抽象
2010/07/17 Javascript
Jquery 跨域访问 Lightswitch OData Service的方法
2013/09/11 Javascript
javascript获得网页窗口实际大小的示例代码
2013/09/21 Javascript
javascript实现禁止鼠标滚轮事件
2015/07/24 Javascript
jQuery图片前后对比插件beforeAfter用法示例【附demo源码下载】
2016/09/20 Javascript
详解微信小程序开发之下拉刷新 上拉加载
2016/11/24 Javascript
Angular ui.bootstrap.pagination分页
2017/01/20 Javascript
bootstrap multiselect下拉列表功能
2017/08/22 Javascript
vue路由组件按需加载的几种方法小结
2018/07/12 Javascript
vue组件数据传递、父子组件数据获取,slot,router路由功能示例
2019/03/19 Javascript
[58:12]Ti4第二日主赛事败者组 LGD vs iG 3
2014/07/21 DOTA
使用Python编写爬虫的基本模块及框架使用指南
2016/01/20 Python
Python利用operator模块实现对象的多级排序详解
2017/05/09 Python
Python编程实现双击更新所有已安装python模块的方法
2017/06/05 Python
Python实现的简单读写csv文件操作示例
2018/07/12 Python
为什么Python中没有&quot;a++&quot;这种写法
2018/11/27 Python
pycharm 解除默认unittest模式的方法
2018/11/30 Python
python 函数中的内置函数及用法详解
2019/07/02 Python
Python csv文件记录流程代码解析
2020/07/16 Python
介绍一下linux的文件权限
2012/02/15 面试题
建议书怎么写
2014/03/12 职场文书
个人授权委托书范本
2014/04/03 职场文书
保护环境建议书100字
2014/05/13 职场文书
市场调查策划方案
2014/06/10 职场文书
MBA推荐信怎么写
2015/03/25 职场文书
2015年医生个人工作总结
2015/04/25 职场文书
反邪教观后感
2015/06/11 职场文书
go mod 安装依赖 unkown revision问题的解决方案
2021/05/06 Golang
python编写五子棋游戏
2021/05/25 Python