Tensorflow加载Vgg预训练模型操作


Posted in Python onMay 26, 2020

很多深度神经网络模型需要加载预训练过的Vgg参数,比如说:风格迁移、目标检测、图像标注等计算机视觉中常见的任务。那么到底如何加载Vgg模型呢?Vgg文件的参数到底有何意义呢?加载后的模型该如何使用呢?

本文将以Vgg19为例子,详细说明Tensorflow如何加载Vgg预训练模型。

实验环境

GTX1050-ti, cuda9.0

Window10, Tensorflow 1.12

展示Vgg19构造

import tensorflow as tf
 
import numpy as np
import scipy.io
 
data_path = 'model/vgg19.mat' # data_path指下载下来的Vgg19预训练模型的文件地址
 
# 读取Vgg19文件
data = scipy.io.loadmat(data_path)
# 打印Vgg19的数据类型及其组成
print("type: ", type(data))
print("data.keys: ", data.keys())
 
# 得到对应卷积核的矩阵
weights = data['layers'][0]
# 定义Vgg19的组成
layers = (
 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
 
 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
 
 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3',
 'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
 
 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
 'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
 
 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3',
 'relu5_3', 'conv5_4', 'relu5_4'
)
 
# 打印Vgg19不同卷积层所对应的维度
for i, name in enumerate(layers):
 kind = name[:4]
 if kind == 'conv':
  print("%s: %s" % (name, weights[i][0][0][2][0][0].shape))
 elif kind == 'relu':
  print(name)
 elif kind == 'pool':
  print(name)
 
 
代码输出结果如下:
type: <class 'dict'>
data.keys: dict_keys(['__header__', '__version__', '__globals__', 'layers', 'meta'])
 
conv1_1: (3, 3, 3, 64)
relu1_1
conv1_2: (3, 3, 64, 64)
relu1_2
pool1
conv2_1: (3, 3, 64, 128)
relu2_1
conv2_2: (3, 3, 128, 128)
relu2_2
pool2
conv3_1: (3, 3, 128, 256)
relu3_1
conv3_2: (3, 3, 256, 256)
relu3_2
conv3_3: (3, 3, 256, 256)
relu3_3
conv3_4: (3, 3, 256, 256)
relu3_4
pool3
conv4_1: (3, 3, 256, 512)
relu4_1
conv4_2: (3, 3, 512, 512)
relu4_2
conv4_3: (3, 3, 512, 512)
relu4_3
conv4_4: (3, 3, 512, 512)
relu4_4
pool4
conv5_1: (3, 3, 512, 512)
relu5_1
conv5_2: (3, 3, 512, 512)
relu5_2
conv5_3: (3, 3, 512, 512)
relu5_3
conv5_4: (3, 3, 512, 512)
relu5_4

那么Vgg19真实的网络结构是怎么样子的呢,如下图所示:

Tensorflow加载Vgg预训练模型操作

在本文,主要讨论卷积模块,大家通过对比可以发现,我们打印出来的Vgg19结构及其卷积核的构造的确如论文中给出的Vgg19结构一致。

构建Vgg19模型

def _conv_layer(input, weights, bias):
 conv = tf.nn.conv2d(input, tf.constant(weights), strides=(1, 1, 1, 1),
   padding='SAME')
 return tf.nn.bias_add(conv, bias)
 
def _pool_layer(input):
 return tf.nn.max_pool(input, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1),
   padding='SAME')
 
class VGG19:
 layers = (
  'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
 
  'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
 
  'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3',
  'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
 
  'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
  'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
 
  'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3',
  'relu5_3', 'conv5_4', 'relu5_4'
 )
 
 def __init__(self, data_path):
  data = scipy.io.loadmat(data_path)
 
  self.weights = data['layers'][0]
 
 def feed_forward(self, input_image, scope=None):
  # 定义net用来保存模型每一步输出的特征图
  net = {}
  current = input_image
 
  with tf.variable_scope(scope):
   for i, name in enumerate(self.layers):
    kind = name[:4]
    if kind == 'conv':
     kernels = self.weights[i][0][0][2][0][0]
     bias = self.weights[i][0][0][2][0][1]
 
     kernels = np.transpose(kernels, (1, 0, 2, 3))
     bias = bias.reshape(-1)
 
     current = _conv_layer(current, kernels, bias)
    elif kind == 'relu':
     current = tf.nn.relu(current)
    elif kind == 'pool':
     current = _pool_layer(current)
    # 在每一步都保存当前输出的特征图
    net[name] = current
 
  return net

在上面的代码中,我们定义了一个Vgg19的类别专门用来加载Vgg19模型,并且将每一层卷积得到的特征图保存到net中,最后返回这个net,用于代码后续的处理。

测试Vgg19模型

在给出Vgg19的构造模型后,我们下一步就是如何用它,我们的思路如下:

加载本地图片

定义Vgg19模型,传入本地图片

得到返回每一层的特征图

image_path = "data/test.jpg" # 本地的测试图片
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.to_float(tf.image.decode_jpeg(image_raw))
 
# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)
 
# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)
 
代码结果如下所示:
{'conv1_1': <tf.Tensor 'vgg19_1/BiasAdd:0' shape=(1, ?, ?, 64) dtype=float32>,
 'relu1_1': <tf.Tensor 'vgg19_1/Relu:0' shape=(1, ?, ?, 64) dtype=float32>,
 'conv1_2': <tf.Tensor 'vgg19_1/BiasAdd_1:0' shape=(1, ?, ?, 64) dtype=float32>,
 'relu1_2': <tf.Tensor 'vgg19_1/Relu_1:0' shape=(1, ?, ?, 64) dtype=float32>,
 'pool1': <tf.Tensor 'vgg19_1/MaxPool:0' shape=(1, ?, ?, 64) dtype=float32>,
 'conv2_1': <tf.Tensor 'vgg19_1/BiasAdd_2:0' shape=(1, ?, ?, 128) dtype=float32>,
 'relu2_1': <tf.Tensor 'vgg19_1/Relu_2:0' shape=(1, ?, ?, 128) dtype=float32>,
 'conv2_2': <tf.Tensor 'vgg19_1/BiasAdd_3:0' shape=(1, ?, ?, 128) dtype=float32>,
 'relu2_2': <tf.Tensor 'vgg19_1/Relu_3:0' shape=(1, ?, ?, 128) dtype=float32>,
 'pool2': <tf.Tensor 'vgg19_1/MaxPool_1:0' shape=(1, ?, ?, 128) dtype=float32>,
 'conv3_1': <tf.Tensor 'vgg19_1/BiasAdd_4:0' shape=(1, ?, ?, 256) dtype=float32>,
 'relu3_1': <tf.Tensor 'vgg19_1/Relu_4:0' shape=(1, ?, ?, 256) dtype=float32>,
 'conv3_2': <tf.Tensor 'vgg19_1/BiasAdd_5:0' shape=(1, ?, ?, 256) dtype=float32>,
 'relu3_2': <tf.Tensor 'vgg19_1/Relu_5:0' shape=(1, ?, ?, 256) dtype=float32>,
 'conv3_3': <tf.Tensor 'vgg19_1/BiasAdd_6:0' shape=(1, ?, ?, 256) dtype=float32>,
 'relu3_3': <tf.Tensor 'vgg19_1/Relu_6:0' shape=(1, ?, ?, 256) dtype=float32>,
 'conv3_4': <tf.Tensor 'vgg19_1/BiasAdd_7:0' shape=(1, ?, ?, 256) dtype=float32>,
 'relu3_4': <tf.Tensor 'vgg19_1/Relu_7:0' shape=(1, ?, ?, 256) dtype=float32>,
 'pool3': <tf.Tensor 'vgg19_1/MaxPool_2:0' shape=(1, ?, ?, 256) dtype=float32>,
 'conv4_1': <tf.Tensor 'vgg19_1/BiasAdd_8:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu4_1': <tf.Tensor 'vgg19_1/Relu_8:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv4_2': <tf.Tensor 'vgg19_1/BiasAdd_9:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu4_2': <tf.Tensor 'vgg19_1/Relu_9:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv4_3': <tf.Tensor 'vgg19_1/BiasAdd_10:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu4_3': <tf.Tensor 'vgg19_1/Relu_10:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv4_4': <tf.Tensor 'vgg19_1/BiasAdd_11:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu4_4': <tf.Tensor 'vgg19_1/Relu_11:0' shape=(1, ?, ?, 512) dtype=float32>,
 'pool4': <tf.Tensor 'vgg19_1/MaxPool_3:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv5_1': <tf.Tensor 'vgg19_1/BiasAdd_12:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu5_1': <tf.Tensor 'vgg19_1/Relu_12:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv5_2': <tf.Tensor 'vgg19_1/BiasAdd_13:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu5_2': <tf.Tensor 'vgg19_1/Relu_13:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv5_3': <tf.Tensor 'vgg19_1/BiasAdd_14:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu5_3': <tf.Tensor 'vgg19_1/Relu_14:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv5_4': <tf.Tensor 'vgg19_1/BiasAdd_15:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu5_4': <tf.Tensor 'vgg19_1/Relu_15:0' shape=(1, ?, ?, 512) dtype=float32>}

本文提供的测试代码是完成正确的,已经避免了很多使用Vgg19预训练模型的坑操作,比如:给图片添加维度,转换读取图片的的格式等,为什么这么做的详细原因可参考我的另一篇博客:Tensorflow加载Vgg预训练模型的几个注意事项。

到这里,如何使用tensorflow读取Vgg19模型结束了,若是大家有其他疑惑,可在评论区留言,会定时回答。

好了,以上就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中使用pyhook实现键盘监控的例子
Jul 18 Python
分析在Python中何种情况下需要使用断言
Apr 01 Python
Scrapy-redis爬虫分布式爬取的分析和实现
Feb 07 Python
Python中字符串格式化str.format的详细介绍
Feb 17 Python
python 连接sqlite及简单操作
Jun 30 Python
浅谈Python处理PDF的方法
Nov 10 Python
TensorFlow深度学习之卷积神经网络CNN
Mar 09 Python
对python自动生成接口测试的示例讲解
Nov 30 Python
Python基于Tkinter模块实现的弹球小游戏
Dec 27 Python
利用Python进行图像的加法,图像混合(附代码)
Jul 14 Python
python爬虫 execjs安装配置及使用
Jul 30 Python
聊聊pytorch测试的时候为何要加上model.eval()
May 23 Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 #Python
You might like
用PHP和MySQL保存和输出图片
2006/10/09 PHP
php文件服务实现虚拟挂载其他目录示例
2014/04/17 PHP
简介WordPress中用于获取首页和站点链接的PHP函数
2015/12/17 PHP
Laravel多用户认证系统示例详解
2018/03/13 PHP
DHTML Slide Show script图片轮换
2008/03/03 Javascript
JS两种定义方式的区别、内部原理
2013/11/21 Javascript
JS模仿编辑器实时改变文本框宽度和高度大小的方法
2015/08/17 Javascript
javascript精确统计网站访问量实例代码
2015/12/19 Javascript
在js里怎么实现Xcode里的callFuncN方法(详解)
2016/11/05 Javascript
vue-router路由简单案例介绍
2017/02/21 Javascript
bootstrap折叠调用collapse()后data-parent不生效的快速解决办法
2017/02/23 Javascript
jQuery中extend函数简单用法示例
2017/10/11 jQuery
利用vue + koa2 + mockjs模拟数据的方法教程
2017/11/22 Javascript
浅谈js获取ModelAndView值的问题
2018/03/28 Javascript
layui 实现表格某一列显示图标
2019/09/19 Javascript
js实现轮播图效果 z-index实现轮播图
2020/01/17 Javascript
[51:28]EG vs Mineski 2018国际邀请赛小组赛BO2 第一场 8.16
2018/08/16 DOTA
Python入门篇之正则表达式
2014/10/20 Python
python中itertools模块zip_longest函数详解
2018/06/12 Python
python遍历文件夹,指定遍历深度与忽略目录的方法
2018/07/11 Python
Python3编码问题 Unicode utf-8 bytes互转方法
2018/10/26 Python
python实现切割url得到域名、协议、主机名等各个字段的例子
2019/07/25 Python
TensorFlow查看输入节点和输出节点名称方式
2020/01/04 Python
python MultipartEncoder传输zip文件实例
2020/04/07 Python
解决windows下python3使用multiprocessing.Pool出现的问题
2020/04/08 Python
如何基于线程池提升request模块效率
2020/04/18 Python
python 日志模块logging的使用场景及示例
2021/01/04 Python
详解Django关于StreamingHttpResponse与FileResponse文件下载的最优方法
2021/01/07 Python
在C#中如何实现多态
2014/07/02 面试题
什么是GWT的Entry Point
2013/08/16 面试题
一年级家长会邀请函
2014/01/25 职场文书
《悯农》教学反思
2014/04/28 职场文书
建设幸福中国演讲稿
2014/09/11 职场文书
学生打架检讨书
2014/10/20 职场文书
总经理助理岗位职责
2015/01/31 职场文书
心灵点滴观后感
2015/06/02 职场文书