把vgg-face.mat权重迁移到pytorch模型示例


Posted in Python onDecember 27, 2019

最近使用pytorch时,需要用到一个预训练好的人脸识别模型提取人脸ID特征,想到很多人都在用用vgg-face,但是vgg-face没有pytorch的模型,于是写个vgg-face.mat转到pytorch模型的代码

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu May 10 10:41:40 2018
@author: hy
"""
import torch
import math
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
from scipy.io import loadmat
import scipy.misc as sm
import matplotlib.pyplot as plt
 
class vgg16_face(nn.Module):
  def __init__(self,num_classes=2622):
    super(vgg16_face,self).__init__()
    inplace = True
    self.conv1_1 = nn.Conv2d(3,64,kernel_size=(3,3),stride=(1,1),padding=(1,1))
    self.relu1_1 = nn.ReLU(inplace)
    self.conv1_2 = nn.Conv2d(64,64,kernel_size=(3,3),stride=(1,1),padding=(1,1))
    self.relu1_2 = nn.ReLU(inplace)
    self.pool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)
      
    self.conv2_1 = nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    self.relu2_1 = nn.ReLU(inplace)
    self.conv2_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    self.relu2_2 = nn.ReLU(inplace)
    self.pool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)
      
    self.conv3_1 = nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    self.relu3_1 = nn.ReLU(inplace)
    self.conv3_2 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    self.relu3_2 = nn.ReLU(inplace)
    self.conv3_3 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    self.relu3_3 = nn.ReLU(inplace)
    self.pool3 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)
      
    self.conv4_1 = nn.Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    self.relu4_1 = nn.ReLU(inplace)
    self.conv4_2 = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    self.relu4_2 = nn.ReLU(inplace)
    self.conv4_3 = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    self.relu4_3 = nn.ReLU(inplace)
    self.pool4 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)
      
    self.conv5_1 = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    self.relu5_1 = nn.ReLU(inplace)
    self.conv5_2 = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    self.relu5_2 = nn.ReLU(inplace)
    self.conv5_3 = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    self.relu5_3 = nn.ReLU(inplace)
    self.pool5 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False) 
      
    self.fc6 = nn.Linear(in_features=25088, out_features=4096, bias=True)
    self.relu6 = nn.ReLU(inplace)
    self.drop6 = nn.Dropout(p=0.5)
    
    self.fc7 = nn.Linear(in_features=4096, out_features=4096, bias=True)
    self.relu7 = nn.ReLU(inplace)
    self.drop7 = nn.Dropout(p=0.5)
    self.fc8 = nn.Linear(in_features=4096, out_features=num_classes, bias=True)
      
    self._initialize_weights()
  def forward(self,x):
    out = self.conv1_1(x)
    x_conv1 = out
    out = self.relu1_1(out)
    out = self.conv1_2(out)
    out = self.relu1_2(out)
    out = self.pool1(out)
    x_pool1 = out
    
    out = self.conv2_1(out)
    out = self.relu2_1(out)
    out = self.conv2_2(out)
    out = self.relu2_2(out)
    out = self.pool2(out)
    x_pool2 = out
    
    out = self.conv3_1(out)
    out = self.relu3_1(out)
    out = self.conv3_2(out)
    out = self.relu3_2(out)
    out = self.conv3_3(out)
    out = self.relu3_3(out)
    out = self.pool3(out)
    x_pool3 = out
    
    out = self.conv4_1(out)
    out = self.relu4_1(out)
    out = self.conv4_2(out)
    out = self.relu4_2(out)
    out = self.conv4_3(out)
    out = self.relu4_3(out)
    out = self.pool4(out)
    x_pool4 = out
    
    out = self.conv5_1(out)
    out = self.relu5_1(out)
    out = self.conv5_2(out)
    out = self.relu5_2(out)
    out = self.conv5_3(out)
    out = self.relu5_3(out)
    out = self.pool5(out)
    x_pool5 = out
    
    out = out.view(out.size(0),-1)
    
    out = self.fc6(out)
    out = self.relu6(out)
    out = self.fc7(out)
    out = self.relu7(out)
    out = self.fc8(out)
    
    return out, x_pool1, x_pool2, x_pool3, x_pool4, x_pool5
 
  def _initialize_weights(self):
    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
        if m.bias is not None:
          m.bias.data.zero_()
      elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()
      elif isinstance(m, nn.Linear):
        m.weight.data.normal_(0, 0.01)
        m.bias.data.zero_()
     
def copy(vgglayers, dstlayer,idx):
  layer = vgglayers[0][idx]
  kernel, bias = layer[0]['weights'][0][0]
  if idx in [33,35]: # fc7, fc8
    kernel = kernel.squeeze()
    dstlayer.weight.data.copy_(torch.from_numpy(kernel.transpose([1,0]))) # matrix format: axb -> bxa
  elif idx == 31: # fc6
    kernel = kernel.reshape(-1,4096)
    dstlayer.weight.data.copy_(torch.from_numpy(kernel.transpose([1,0]))) # matrix format: axb -> bxa
  else:
    dstlayer.weight.data.copy_(torch.from_numpy(kernel.transpose([3,2,1,0]))) # matrix format: axbxcxd -> dxcxbxa
  dstlayer.bias.data.copy_(torch.from_numpy(bias.reshape(-1)))
 
def get_vggface(vgg_path):
  """1. define pytorch model"""   
  model = vgg16_face()   
  
  """2. get pre-trained weights and other params"""     
  #vgg_path = "/home/hy/vgg-face.mat" # download from http://www.vlfeat.org/matconvnet/pretrained/
  vgg_weights = loadmat(vgg_path)
  data = vgg_weights
  meta = data['meta']
  classes = meta['classes']
  class_names = classes[0][0]['description'][0][0]
  normalization = meta['normalization']
  average_image = np.squeeze(normalization[0][0]['averageImage'][0][0][0][0])
  image_size = np.squeeze(normalization[0][0]['imageSize'][0][0])
  layers = data['layers']
  # =============================================================================
  # for idx,layer in enumerate(layers[0]):
  #   name = layer[0]['name'][0][0]
  #   print idx,name
  # """
  # 0 conv1_1
  # 1 relu1_1
  # 2 conv1_2
  # 3 relu1_2
  # 4 pool1
  # 5 conv2_1
  # 6 relu2_1
  # 7 conv2_2
  # 8 relu2_2
  # 9 pool2
  # 10 conv3_1
  # 11 relu3_1
  # 12 conv3_2
  # 13 relu3_2
  # 14 conv3_3
  # 15 relu3_3
  # 16 pool3
  # 17 conv4_1
  # 18 relu4_1
  # 19 conv4_2
  # 20 relu4_2
  # 21 conv4_3
  # 22 relu4_3
  # 23 pool4
  # 24 conv5_1
  # 25 relu5_1
  # 26 conv5_2
  # 27 relu5_2
  # 28 conv5_3
  # 29 relu5_3
  # 30 pool5
  # 31 fc6
  # 32 relu6
  # 33 fc7
  # 34 relu7
  # 35 fc8
  # 36 prob
  # """
  # =============================================================================
  
  """3. load weights to pytorch model"""  
  copy(layers,model.conv1_1,0)
  copy(layers,model.conv1_2,2)
  copy(layers,model.conv2_1,5)
  copy(layers,model.conv2_2,7)
  copy(layers,model.conv3_1,10)
  copy(layers,model.conv3_2,12)
  copy(layers,model.conv3_3,14)
  copy(layers,model.conv4_1,17)
  copy(layers,model.conv4_2,19)
  copy(layers,model.conv4_3,21)
  copy(layers,model.conv5_1,24)
  copy(layers,model.conv5_2,26)
  copy(layers,model.conv5_3,28)
  copy(layers,model.fc6,31)
  copy(layers,model.fc7,33)
  copy(layers,model.fc8,35)
  return model,class_names,average_image,image_size
 
if __name__ == '__main__':
  """test""" 
  vgg_path = "/home/hy/vgg-face.mat" # download from http://www.vlfeat.org/matconvnet/pretrained/ 
  model,class_names,average_image,image_size = get_vggface(vgg_path) 
  imgpath = "/home/hy/e/avg_face.jpg"
  img = sm.imread(imgpath)
  img = sm.imresize(img,[image_size[0],image_size[1]])
  input_arr = np.float32(img)#-average_image # h,w,c
  x = torch.from_numpy(input_arr.transpose((2,0,1))) # c,h,w
  avg = torch.from_numpy(average_image) # 
  avg = avg.view(3,1,1).expand(3,224,224)
  x = x - avg
  x = x.contiguous()
  x = x.view(1, x.size(0), x.size(1), x.size(2))
  x = Variable(x)
  out, x_pool1, x_pool2, x_pool3, x_pool4, x_pool5 = model(x)
#  plt.imshow(x_pool1.data.numpy()[0,45]) # plot

以上这篇把vgg-face.mat权重迁移到pytorch模型示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 获取et和excel的版本号
Apr 09 Python
python自然语言编码转换模块codecs介绍
Apr 08 Python
Python实现统计英文单词个数及字符串分割代码
May 28 Python
在类Unix系统上开始Python3编程入门
Aug 20 Python
Python数据结构与算法之常见的分配排序法示例【桶排序与基数排序】
Dec 15 Python
修改默认的pip版本为对应python2.7的方法
Nov 06 Python
Python实现高斯函数的三维显示方法
Dec 29 Python
Python中最好用的命令行参数解析工具(argparse)
Aug 23 Python
利用Python校准本地时间的方法教程
Oct 31 Python
python使用 cx_Oracle 模块进行查询操作示例
Nov 28 Python
详解pycharm连接远程linux服务器的虚拟环境的方法
Nov 13 Python
Python爬虫之Selenium设置元素等待的方法
Dec 04 Python
Pytorch 多维数组运算过程的索引处理方式
Dec 27 #Python
Pytorch 之修改Tensor部分值方式
Dec 27 #Python
pytorch 实现tensor与numpy数组转换
Dec 27 #Python
Numpy与Pytorch 矩阵操作方式
Dec 27 #Python
基于python及pytorch中乘法的使用详解
Dec 27 #Python
pytorch:torch.mm()和torch.matmul()的使用
Dec 27 #Python
pytorch点乘与叉乘示例讲解
Dec 27 #Python
You might like
PHP实现的回溯算法示例
2017/08/15 PHP
PDO::quote讲解
2019/01/29 PHP
同一个表单 根据要求递交到不同页面的实现方法小结
2009/08/05 Javascript
jquery操作checkbox实现全选和取消全选
2014/05/02 Javascript
js实现图片在未加载完成前显示加载中字样
2014/09/03 Javascript
上传文件返回的json数据会被提示下载问题解决方案
2014/12/03 Javascript
基于jquery和svg实现超炫酷的动画特效
2014/12/09 Javascript
JavaScript实现iframe自动高度调整和不同主域名跨域
2016/02/27 Javascript
JS查找字符串中出现次数最多的字符
2016/09/05 Javascript
JavaScript计算值然后把值嵌入到html中的实现方法
2016/10/29 Javascript
JS实现图片上传预览功能
2016/11/21 Javascript
vue学习笔记之指令v-text && v-html && v-bind详解
2017/05/12 Javascript
关于Vue.nextTick()的正确使用方法浅析
2017/08/25 Javascript
vue select二级联动第二级默认选中第一个option值的实例
2018/01/10 Javascript
在 Linux/Unix 中不重启 Vim 而重新加载 .vimrc 文件的流程
2018/03/21 Javascript
vue同步父子组件和异步父子组件的生命周期顺序问题
2018/10/07 Javascript
详解实现一个通用的“划词高亮”在线笔记功能
2019/04/23 Javascript
jquery实现选项卡切换代码实例
2019/05/14 jQuery
微信小程序云开发修改云数据库中的数据方法
2019/05/18 Javascript
eslint+prettier统一代码风格的实现方法
2020/07/22 Javascript
给Python的Django框架下搭建的BLOG添加RSS功能的教程
2015/04/08 Python
shelve  用来持久化任意的Python对象实例代码
2016/10/12 Python
python中获得当前目录和上级目录的实现方法
2017/10/12 Python
python的几种矩阵相乘的公式详解
2019/07/10 Python
Python configparser模块应用过程解析
2020/08/14 Python
python跨文件使用全局变量的实现
2020/11/17 Python
Bootstrap File Input文件上传组件
2020/12/01 HTML / CSS
Crocs波兰官方商店:女鞋、男鞋、童鞋、洞洞鞋
2019/10/08 全球购物
机械专业个人求职自荐信格式
2013/09/21 职场文书
房地产出纳岗位职责
2013/12/01 职场文书
饲料采购员岗位职责
2013/12/19 职场文书
2014年公务员思想汇报范文:全心全意为人民服务
2014/03/06 职场文书
元旦晚会感言
2014/03/12 职场文书
毕业设计致谢词
2015/05/14 职场文书
如何写通讯稿
2015/07/22 职场文书
Python如何利用pandas读取csv数据并绘图
2022/07/07 Python