Tensorflow加载Vgg预训练模型操作


Posted in Python onMay 26, 2020

很多深度神经网络模型需要加载预训练过的Vgg参数,比如说:风格迁移、目标检测、图像标注等计算机视觉中常见的任务。那么到底如何加载Vgg模型呢?Vgg文件的参数到底有何意义呢?加载后的模型该如何使用呢?

本文将以Vgg19为例子,详细说明Tensorflow如何加载Vgg预训练模型。

实验环境

GTX1050-ti, cuda9.0

Window10, Tensorflow 1.12

展示Vgg19构造

import tensorflow as tf
 
import numpy as np
import scipy.io
 
data_path = 'model/vgg19.mat' # data_path指下载下来的Vgg19预训练模型的文件地址
 
# 读取Vgg19文件
data = scipy.io.loadmat(data_path)
# 打印Vgg19的数据类型及其组成
print("type: ", type(data))
print("data.keys: ", data.keys())
 
# 得到对应卷积核的矩阵
weights = data['layers'][0]
# 定义Vgg19的组成
layers = (
 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
 
 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
 
 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3',
 'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
 
 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
 'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
 
 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3',
 'relu5_3', 'conv5_4', 'relu5_4'
)
 
# 打印Vgg19不同卷积层所对应的维度
for i, name in enumerate(layers):
 kind = name[:4]
 if kind == 'conv':
  print("%s: %s" % (name, weights[i][0][0][2][0][0].shape))
 elif kind == 'relu':
  print(name)
 elif kind == 'pool':
  print(name)
 
 
代码输出结果如下:
type: <class 'dict'>
data.keys: dict_keys(['__header__', '__version__', '__globals__', 'layers', 'meta'])
 
conv1_1: (3, 3, 3, 64)
relu1_1
conv1_2: (3, 3, 64, 64)
relu1_2
pool1
conv2_1: (3, 3, 64, 128)
relu2_1
conv2_2: (3, 3, 128, 128)
relu2_2
pool2
conv3_1: (3, 3, 128, 256)
relu3_1
conv3_2: (3, 3, 256, 256)
relu3_2
conv3_3: (3, 3, 256, 256)
relu3_3
conv3_4: (3, 3, 256, 256)
relu3_4
pool3
conv4_1: (3, 3, 256, 512)
relu4_1
conv4_2: (3, 3, 512, 512)
relu4_2
conv4_3: (3, 3, 512, 512)
relu4_3
conv4_4: (3, 3, 512, 512)
relu4_4
pool4
conv5_1: (3, 3, 512, 512)
relu5_1
conv5_2: (3, 3, 512, 512)
relu5_2
conv5_3: (3, 3, 512, 512)
relu5_3
conv5_4: (3, 3, 512, 512)
relu5_4

那么Vgg19真实的网络结构是怎么样子的呢,如下图所示:

Tensorflow加载Vgg预训练模型操作

在本文,主要讨论卷积模块,大家通过对比可以发现,我们打印出来的Vgg19结构及其卷积核的构造的确如论文中给出的Vgg19结构一致。

构建Vgg19模型

def _conv_layer(input, weights, bias):
 conv = tf.nn.conv2d(input, tf.constant(weights), strides=(1, 1, 1, 1),
   padding='SAME')
 return tf.nn.bias_add(conv, bias)
 
def _pool_layer(input):
 return tf.nn.max_pool(input, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1),
   padding='SAME')
 
class VGG19:
 layers = (
  'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
 
  'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
 
  'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3',
  'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
 
  'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
  'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
 
  'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3',
  'relu5_3', 'conv5_4', 'relu5_4'
 )
 
 def __init__(self, data_path):
  data = scipy.io.loadmat(data_path)
 
  self.weights = data['layers'][0]
 
 def feed_forward(self, input_image, scope=None):
  # 定义net用来保存模型每一步输出的特征图
  net = {}
  current = input_image
 
  with tf.variable_scope(scope):
   for i, name in enumerate(self.layers):
    kind = name[:4]
    if kind == 'conv':
     kernels = self.weights[i][0][0][2][0][0]
     bias = self.weights[i][0][0][2][0][1]
 
     kernels = np.transpose(kernels, (1, 0, 2, 3))
     bias = bias.reshape(-1)
 
     current = _conv_layer(current, kernels, bias)
    elif kind == 'relu':
     current = tf.nn.relu(current)
    elif kind == 'pool':
     current = _pool_layer(current)
    # 在每一步都保存当前输出的特征图
    net[name] = current
 
  return net

在上面的代码中,我们定义了一个Vgg19的类别专门用来加载Vgg19模型,并且将每一层卷积得到的特征图保存到net中,最后返回这个net,用于代码后续的处理。

测试Vgg19模型

在给出Vgg19的构造模型后,我们下一步就是如何用它,我们的思路如下:

加载本地图片

定义Vgg19模型,传入本地图片

得到返回每一层的特征图

image_path = "data/test.jpg" # 本地的测试图片
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.to_float(tf.image.decode_jpeg(image_raw))
 
# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)
 
# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)
 
代码结果如下所示:
{'conv1_1': <tf.Tensor 'vgg19_1/BiasAdd:0' shape=(1, ?, ?, 64) dtype=float32>,
 'relu1_1': <tf.Tensor 'vgg19_1/Relu:0' shape=(1, ?, ?, 64) dtype=float32>,
 'conv1_2': <tf.Tensor 'vgg19_1/BiasAdd_1:0' shape=(1, ?, ?, 64) dtype=float32>,
 'relu1_2': <tf.Tensor 'vgg19_1/Relu_1:0' shape=(1, ?, ?, 64) dtype=float32>,
 'pool1': <tf.Tensor 'vgg19_1/MaxPool:0' shape=(1, ?, ?, 64) dtype=float32>,
 'conv2_1': <tf.Tensor 'vgg19_1/BiasAdd_2:0' shape=(1, ?, ?, 128) dtype=float32>,
 'relu2_1': <tf.Tensor 'vgg19_1/Relu_2:0' shape=(1, ?, ?, 128) dtype=float32>,
 'conv2_2': <tf.Tensor 'vgg19_1/BiasAdd_3:0' shape=(1, ?, ?, 128) dtype=float32>,
 'relu2_2': <tf.Tensor 'vgg19_1/Relu_3:0' shape=(1, ?, ?, 128) dtype=float32>,
 'pool2': <tf.Tensor 'vgg19_1/MaxPool_1:0' shape=(1, ?, ?, 128) dtype=float32>,
 'conv3_1': <tf.Tensor 'vgg19_1/BiasAdd_4:0' shape=(1, ?, ?, 256) dtype=float32>,
 'relu3_1': <tf.Tensor 'vgg19_1/Relu_4:0' shape=(1, ?, ?, 256) dtype=float32>,
 'conv3_2': <tf.Tensor 'vgg19_1/BiasAdd_5:0' shape=(1, ?, ?, 256) dtype=float32>,
 'relu3_2': <tf.Tensor 'vgg19_1/Relu_5:0' shape=(1, ?, ?, 256) dtype=float32>,
 'conv3_3': <tf.Tensor 'vgg19_1/BiasAdd_6:0' shape=(1, ?, ?, 256) dtype=float32>,
 'relu3_3': <tf.Tensor 'vgg19_1/Relu_6:0' shape=(1, ?, ?, 256) dtype=float32>,
 'conv3_4': <tf.Tensor 'vgg19_1/BiasAdd_7:0' shape=(1, ?, ?, 256) dtype=float32>,
 'relu3_4': <tf.Tensor 'vgg19_1/Relu_7:0' shape=(1, ?, ?, 256) dtype=float32>,
 'pool3': <tf.Tensor 'vgg19_1/MaxPool_2:0' shape=(1, ?, ?, 256) dtype=float32>,
 'conv4_1': <tf.Tensor 'vgg19_1/BiasAdd_8:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu4_1': <tf.Tensor 'vgg19_1/Relu_8:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv4_2': <tf.Tensor 'vgg19_1/BiasAdd_9:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu4_2': <tf.Tensor 'vgg19_1/Relu_9:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv4_3': <tf.Tensor 'vgg19_1/BiasAdd_10:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu4_3': <tf.Tensor 'vgg19_1/Relu_10:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv4_4': <tf.Tensor 'vgg19_1/BiasAdd_11:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu4_4': <tf.Tensor 'vgg19_1/Relu_11:0' shape=(1, ?, ?, 512) dtype=float32>,
 'pool4': <tf.Tensor 'vgg19_1/MaxPool_3:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv5_1': <tf.Tensor 'vgg19_1/BiasAdd_12:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu5_1': <tf.Tensor 'vgg19_1/Relu_12:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv5_2': <tf.Tensor 'vgg19_1/BiasAdd_13:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu5_2': <tf.Tensor 'vgg19_1/Relu_13:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv5_3': <tf.Tensor 'vgg19_1/BiasAdd_14:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu5_3': <tf.Tensor 'vgg19_1/Relu_14:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv5_4': <tf.Tensor 'vgg19_1/BiasAdd_15:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu5_4': <tf.Tensor 'vgg19_1/Relu_15:0' shape=(1, ?, ?, 512) dtype=float32>}

本文提供的测试代码是完成正确的,已经避免了很多使用Vgg19预训练模型的坑操作,比如:给图片添加维度,转换读取图片的的格式等,为什么这么做的详细原因可参考我的另一篇博客:Tensorflow加载Vgg预训练模型的几个注意事项。

到这里,如何使用tensorflow读取Vgg19模型结束了,若是大家有其他疑惑,可在评论区留言,会定时回答。

好了,以上就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
利用Python的装饰器解决Bottle框架中用户验证问题
Apr 24 Python
python返回昨天日期的方法
May 13 Python
Python通过matplotlib画双层饼图及环形图简单示例
Dec 15 Python
python编程实现12306的一个小爬虫实例
Dec 27 Python
Python利用递归实现文件的复制方法
Oct 27 Python
python3 使用Opencv打开USB摄像头,配置1080P分辨率的操作
Dec 11 Python
python3读取csv文件任意行列代码实例
Jan 13 Python
tensorflow 实现从checkpoint中获取graph信息
Feb 10 Python
python实现逆滤波与维纳滤波示例
Feb 26 Python
Tensorflow卷积实现原理+手写python代码实现卷积教程
May 22 Python
如何使用python的ctypes调用医保中心的dll动态库下载医保中心的账单
May 24 Python
python3排序的实例方法
Oct 20 Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 #Python
You might like
第十二节--类的自动加载
2006/11/16 PHP
基于PHP magic_quotes_gpc的使用方法详解
2013/06/24 PHP
php实现通过ftp上传文件
2015/06/19 PHP
详解PHP中的 input属性(隐藏 只读 限制)
2017/08/14 PHP
设定php简写功能的方法
2019/11/28 PHP
javascript 跨浏览器开发经验总结(五) js 事件
2010/05/19 Javascript
js下利用控制器载入对应脚本
2010/07/17 Javascript
JavaScript下通过的XMLHttpRequest发送请求的代码
2011/06/28 Javascript
jquery验证表单中的单选与多选实例
2013/08/18 Javascript
jquery 实现窗口的最大化不论什么情况
2013/09/03 Javascript
jquery ajax jsonp跨域调用实例代码
2013/12/11 Javascript
angular简介和其特点介绍
2015/01/29 Javascript
js随机生成字母数字组合的字符串 随机动画数字
2015/09/02 Javascript
jQuery validate插件实现ajax验证重复的2种方法
2016/01/22 Javascript
原生JS获取元素集合的子元素宽度实例
2016/12/14 Javascript
Vue.js自定义指令的用法与实例解析
2017/01/18 Javascript
React Native 集成jpush-react-native的示例代码
2017/08/16 Javascript
AngularJS实时获取并显示密码的方法
2018/02/06 Javascript
用Cordova打包Vue项目的方法步骤
2019/02/02 Javascript
[47:04]EG vs RNG 2019国际邀请赛小组赛 BO2 第二场 8.16
2019/08/18 DOTA
Python结巴中文分词工具使用过程中遇到的问题及解决方法
2017/04/15 Python
浅谈django的render函数的参数问题
2018/10/16 Python
pymysql 开启调试模式的实现
2019/09/24 Python
在python中计算ssim的方法(与Matlab结果一致)
2019/12/19 Python
tensorflow指定GPU与动态分配GPU memory设置
2020/02/03 Python
CSS3中的元素过渡属性transition示例详解
2016/11/30 HTML / CSS
CSS Grid布局教程之什么是网格布局
2014/12/30 HTML / CSS
前厅部经理岗位职责范文
2014/02/04 职场文书
少先队活动总结
2014/08/29 职场文书
保密工作整改报告
2014/11/06 职场文书
宾馆前台接待岗位职责
2015/04/02 职场文书
公安干警正风肃纪心得体会
2016/01/15 职场文书
导游词之河北邯郸
2019/09/12 职场文书
Pytest中skip skipif跳过用例详解
2021/06/30 Python
Mysql中where与on的区别及何时使用详析
2021/08/04 MySQL
25张裸眼3D图片,带你重温童年的记忆,感受3D的魅力
2022/02/06 杂记