pytorch中的weight-initilzation用法


Posted in Python onJune 24, 2020

pytorch中的权值初始化

官方论坛对weight-initilzation的讨论

torch.nn.Module.apply(fn)

torch.nn.Module.apply(fn)
# 递归的调用weights_init函数,遍历nn.Module的submodule作为参数
# 常用来对模型的参数进行初始化
# fn是对参数进行初始化的函数的句柄,fn以nn.Module或者自己定义的nn.Module的子类作为参数
# fn (Module -> None) ? function to be applied to each submodule
# Returns: self
# Return type: Module

例子:

def weights_init(m):
 classname = m.__class__.__name__
 if classname.find('Conv') != -1:
  m.weight.data.normal_(0.0, 0.02) 
  # m.weight.data是卷积核参数, m.bias.data是偏置项参数
 elif classname.find('BatchNorm') != -1:
  m.weight.data.normal_(1.0, 0.02)
  m.bias.data.fill_(0)

netG = _netG(ngpu) # 生成模型实例
netG.apply(weights_init) # 递归的调用weights_init函数,遍历netG的submodule作为参数
#-*-coding:utf-8-*-
import torch
from torch.autograd import Variable

# 对模型参数进行初始化
# 官方论坛链接:https://discuss.pytorch.org/t/weight-initilzation/157/3

# 方法一
# 单独定义一个weights_init函数,输入参数是m(torch.nn.module或者自己定义的继承nn.module的子类)
# 然后使用net.apply()进行参数初始化
# m.__class__.__name__ 获得nn.module的名字
# https://github.com/pytorch/examples/blob/master/dcgan/main.py#L90-L96
def weights_init(m):
 classname = m.__class__.__name__
 if classname.find('Conv') != -1:
  m.weight.data.normal_(0.0, 0.02)
 elif classname.find('BatchNorm') != -1:
  m.weight.data.normal_(1.0, 0.02)
  m.bias.data.fill_(0)

netG = _netG(ngpu) # 生成模型实例
netG.apply(weights_init) # 递归的调用weights_init函数,遍历netG的submodule作为参数

# function to be applied to each submodule

# 方法二
# 1. 使用net.modules()遍历模型中的网络层的类型 2. 对其中的m层的weigth.data(tensor)部分进行初始化操作
# Another initialization example from PyTorch Vision resnet implementation.
# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L112-L118
class ResNet(nn.Module):
 def __init__(self, block, layers, num_classes=1000):
  self.inplanes = 64
  super(ResNet, self).__init__()
  self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
        bias=False)
  self.bn1 = nn.BatchNorm2d(64)
  self.relu = nn.ReLU(inplace=True)
  self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  self.layer1 = self._make_layer(block, 64, layers[0])
  self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
  self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
  self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
  self.avgpool = nn.AvgPool2d(7, stride=1)
  self.fc = nn.Linear(512 * block.expansion, num_classes)
  # 权值参数初始化
  for m in self.modules():
   if isinstance(m, nn.Conv2d):
    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    m.weight.data.normal_(0, math.sqrt(2. / n))
   elif isinstance(m, nn.BatchNorm2d):
    m.weight.data.fill_(1)
    m.bias.data.zero_()

# 方法三
# 自己知道网络中参数的顺序和类型, 然后将参数依次读取出来,调用torch.nn.init中的方法进行初始化
net = AlexNet(2)
params = list(net.parameters()) # params依次为Conv2d参数和Bias参数
# 或者
conv1Params = list(net.conv1.parameters())
# 其中,conv1Params[0]表示卷积核参数, conv1Params[1]表示bias项参数
# 然后使用torch.nn.init中函数进行初始化
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.constant(tensor, 0)

# net.modules()迭代的返回: AlexNet,Sequential,Conv2d,ReLU,MaxPool2d,LRN,AvgPool3d....,Conv2d,...,Conv2d,...,Linear,
# 这里,只有Conv2d和Linear才有参数
# net.children()只返回实际存在的子模块: Sequential,Sequential,Sequential,Sequential,Sequential,Sequential,Sequential,Linear

# 附AlexNet的定义
class AlexNet(nn.Module):
 def __init__(self, num_classes = 2): # 默认为两类,猫和狗
#   super().__init__() # python3
  super(AlexNet, self).__init__()
  # 开始构建AlexNet网络模型,5层卷积,3层全连接层
  # 5层卷积层
  self.conv1 = nn.Sequential(
   nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4),
   nn.ReLU(inplace=True),
   nn.MaxPool2d(kernel_size=3, stride=2),
   LRN(local_size=5, bias=1, alpha=1e-4, beta=0.75, ACROSS_CHANNELS=True)
  )
  self.conv2 = nn.Sequential(
   nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, groups=2, padding=2),
   nn.ReLU(inplace=True),
   nn.MaxPool2d(kernel_size=3, stride=2),
   LRN(local_size=5, bias=1, alpha=1e-4, beta=0.75, ACROSS_CHANNELS=True)
  )
  self.conv3 = nn.Sequential(
   nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1),
   nn.ReLU(inplace=True)
  )
  self.conv4 = nn.Sequential(
   nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1),
   nn.ReLU(inplace=True)
  )
  self.conv5 = nn.Sequential(
   nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1),
   nn.ReLU(inplace=True),
   nn.MaxPool2d(kernel_size=3, stride=2)
  )
  # 3层全连接层
  # 前向计算的时候,最开始输入需要进行view操作,将3D的tensor变为1D
  self.fc6 = nn.Sequential(
   nn.Linear(in_features=6*6*256, out_features=4096),
   nn.ReLU(inplace=True),
   nn.Dropout()
  )
  self.fc7 = nn.Sequential(
   nn.Linear(in_features=4096, out_features=4096),
   nn.ReLU(inplace=True),
   nn.Dropout()
  )
  self.fc8 = nn.Linear(in_features=4096, out_features=num_classes)

 def forward(self, x):
  x = self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(x)))))
  x = x.view(-1, 6*6*256)
  x = self.fc8(self.fc7(self.fc6(x)))
  return x

补充知识:pytorch Load部分weights

我们从网上down下来的模型与我们的模型可能就存在一个层的差异,此时我们就需要重新训练所有的参数是不合理的。

因此我们可以加载相同的参数,而忽略不同的参数,代码如下:

pretrained_dict = torch.load(“model.pth”)
  model_dict = et.state_dict()
  pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
  model_dict.update(pretrained_dict)
  net.load_state_dict(model_dict)

以上这篇pytorch中的weight-initilzation用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 随机数生成的代码的详细分析
May 15 Python
python写的ARP攻击代码实例
Jun 04 Python
举例讲解Linux系统下Python调用系统Shell的方法
Nov 07 Python
Windows上使用virtualenv搭建Python+Flask开发环境
Jun 07 Python
python利用MethodType绑定方法到类示例代码
Aug 27 Python
python实现生命游戏的示例代码(Game of Life)
Jan 24 Python
python生成不重复随机数和对list乱序的解决方法
Apr 09 Python
python 正确保留多位小数的实例
Jul 16 Python
解决PyCharm不运行脚本,而是运行单元测试的问题
Jan 17 Python
python 机器学习之支持向量机非线性回归SVR模型
Jun 26 Python
python 实时调取摄像头的示例代码
Nov 25 Python
Python Selenium异常处理的实例分析
Feb 28 Python
pytorch查看模型weight与grad方式
Jun 24 #Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
可视化pytorch 模型中不同BN层的running mean曲线实例
Jun 24 #Python
python3.x中安装web.py步骤方法
Jun 23 #Python
python如何删除文件、目录
Jun 23 #Python
TensorFlow保存TensorBoard图像操作
Jun 23 #Python
python和js交互调用的方法
Jun 23 #Python
You might like
现磨咖啡骗局!现磨咖啡=新鲜咖啡?现磨咖啡背后的猫腻你不懂!
2019/03/28 冲泡冲煮
Windows下的PHP5.0安装配制详解
2006/09/05 PHP
浅谈php中mysql与mysqli的区别分析
2013/06/10 PHP
dedecms函数分享之获取某一栏目所有子栏目
2014/05/19 PHP
PHP中几个可以提高运行效率的代码写法、技巧分享
2014/08/21 PHP
一个php生成16位随机数的代码(两种方法)
2014/09/16 PHP
PHP用户管理中常用接口调用实例及解析(含源码)
2017/03/09 PHP
PHP编程计算两个时间段是否有交集的实现方法(不算边界重叠)
2017/05/30 PHP
php使用imagecopymerge()函数创建半透明水印
2018/01/25 PHP
JavaScript中如何通过arguments对象实现对象的重载
2014/05/12 Javascript
JQuery中$.each 和$(selector).each()的区别详解
2015/03/13 Javascript
javascript实现tab切换的四种方法
2015/11/05 Javascript
详细解读Jquery各Ajax函数($.get(),$.post(),$.ajax(),$.getJSON())
2016/08/15 Javascript
微信小程序 tabs选项卡效果的实现
2017/01/05 Javascript
es7学习教程之fetch解决异步嵌套问题的方法示例
2017/07/21 Javascript
在vue中实现简单页面逆传值的方法
2017/11/27 Javascript
JS中移除非数字最多保留一位小数
2018/05/09 Javascript
js中获取URL参数的共用方法getRequest()方法实例详解
2018/10/24 Javascript
vue 表单验证按钮事件交由父组件触发的方法
2018/12/17 Javascript
vue项目中在外部js文件中直接调用vue实例的方法比如说this
2019/04/28 Javascript
vue.js+ElementUI实现进度条提示密码强度效果
2020/01/18 Javascript
微信小程序开发之获取用户手机号码(php接口解密)
2020/05/17 Javascript
[45:15]Optic vs VP 2018国际邀请赛淘汰赛BO3 第一场 8.24
2018/08/25 DOTA
Python中常见的数据类型小结
2015/08/29 Python
玩转python爬虫之正则表达式
2016/02/17 Python
python list格式数据excel导出方法
2018/10/31 Python
django基于存储在前端的token用户认证解析
2019/08/06 Python
Python更换pip源方法过程解析
2020/05/19 Python
python进度条显示-tqmd模块的实现示例
2020/08/23 Python
CSS3制作ajax loader icon实现思路及代码
2013/08/25 HTML / CSS
canvas与html5实现视频截图功能示例
2016/12/15 HTML / CSS
Hotter Shoes英国官网:英伦风格,舒适的鞋子
2017/12/28 全球购物
Bose英国官方网站:美国知名音响品牌
2020/01/26 全球购物
安踏广告词改编版
2014/03/21 职场文书
颁奖晚会主持词
2014/03/25 职场文书
八项规定个人对照检查材料思想汇报
2014/09/25 职场文书