Tensorflow加载Vgg预训练模型操作


Posted in Python onMay 26, 2020

很多深度神经网络模型需要加载预训练过的Vgg参数,比如说:风格迁移、目标检测、图像标注等计算机视觉中常见的任务。那么到底如何加载Vgg模型呢?Vgg文件的参数到底有何意义呢?加载后的模型该如何使用呢?

本文将以Vgg19为例子,详细说明Tensorflow如何加载Vgg预训练模型。

实验环境

GTX1050-ti, cuda9.0

Window10, Tensorflow 1.12

展示Vgg19构造

import tensorflow as tf
 
import numpy as np
import scipy.io
 
data_path = 'model/vgg19.mat' # data_path指下载下来的Vgg19预训练模型的文件地址
 
# 读取Vgg19文件
data = scipy.io.loadmat(data_path)
# 打印Vgg19的数据类型及其组成
print("type: ", type(data))
print("data.keys: ", data.keys())
 
# 得到对应卷积核的矩阵
weights = data['layers'][0]
# 定义Vgg19的组成
layers = (
 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
 
 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
 
 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3',
 'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
 
 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
 'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
 
 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3',
 'relu5_3', 'conv5_4', 'relu5_4'
)
 
# 打印Vgg19不同卷积层所对应的维度
for i, name in enumerate(layers):
 kind = name[:4]
 if kind == 'conv':
  print("%s: %s" % (name, weights[i][0][0][2][0][0].shape))
 elif kind == 'relu':
  print(name)
 elif kind == 'pool':
  print(name)
 
 
代码输出结果如下:
type: <class 'dict'>
data.keys: dict_keys(['__header__', '__version__', '__globals__', 'layers', 'meta'])
 
conv1_1: (3, 3, 3, 64)
relu1_1
conv1_2: (3, 3, 64, 64)
relu1_2
pool1
conv2_1: (3, 3, 64, 128)
relu2_1
conv2_2: (3, 3, 128, 128)
relu2_2
pool2
conv3_1: (3, 3, 128, 256)
relu3_1
conv3_2: (3, 3, 256, 256)
relu3_2
conv3_3: (3, 3, 256, 256)
relu3_3
conv3_4: (3, 3, 256, 256)
relu3_4
pool3
conv4_1: (3, 3, 256, 512)
relu4_1
conv4_2: (3, 3, 512, 512)
relu4_2
conv4_3: (3, 3, 512, 512)
relu4_3
conv4_4: (3, 3, 512, 512)
relu4_4
pool4
conv5_1: (3, 3, 512, 512)
relu5_1
conv5_2: (3, 3, 512, 512)
relu5_2
conv5_3: (3, 3, 512, 512)
relu5_3
conv5_4: (3, 3, 512, 512)
relu5_4

那么Vgg19真实的网络结构是怎么样子的呢,如下图所示:

Tensorflow加载Vgg预训练模型操作

在本文,主要讨论卷积模块,大家通过对比可以发现,我们打印出来的Vgg19结构及其卷积核的构造的确如论文中给出的Vgg19结构一致。

构建Vgg19模型

def _conv_layer(input, weights, bias):
 conv = tf.nn.conv2d(input, tf.constant(weights), strides=(1, 1, 1, 1),
   padding='SAME')
 return tf.nn.bias_add(conv, bias)
 
def _pool_layer(input):
 return tf.nn.max_pool(input, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1),
   padding='SAME')
 
class VGG19:
 layers = (
  'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
 
  'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
 
  'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3',
  'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
 
  'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
  'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
 
  'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3',
  'relu5_3', 'conv5_4', 'relu5_4'
 )
 
 def __init__(self, data_path):
  data = scipy.io.loadmat(data_path)
 
  self.weights = data['layers'][0]
 
 def feed_forward(self, input_image, scope=None):
  # 定义net用来保存模型每一步输出的特征图
  net = {}
  current = input_image
 
  with tf.variable_scope(scope):
   for i, name in enumerate(self.layers):
    kind = name[:4]
    if kind == 'conv':
     kernels = self.weights[i][0][0][2][0][0]
     bias = self.weights[i][0][0][2][0][1]
 
     kernels = np.transpose(kernels, (1, 0, 2, 3))
     bias = bias.reshape(-1)
 
     current = _conv_layer(current, kernels, bias)
    elif kind == 'relu':
     current = tf.nn.relu(current)
    elif kind == 'pool':
     current = _pool_layer(current)
    # 在每一步都保存当前输出的特征图
    net[name] = current
 
  return net

在上面的代码中,我们定义了一个Vgg19的类别专门用来加载Vgg19模型,并且将每一层卷积得到的特征图保存到net中,最后返回这个net,用于代码后续的处理。

测试Vgg19模型

在给出Vgg19的构造模型后,我们下一步就是如何用它,我们的思路如下:

加载本地图片

定义Vgg19模型,传入本地图片

得到返回每一层的特征图

image_path = "data/test.jpg" # 本地的测试图片
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.to_float(tf.image.decode_jpeg(image_raw))
 
# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)
 
# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)
 
代码结果如下所示:
{'conv1_1': <tf.Tensor 'vgg19_1/BiasAdd:0' shape=(1, ?, ?, 64) dtype=float32>,
 'relu1_1': <tf.Tensor 'vgg19_1/Relu:0' shape=(1, ?, ?, 64) dtype=float32>,
 'conv1_2': <tf.Tensor 'vgg19_1/BiasAdd_1:0' shape=(1, ?, ?, 64) dtype=float32>,
 'relu1_2': <tf.Tensor 'vgg19_1/Relu_1:0' shape=(1, ?, ?, 64) dtype=float32>,
 'pool1': <tf.Tensor 'vgg19_1/MaxPool:0' shape=(1, ?, ?, 64) dtype=float32>,
 'conv2_1': <tf.Tensor 'vgg19_1/BiasAdd_2:0' shape=(1, ?, ?, 128) dtype=float32>,
 'relu2_1': <tf.Tensor 'vgg19_1/Relu_2:0' shape=(1, ?, ?, 128) dtype=float32>,
 'conv2_2': <tf.Tensor 'vgg19_1/BiasAdd_3:0' shape=(1, ?, ?, 128) dtype=float32>,
 'relu2_2': <tf.Tensor 'vgg19_1/Relu_3:0' shape=(1, ?, ?, 128) dtype=float32>,
 'pool2': <tf.Tensor 'vgg19_1/MaxPool_1:0' shape=(1, ?, ?, 128) dtype=float32>,
 'conv3_1': <tf.Tensor 'vgg19_1/BiasAdd_4:0' shape=(1, ?, ?, 256) dtype=float32>,
 'relu3_1': <tf.Tensor 'vgg19_1/Relu_4:0' shape=(1, ?, ?, 256) dtype=float32>,
 'conv3_2': <tf.Tensor 'vgg19_1/BiasAdd_5:0' shape=(1, ?, ?, 256) dtype=float32>,
 'relu3_2': <tf.Tensor 'vgg19_1/Relu_5:0' shape=(1, ?, ?, 256) dtype=float32>,
 'conv3_3': <tf.Tensor 'vgg19_1/BiasAdd_6:0' shape=(1, ?, ?, 256) dtype=float32>,
 'relu3_3': <tf.Tensor 'vgg19_1/Relu_6:0' shape=(1, ?, ?, 256) dtype=float32>,
 'conv3_4': <tf.Tensor 'vgg19_1/BiasAdd_7:0' shape=(1, ?, ?, 256) dtype=float32>,
 'relu3_4': <tf.Tensor 'vgg19_1/Relu_7:0' shape=(1, ?, ?, 256) dtype=float32>,
 'pool3': <tf.Tensor 'vgg19_1/MaxPool_2:0' shape=(1, ?, ?, 256) dtype=float32>,
 'conv4_1': <tf.Tensor 'vgg19_1/BiasAdd_8:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu4_1': <tf.Tensor 'vgg19_1/Relu_8:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv4_2': <tf.Tensor 'vgg19_1/BiasAdd_9:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu4_2': <tf.Tensor 'vgg19_1/Relu_9:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv4_3': <tf.Tensor 'vgg19_1/BiasAdd_10:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu4_3': <tf.Tensor 'vgg19_1/Relu_10:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv4_4': <tf.Tensor 'vgg19_1/BiasAdd_11:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu4_4': <tf.Tensor 'vgg19_1/Relu_11:0' shape=(1, ?, ?, 512) dtype=float32>,
 'pool4': <tf.Tensor 'vgg19_1/MaxPool_3:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv5_1': <tf.Tensor 'vgg19_1/BiasAdd_12:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu5_1': <tf.Tensor 'vgg19_1/Relu_12:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv5_2': <tf.Tensor 'vgg19_1/BiasAdd_13:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu5_2': <tf.Tensor 'vgg19_1/Relu_13:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv5_3': <tf.Tensor 'vgg19_1/BiasAdd_14:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu5_3': <tf.Tensor 'vgg19_1/Relu_14:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv5_4': <tf.Tensor 'vgg19_1/BiasAdd_15:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu5_4': <tf.Tensor 'vgg19_1/Relu_15:0' shape=(1, ?, ?, 512) dtype=float32>}

本文提供的测试代码是完成正确的,已经避免了很多使用Vgg19预训练模型的坑操作,比如:给图片添加维度,转换读取图片的的格式等,为什么这么做的详细原因可参考我的另一篇博客:Tensorflow加载Vgg预训练模型的几个注意事项。

到这里,如何使用tensorflow读取Vgg19模型结束了,若是大家有其他疑惑,可在评论区留言,会定时回答。

好了,以上就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python爬虫辅助利器PyQuery模块的安装使用攻略
Apr 24 Python
使用python实现knn算法
Dec 20 Python
Python cookbook(数据结构与算法)实现对不原生支持比较操作的对象排序算法示例
Mar 15 Python
对pandas处理json数据的方法详解
Feb 08 Python
Python常见读写文件操作实例总结【文本、json、csv、pdf等】
Apr 15 Python
OpenCV 轮廓检测的实现方法
Jul 03 Python
Python爬取智联招聘数据分析师岗位相关信息的方法
Aug 13 Python
python利用datetime模块计算程序运行时间问题
Feb 20 Python
python--shutil移动文件到另一个路径的操作
Jul 13 Python
Windows下PyCharm配置Anaconda环境(超详细教程)
Jul 31 Python
Elasticsearch 批量操作
Apr 19 Python
python如何将mat文件转为png
Jul 15 Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 #Python
You might like
15种PHP Encoder的比较
2007/03/06 PHP
php,ajax实现分页
2008/03/27 PHP
php学习笔记之 函数声明(二)
2011/06/09 PHP
写出高质量的PHP程序
2012/02/04 PHP
PHP中大于2038年时间戳的问题处理方案
2015/03/03 PHP
php 微信开发获取用户信息如何实现
2016/12/13 PHP
PHP操作Postgresql封装类与应用完整实例
2018/04/24 PHP
javaScript 删除字符串空格多种方法小结
2012/10/24 Javascript
让新消息在网页标题闪烁提示的jQuery代码
2013/11/04 Javascript
JScript中的条件注释详解
2015/04/24 Javascript
JavaScript中实现无缝滚动、分享到侧边栏实例代码
2016/04/06 Javascript
详解照片瀑布流效果(js,jquery分别实现与知识点总结)
2017/01/01 Javascript
AngularJS实现使用路由切换视图的方法
2017/01/24 Javascript
AngularJS 防止页面闪烁的方法
2017/03/09 Javascript
原生JS实现层叠轮播图
2017/05/17 Javascript
关于Ajax的原理以及代码封装详解
2017/09/08 Javascript
vue的diff算法知识点总结
2018/03/29 Javascript
浅析TypeScript 命名空间
2020/03/19 Javascript
python实现上传样本到virustotal并查询扫描信息的方法
2014/10/05 Python
python创建列表并给列表赋初始值的方法
2015/07/28 Python
在Linux系统上通过uWSGI配置Nginx+Python环境的教程
2015/12/25 Python
Python的dict字典结构操作方法学习笔记
2016/05/07 Python
Python计时相关操作详解【time,datetime】
2017/05/26 Python
Python面向对象程序设计之类的定义与继承简单示例
2019/03/18 Python
python之pymysql模块简单应用示例代码
2019/12/16 Python
通俗易懂了解Python装饰器原理
2020/09/17 Python
Python3中的tuple函数知识点讲解
2021/01/03 Python
一些高难度的SQL面试题
2016/11/29 面试题
数控技术专业推荐信
2013/11/01 职场文书
机械制造与自动化应届生求职信
2013/11/16 职场文书
眼镜促销方案
2014/03/15 职场文书
领导班子整改方案和个人整改措施
2014/10/25 职场文书
档案管理员岗位职责
2015/02/12 职场文书
世界名著读书笔记
2015/06/25 职场文书
联村联户简报
2015/07/21 职场文书
学子宴致辞大全
2015/07/27 职场文书