pytorch中的weight-initilzation用法


Posted in Python onJune 24, 2020

pytorch中的权值初始化

官方论坛对weight-initilzation的讨论

torch.nn.Module.apply(fn)

torch.nn.Module.apply(fn)
# 递归的调用weights_init函数,遍历nn.Module的submodule作为参数
# 常用来对模型的参数进行初始化
# fn是对参数进行初始化的函数的句柄,fn以nn.Module或者自己定义的nn.Module的子类作为参数
# fn (Module -> None) ? function to be applied to each submodule
# Returns: self
# Return type: Module

例子:

def weights_init(m):
 classname = m.__class__.__name__
 if classname.find('Conv') != -1:
  m.weight.data.normal_(0.0, 0.02) 
  # m.weight.data是卷积核参数, m.bias.data是偏置项参数
 elif classname.find('BatchNorm') != -1:
  m.weight.data.normal_(1.0, 0.02)
  m.bias.data.fill_(0)

netG = _netG(ngpu) # 生成模型实例
netG.apply(weights_init) # 递归的调用weights_init函数,遍历netG的submodule作为参数
#-*-coding:utf-8-*-
import torch
from torch.autograd import Variable

# 对模型参数进行初始化
# 官方论坛链接:https://discuss.pytorch.org/t/weight-initilzation/157/3

# 方法一
# 单独定义一个weights_init函数,输入参数是m(torch.nn.module或者自己定义的继承nn.module的子类)
# 然后使用net.apply()进行参数初始化
# m.__class__.__name__ 获得nn.module的名字
# https://github.com/pytorch/examples/blob/master/dcgan/main.py#L90-L96
def weights_init(m):
 classname = m.__class__.__name__
 if classname.find('Conv') != -1:
  m.weight.data.normal_(0.0, 0.02)
 elif classname.find('BatchNorm') != -1:
  m.weight.data.normal_(1.0, 0.02)
  m.bias.data.fill_(0)

netG = _netG(ngpu) # 生成模型实例
netG.apply(weights_init) # 递归的调用weights_init函数,遍历netG的submodule作为参数

# function to be applied to each submodule

# 方法二
# 1. 使用net.modules()遍历模型中的网络层的类型 2. 对其中的m层的weigth.data(tensor)部分进行初始化操作
# Another initialization example from PyTorch Vision resnet implementation.
# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L112-L118
class ResNet(nn.Module):
 def __init__(self, block, layers, num_classes=1000):
  self.inplanes = 64
  super(ResNet, self).__init__()
  self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
        bias=False)
  self.bn1 = nn.BatchNorm2d(64)
  self.relu = nn.ReLU(inplace=True)
  self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  self.layer1 = self._make_layer(block, 64, layers[0])
  self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
  self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
  self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
  self.avgpool = nn.AvgPool2d(7, stride=1)
  self.fc = nn.Linear(512 * block.expansion, num_classes)
  # 权值参数初始化
  for m in self.modules():
   if isinstance(m, nn.Conv2d):
    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    m.weight.data.normal_(0, math.sqrt(2. / n))
   elif isinstance(m, nn.BatchNorm2d):
    m.weight.data.fill_(1)
    m.bias.data.zero_()

# 方法三
# 自己知道网络中参数的顺序和类型, 然后将参数依次读取出来,调用torch.nn.init中的方法进行初始化
net = AlexNet(2)
params = list(net.parameters()) # params依次为Conv2d参数和Bias参数
# 或者
conv1Params = list(net.conv1.parameters())
# 其中,conv1Params[0]表示卷积核参数, conv1Params[1]表示bias项参数
# 然后使用torch.nn.init中函数进行初始化
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.constant(tensor, 0)

# net.modules()迭代的返回: AlexNet,Sequential,Conv2d,ReLU,MaxPool2d,LRN,AvgPool3d....,Conv2d,...,Conv2d,...,Linear,
# 这里,只有Conv2d和Linear才有参数
# net.children()只返回实际存在的子模块: Sequential,Sequential,Sequential,Sequential,Sequential,Sequential,Sequential,Linear

# 附AlexNet的定义
class AlexNet(nn.Module):
 def __init__(self, num_classes = 2): # 默认为两类,猫和狗
#   super().__init__() # python3
  super(AlexNet, self).__init__()
  # 开始构建AlexNet网络模型,5层卷积,3层全连接层
  # 5层卷积层
  self.conv1 = nn.Sequential(
   nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4),
   nn.ReLU(inplace=True),
   nn.MaxPool2d(kernel_size=3, stride=2),
   LRN(local_size=5, bias=1, alpha=1e-4, beta=0.75, ACROSS_CHANNELS=True)
  )
  self.conv2 = nn.Sequential(
   nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, groups=2, padding=2),
   nn.ReLU(inplace=True),
   nn.MaxPool2d(kernel_size=3, stride=2),
   LRN(local_size=5, bias=1, alpha=1e-4, beta=0.75, ACROSS_CHANNELS=True)
  )
  self.conv3 = nn.Sequential(
   nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1),
   nn.ReLU(inplace=True)
  )
  self.conv4 = nn.Sequential(
   nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1),
   nn.ReLU(inplace=True)
  )
  self.conv5 = nn.Sequential(
   nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1),
   nn.ReLU(inplace=True),
   nn.MaxPool2d(kernel_size=3, stride=2)
  )
  # 3层全连接层
  # 前向计算的时候,最开始输入需要进行view操作,将3D的tensor变为1D
  self.fc6 = nn.Sequential(
   nn.Linear(in_features=6*6*256, out_features=4096),
   nn.ReLU(inplace=True),
   nn.Dropout()
  )
  self.fc7 = nn.Sequential(
   nn.Linear(in_features=4096, out_features=4096),
   nn.ReLU(inplace=True),
   nn.Dropout()
  )
  self.fc8 = nn.Linear(in_features=4096, out_features=num_classes)

 def forward(self, x):
  x = self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(x)))))
  x = x.view(-1, 6*6*256)
  x = self.fc8(self.fc7(self.fc6(x)))
  return x

补充知识:pytorch Load部分weights

我们从网上down下来的模型与我们的模型可能就存在一个层的差异,此时我们就需要重新训练所有的参数是不合理的。

因此我们可以加载相同的参数,而忽略不同的参数,代码如下:

pretrained_dict = torch.load(“model.pth”)
  model_dict = et.state_dict()
  pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
  model_dict.update(pretrained_dict)
  net.load_state_dict(model_dict)

以上这篇pytorch中的weight-initilzation用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的装饰器用法详解
Jan 14 Python
Python闭包实现计数器的方法
May 05 Python
Python编程实现的图片识别功能示例
Aug 03 Python
浅析Python中的赋值和深浅拷贝
Aug 15 Python
解决python3 网络请求路径包含中文的问题
May 10 Python
python根据url地址下载小文件的实例
Dec 18 Python
Pandas分组与排序的实现
Jul 23 Python
python实现桌面托盘气泡提示
Jul 29 Python
python 怎样将dataframe中的字符串日期转化为日期的方法
Sep 26 Python
python处理excel绘制雷达图
Oct 18 Python
Python获取android设备cpu和内存占用情况
Nov 15 Python
python爬虫多次请求超时的几种重试方法(6种)
Dec 01 Python
pytorch查看模型weight与grad方式
Jun 24 #Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
可视化pytorch 模型中不同BN层的running mean曲线实例
Jun 24 #Python
python3.x中安装web.py步骤方法
Jun 23 #Python
python如何删除文件、目录
Jun 23 #Python
TensorFlow保存TensorBoard图像操作
Jun 23 #Python
python和js交互调用的方法
Jun 23 #Python
You might like
Yii2简单实现给表单添加验证码的方法
2016/07/18 PHP
奇妙的js
2007/09/24 Javascript
jQuery 行级解析读取XML文件(附源码)
2009/10/12 Javascript
基于jquery的一个OutlookBar类,动态创建导航条
2010/11/19 Javascript
JS添加删除一组文本框并对输入信息加以验证判断其正确性
2013/04/11 Javascript
如何使用jQuery来处理图片坏链具体实现步骤
2013/05/02 Javascript
jquery $.each() 使用小探
2013/08/23 Javascript
JS使用ajax方法获取指定url的head信息中指定字段值的方法
2015/03/24 Javascript
js给selected添加options的方法
2015/05/06 Javascript
javascript设计模式之对象工厂函数与构造函数详解
2015/07/30 Javascript
各式各样的导航条效果css3结合jquery代码实现
2016/09/17 Javascript
jQuery简单获取DIV和A标签元素位置的方法
2017/02/07 Javascript
JavaScript运动框架 解决防抖动问题、悬浮对联(二)
2017/05/17 Javascript
angularjs实现上拉加载和下拉刷新数据功能
2017/06/12 Javascript
详解React Native顶|底部导航使用小技巧
2017/09/14 Javascript
React Native 截屏组件的示例代码
2017/12/06 Javascript
vue-cli 3.x 修改dist路径的方法
2018/09/19 Javascript
jQuery操作动画完整实例分析
2020/01/10 jQuery
通过实例了解Javascript柯里化流程
2020/03/03 Javascript
js实现登录时记住密码的方法分析
2020/04/05 Javascript
[53:15]2018DOTA2亚洲邀请赛3月29日 小组赛A组 KG VS OG
2018/03/30 DOTA
python基础教程之分支、循环简单用法
2016/06/16 Python
解决PyCharm的Python.exe已经停止工作的问题
2018/11/29 Python
Flask框架踩坑之ajax跨域请求实现
2019/02/22 Python
对Django外键关系的描述
2019/07/26 Python
Python使用Tkinter实现滚动抽奖器效果
2020/01/06 Python
用python实现一个简单计算器(完整DEMO)
2020/10/14 Python
基于Python的接口自动化unittest测试框架和ddt数据驱动详解
2021/01/27 Python
HTML5 Web Database 数据库的SQL语句的使用方法
2012/12/09 HTML / CSS
HTML5 DeviceOrientation实现手机网站摇一摇功能代码实例
2015/04/24 HTML / CSS
日本索尼音乐商店:Sony Music Shop
2018/07/17 全球购物
Conforama瑞士:家具、厨房、电器、装饰
2020/09/06 全球购物
MYSQL基础面试题
2012/05/13 面试题
车间主管岗位职责
2013/11/14 职场文书
springboot @ConfigurationProperties和@PropertySource的区别
2021/06/11 Java/Android
关于Redis的主从复制及哨兵问题
2022/06/16 Redis