Tensorflow加载Vgg预训练模型操作


Posted in Python onMay 26, 2020

很多深度神经网络模型需要加载预训练过的Vgg参数,比如说:风格迁移、目标检测、图像标注等计算机视觉中常见的任务。那么到底如何加载Vgg模型呢?Vgg文件的参数到底有何意义呢?加载后的模型该如何使用呢?

本文将以Vgg19为例子,详细说明Tensorflow如何加载Vgg预训练模型。

实验环境

GTX1050-ti, cuda9.0

Window10, Tensorflow 1.12

展示Vgg19构造

import tensorflow as tf
 
import numpy as np
import scipy.io
 
data_path = 'model/vgg19.mat' # data_path指下载下来的Vgg19预训练模型的文件地址
 
# 读取Vgg19文件
data = scipy.io.loadmat(data_path)
# 打印Vgg19的数据类型及其组成
print("type: ", type(data))
print("data.keys: ", data.keys())
 
# 得到对应卷积核的矩阵
weights = data['layers'][0]
# 定义Vgg19的组成
layers = (
 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
 
 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
 
 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3',
 'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
 
 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
 'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
 
 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3',
 'relu5_3', 'conv5_4', 'relu5_4'
)
 
# 打印Vgg19不同卷积层所对应的维度
for i, name in enumerate(layers):
 kind = name[:4]
 if kind == 'conv':
  print("%s: %s" % (name, weights[i][0][0][2][0][0].shape))
 elif kind == 'relu':
  print(name)
 elif kind == 'pool':
  print(name)
 
 
代码输出结果如下:
type: <class 'dict'>
data.keys: dict_keys(['__header__', '__version__', '__globals__', 'layers', 'meta'])
 
conv1_1: (3, 3, 3, 64)
relu1_1
conv1_2: (3, 3, 64, 64)
relu1_2
pool1
conv2_1: (3, 3, 64, 128)
relu2_1
conv2_2: (3, 3, 128, 128)
relu2_2
pool2
conv3_1: (3, 3, 128, 256)
relu3_1
conv3_2: (3, 3, 256, 256)
relu3_2
conv3_3: (3, 3, 256, 256)
relu3_3
conv3_4: (3, 3, 256, 256)
relu3_4
pool3
conv4_1: (3, 3, 256, 512)
relu4_1
conv4_2: (3, 3, 512, 512)
relu4_2
conv4_3: (3, 3, 512, 512)
relu4_3
conv4_4: (3, 3, 512, 512)
relu4_4
pool4
conv5_1: (3, 3, 512, 512)
relu5_1
conv5_2: (3, 3, 512, 512)
relu5_2
conv5_3: (3, 3, 512, 512)
relu5_3
conv5_4: (3, 3, 512, 512)
relu5_4

那么Vgg19真实的网络结构是怎么样子的呢,如下图所示:

Tensorflow加载Vgg预训练模型操作

在本文,主要讨论卷积模块,大家通过对比可以发现,我们打印出来的Vgg19结构及其卷积核的构造的确如论文中给出的Vgg19结构一致。

构建Vgg19模型

def _conv_layer(input, weights, bias):
 conv = tf.nn.conv2d(input, tf.constant(weights), strides=(1, 1, 1, 1),
   padding='SAME')
 return tf.nn.bias_add(conv, bias)
 
def _pool_layer(input):
 return tf.nn.max_pool(input, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1),
   padding='SAME')
 
class VGG19:
 layers = (
  'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
 
  'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
 
  'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3',
  'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
 
  'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
  'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
 
  'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3',
  'relu5_3', 'conv5_4', 'relu5_4'
 )
 
 def __init__(self, data_path):
  data = scipy.io.loadmat(data_path)
 
  self.weights = data['layers'][0]
 
 def feed_forward(self, input_image, scope=None):
  # 定义net用来保存模型每一步输出的特征图
  net = {}
  current = input_image
 
  with tf.variable_scope(scope):
   for i, name in enumerate(self.layers):
    kind = name[:4]
    if kind == 'conv':
     kernels = self.weights[i][0][0][2][0][0]
     bias = self.weights[i][0][0][2][0][1]
 
     kernels = np.transpose(kernels, (1, 0, 2, 3))
     bias = bias.reshape(-1)
 
     current = _conv_layer(current, kernels, bias)
    elif kind == 'relu':
     current = tf.nn.relu(current)
    elif kind == 'pool':
     current = _pool_layer(current)
    # 在每一步都保存当前输出的特征图
    net[name] = current
 
  return net

在上面的代码中,我们定义了一个Vgg19的类别专门用来加载Vgg19模型,并且将每一层卷积得到的特征图保存到net中,最后返回这个net,用于代码后续的处理。

测试Vgg19模型

在给出Vgg19的构造模型后,我们下一步就是如何用它,我们的思路如下:

加载本地图片

定义Vgg19模型,传入本地图片

得到返回每一层的特征图

image_path = "data/test.jpg" # 本地的测试图片
 
image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.to_float(tf.image.decode_jpeg(image_raw))
 
# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)
 
# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)
 
代码结果如下所示:
{'conv1_1': <tf.Tensor 'vgg19_1/BiasAdd:0' shape=(1, ?, ?, 64) dtype=float32>,
 'relu1_1': <tf.Tensor 'vgg19_1/Relu:0' shape=(1, ?, ?, 64) dtype=float32>,
 'conv1_2': <tf.Tensor 'vgg19_1/BiasAdd_1:0' shape=(1, ?, ?, 64) dtype=float32>,
 'relu1_2': <tf.Tensor 'vgg19_1/Relu_1:0' shape=(1, ?, ?, 64) dtype=float32>,
 'pool1': <tf.Tensor 'vgg19_1/MaxPool:0' shape=(1, ?, ?, 64) dtype=float32>,
 'conv2_1': <tf.Tensor 'vgg19_1/BiasAdd_2:0' shape=(1, ?, ?, 128) dtype=float32>,
 'relu2_1': <tf.Tensor 'vgg19_1/Relu_2:0' shape=(1, ?, ?, 128) dtype=float32>,
 'conv2_2': <tf.Tensor 'vgg19_1/BiasAdd_3:0' shape=(1, ?, ?, 128) dtype=float32>,
 'relu2_2': <tf.Tensor 'vgg19_1/Relu_3:0' shape=(1, ?, ?, 128) dtype=float32>,
 'pool2': <tf.Tensor 'vgg19_1/MaxPool_1:0' shape=(1, ?, ?, 128) dtype=float32>,
 'conv3_1': <tf.Tensor 'vgg19_1/BiasAdd_4:0' shape=(1, ?, ?, 256) dtype=float32>,
 'relu3_1': <tf.Tensor 'vgg19_1/Relu_4:0' shape=(1, ?, ?, 256) dtype=float32>,
 'conv3_2': <tf.Tensor 'vgg19_1/BiasAdd_5:0' shape=(1, ?, ?, 256) dtype=float32>,
 'relu3_2': <tf.Tensor 'vgg19_1/Relu_5:0' shape=(1, ?, ?, 256) dtype=float32>,
 'conv3_3': <tf.Tensor 'vgg19_1/BiasAdd_6:0' shape=(1, ?, ?, 256) dtype=float32>,
 'relu3_3': <tf.Tensor 'vgg19_1/Relu_6:0' shape=(1, ?, ?, 256) dtype=float32>,
 'conv3_4': <tf.Tensor 'vgg19_1/BiasAdd_7:0' shape=(1, ?, ?, 256) dtype=float32>,
 'relu3_4': <tf.Tensor 'vgg19_1/Relu_7:0' shape=(1, ?, ?, 256) dtype=float32>,
 'pool3': <tf.Tensor 'vgg19_1/MaxPool_2:0' shape=(1, ?, ?, 256) dtype=float32>,
 'conv4_1': <tf.Tensor 'vgg19_1/BiasAdd_8:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu4_1': <tf.Tensor 'vgg19_1/Relu_8:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv4_2': <tf.Tensor 'vgg19_1/BiasAdd_9:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu4_2': <tf.Tensor 'vgg19_1/Relu_9:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv4_3': <tf.Tensor 'vgg19_1/BiasAdd_10:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu4_3': <tf.Tensor 'vgg19_1/Relu_10:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv4_4': <tf.Tensor 'vgg19_1/BiasAdd_11:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu4_4': <tf.Tensor 'vgg19_1/Relu_11:0' shape=(1, ?, ?, 512) dtype=float32>,
 'pool4': <tf.Tensor 'vgg19_1/MaxPool_3:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv5_1': <tf.Tensor 'vgg19_1/BiasAdd_12:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu5_1': <tf.Tensor 'vgg19_1/Relu_12:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv5_2': <tf.Tensor 'vgg19_1/BiasAdd_13:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu5_2': <tf.Tensor 'vgg19_1/Relu_13:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv5_3': <tf.Tensor 'vgg19_1/BiasAdd_14:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu5_3': <tf.Tensor 'vgg19_1/Relu_14:0' shape=(1, ?, ?, 512) dtype=float32>,
 'conv5_4': <tf.Tensor 'vgg19_1/BiasAdd_15:0' shape=(1, ?, ?, 512) dtype=float32>,
 'relu5_4': <tf.Tensor 'vgg19_1/Relu_15:0' shape=(1, ?, ?, 512) dtype=float32>}

本文提供的测试代码是完成正确的,已经避免了很多使用Vgg19预训练模型的坑操作,比如:给图片添加维度,转换读取图片的的格式等,为什么这么做的详细原因可参考我的另一篇博客:Tensorflow加载Vgg预训练模型的几个注意事项。

到这里,如何使用tensorflow读取Vgg19模型结束了,若是大家有其他疑惑,可在评论区留言,会定时回答。

好了,以上就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
深入理解Python对Json的解析
Feb 14 Python
Python模糊查询本地文件夹去除文件后缀的实例(7行代码)
Nov 09 Python
python实现俄罗斯方块
Jun 26 Python
Django项目开发中cookies和session的常用操作分析
Jul 03 Python
对PyQt5的输入对话框使用(QInputDialog)详解
Jun 25 Python
python打造爬虫代理池过程解析
Aug 15 Python
python实现俄罗斯方块小游戏
Apr 24 Python
Django中FilePathField字段的用法
May 21 Python
python程序需要编译吗
Jun 19 Python
Python爬取微信小程序通用方法代码实例详解
Sep 29 Python
Python实现定时监测网站运行状态的示例代码
Sep 30 Python
python 合并多个excel中同名的sheet
Jan 22 Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 #Python
You might like
利用curl抓取远程页面内容的示例代码
2013/07/23 PHP
Yii2.0多文件上传实例说明
2017/07/24 PHP
JavaScript 学习笔记(十六) js事件
2010/02/01 Javascript
javascript动态改变img的src属性图片不显示的解决方法
2010/10/20 Javascript
js判断浏览器是否支持严格模式的方法
2016/10/04 Javascript
jquery+css3问卷答题卡翻页动画效果示例
2016/10/26 Javascript
JavaScript和JQuery获取DIV值的方法示例
2017/03/07 Javascript
解析jquery easyui tree异步加载子节点问题
2017/03/08 Javascript
Vue.js 2.0学习教程之从基础到组件详解
2017/04/24 Javascript
20行JS代码实现网页刮刮乐效果
2017/06/23 Javascript
vue页面切换到滚动页面显示顶部的实例
2018/03/13 Javascript
vue.js通过路由实现经典的三栏布局实例代码
2018/07/08 Javascript
Vuex的基本概念、项目搭建以及入坑点
2018/11/04 Javascript
ES6 async、await的基本使用方法示例
2020/06/06 Javascript
JavaScript实现鼠标移入随机变换颜色
2020/11/24 Javascript
Django在Win7下的安装及创建项目hello word简明教程
2014/07/14 Python
python使用arp欺骗伪造网关的方法
2015/04/24 Python
详解python字节码
2018/02/07 Python
Python numpy 提取矩阵的某一行或某一列的实例
2018/04/03 Python
Django框架实现的分页demo示例
2019/05/25 Python
python中字符串数组逆序排列方法总结
2019/06/23 Python
Python变量格式化输出实现原理解析
2020/08/06 Python
Python经典五人分鱼实例讲解
2021/01/04 Python
css3 中的新特性加强记忆详解
2016/04/16 HTML / CSS
使用phonegap进行本地存储的实现方法
2017/03/31 HTML / CSS
说一下mysql, oracle等常见数据库的分页实现方案
2012/09/29 面试题
农村婚礼证婚词
2014/01/10 职场文书
新闻发布会主持词
2014/03/28 职场文书
结婚保证书范文
2014/04/29 职场文书
大学班级计划书
2014/04/29 职场文书
祖国在我心中演讲稿300字
2014/05/04 职场文书
团拜会策划方案
2014/06/07 职场文书
员工年终自我评价
2014/09/14 职场文书
医生辞职信范文
2015/03/02 职场文书
用javascript制作qq注册动态页面
2021/04/14 Javascript
详解Vue router路由
2021/11/20 Vue.js