pytorch中的weight-initilzation用法


Posted in Python onJune 24, 2020

pytorch中的权值初始化

官方论坛对weight-initilzation的讨论

torch.nn.Module.apply(fn)

torch.nn.Module.apply(fn)
# 递归的调用weights_init函数,遍历nn.Module的submodule作为参数
# 常用来对模型的参数进行初始化
# fn是对参数进行初始化的函数的句柄,fn以nn.Module或者自己定义的nn.Module的子类作为参数
# fn (Module -> None) ? function to be applied to each submodule
# Returns: self
# Return type: Module

例子:

def weights_init(m):
 classname = m.__class__.__name__
 if classname.find('Conv') != -1:
  m.weight.data.normal_(0.0, 0.02) 
  # m.weight.data是卷积核参数, m.bias.data是偏置项参数
 elif classname.find('BatchNorm') != -1:
  m.weight.data.normal_(1.0, 0.02)
  m.bias.data.fill_(0)

netG = _netG(ngpu) # 生成模型实例
netG.apply(weights_init) # 递归的调用weights_init函数,遍历netG的submodule作为参数
#-*-coding:utf-8-*-
import torch
from torch.autograd import Variable

# 对模型参数进行初始化
# 官方论坛链接:https://discuss.pytorch.org/t/weight-initilzation/157/3

# 方法一
# 单独定义一个weights_init函数,输入参数是m(torch.nn.module或者自己定义的继承nn.module的子类)
# 然后使用net.apply()进行参数初始化
# m.__class__.__name__ 获得nn.module的名字
# https://github.com/pytorch/examples/blob/master/dcgan/main.py#L90-L96
def weights_init(m):
 classname = m.__class__.__name__
 if classname.find('Conv') != -1:
  m.weight.data.normal_(0.0, 0.02)
 elif classname.find('BatchNorm') != -1:
  m.weight.data.normal_(1.0, 0.02)
  m.bias.data.fill_(0)

netG = _netG(ngpu) # 生成模型实例
netG.apply(weights_init) # 递归的调用weights_init函数,遍历netG的submodule作为参数

# function to be applied to each submodule

# 方法二
# 1. 使用net.modules()遍历模型中的网络层的类型 2. 对其中的m层的weigth.data(tensor)部分进行初始化操作
# Another initialization example from PyTorch Vision resnet implementation.
# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L112-L118
class ResNet(nn.Module):
 def __init__(self, block, layers, num_classes=1000):
  self.inplanes = 64
  super(ResNet, self).__init__()
  self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
        bias=False)
  self.bn1 = nn.BatchNorm2d(64)
  self.relu = nn.ReLU(inplace=True)
  self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  self.layer1 = self._make_layer(block, 64, layers[0])
  self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
  self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
  self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
  self.avgpool = nn.AvgPool2d(7, stride=1)
  self.fc = nn.Linear(512 * block.expansion, num_classes)
  # 权值参数初始化
  for m in self.modules():
   if isinstance(m, nn.Conv2d):
    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    m.weight.data.normal_(0, math.sqrt(2. / n))
   elif isinstance(m, nn.BatchNorm2d):
    m.weight.data.fill_(1)
    m.bias.data.zero_()

# 方法三
# 自己知道网络中参数的顺序和类型, 然后将参数依次读取出来,调用torch.nn.init中的方法进行初始化
net = AlexNet(2)
params = list(net.parameters()) # params依次为Conv2d参数和Bias参数
# 或者
conv1Params = list(net.conv1.parameters())
# 其中,conv1Params[0]表示卷积核参数, conv1Params[1]表示bias项参数
# 然后使用torch.nn.init中函数进行初始化
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.constant(tensor, 0)

# net.modules()迭代的返回: AlexNet,Sequential,Conv2d,ReLU,MaxPool2d,LRN,AvgPool3d....,Conv2d,...,Conv2d,...,Linear,
# 这里,只有Conv2d和Linear才有参数
# net.children()只返回实际存在的子模块: Sequential,Sequential,Sequential,Sequential,Sequential,Sequential,Sequential,Linear

# 附AlexNet的定义
class AlexNet(nn.Module):
 def __init__(self, num_classes = 2): # 默认为两类,猫和狗
#   super().__init__() # python3
  super(AlexNet, self).__init__()
  # 开始构建AlexNet网络模型,5层卷积,3层全连接层
  # 5层卷积层
  self.conv1 = nn.Sequential(
   nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4),
   nn.ReLU(inplace=True),
   nn.MaxPool2d(kernel_size=3, stride=2),
   LRN(local_size=5, bias=1, alpha=1e-4, beta=0.75, ACROSS_CHANNELS=True)
  )
  self.conv2 = nn.Sequential(
   nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, groups=2, padding=2),
   nn.ReLU(inplace=True),
   nn.MaxPool2d(kernel_size=3, stride=2),
   LRN(local_size=5, bias=1, alpha=1e-4, beta=0.75, ACROSS_CHANNELS=True)
  )
  self.conv3 = nn.Sequential(
   nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1),
   nn.ReLU(inplace=True)
  )
  self.conv4 = nn.Sequential(
   nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1),
   nn.ReLU(inplace=True)
  )
  self.conv5 = nn.Sequential(
   nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1),
   nn.ReLU(inplace=True),
   nn.MaxPool2d(kernel_size=3, stride=2)
  )
  # 3层全连接层
  # 前向计算的时候,最开始输入需要进行view操作,将3D的tensor变为1D
  self.fc6 = nn.Sequential(
   nn.Linear(in_features=6*6*256, out_features=4096),
   nn.ReLU(inplace=True),
   nn.Dropout()
  )
  self.fc7 = nn.Sequential(
   nn.Linear(in_features=4096, out_features=4096),
   nn.ReLU(inplace=True),
   nn.Dropout()
  )
  self.fc8 = nn.Linear(in_features=4096, out_features=num_classes)

 def forward(self, x):
  x = self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(x)))))
  x = x.view(-1, 6*6*256)
  x = self.fc8(self.fc7(self.fc6(x)))
  return x

补充知识:pytorch Load部分weights

我们从网上down下来的模型与我们的模型可能就存在一个层的差异,此时我们就需要重新训练所有的参数是不合理的。

因此我们可以加载相同的参数,而忽略不同的参数,代码如下:

pretrained_dict = torch.load(“model.pth”)
  model_dict = et.state_dict()
  pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
  model_dict.update(pretrained_dict)
  net.load_state_dict(model_dict)

以上这篇pytorch中的weight-initilzation用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python进阶教程之函数参数的多种传递方法
Aug 30 Python
让python 3支持mysqldb的解决方法
Feb 14 Python
基于python代码实现简易滤除数字的方法
Jul 17 Python
Python aiohttp百万并发极限测试实例分析
Oct 26 Python
python GUI库图形界面开发之PyQt5简单绘图板实例与代码分析
Mar 08 Python
PYcharm 激活方法(推荐)
Mar 23 Python
python 使用多线程创建一个Buffer缓存器的实现思路
Jul 02 Python
python等待10秒执行下一命令的方法
Jul 19 Python
使用Python将xmind脑图转成excel用例的实现代码(一)
Oct 12 Python
Python通过递归函数输出嵌套列表元素
Oct 15 Python
如何利用Python matplotlib绘制雷达图
Dec 21 Python
golang特有程序结构入门教程
Jun 02 Python
pytorch查看模型weight与grad方式
Jun 24 #Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
可视化pytorch 模型中不同BN层的running mean曲线实例
Jun 24 #Python
python3.x中安装web.py步骤方法
Jun 23 #Python
python如何删除文件、目录
Jun 23 #Python
TensorFlow保存TensorBoard图像操作
Jun 23 #Python
python和js交互调用的方法
Jun 23 #Python
You might like
MOTOROLA 摩托罗拉 MODEL 66-XI五灯中波收音机
2021/03/02 无线电
php 在线打包_支持子目录
2008/06/28 PHP
学习php设计模式 php实现模板方法模式
2015/12/08 PHP
php正则表达式验证(邮件地址、Url地址、电话号码、邮政编码)
2016/03/14 PHP
php获取小程序码的实现代码(B类接口)
2020/06/13 PHP
取得一定长度的内容,处理中文
2006/12/20 Javascript
Prototype 学习 Prototype对象
2009/07/12 Javascript
Jquery ui css framework
2010/06/28 Javascript
javascript实用方法总结
2015/02/06 Javascript
jQuery提示插件alertify使用指南
2015/04/21 Javascript
jQuery实现图片渐入渐出切换展示效果
2015/08/15 Javascript
jquery+html5烂漫爱心表白动画代码分享
2015/08/24 Javascript
JS实现1000以内被3或5整除的数字之和
2016/02/18 Javascript
jquery实现图片上传前本地预览功能
2016/05/10 Javascript
iframe中使用jquery进行查找的方法【案例分析】
2016/06/17 Javascript
jQuery实现指定区域外单击关闭指定层的方法【经典】
2016/06/22 Javascript
认识jQuery的Promise的具体使用方法
2017/10/10 jQuery
Element-UI踩坑之Pagination组件的使用
2018/10/29 Javascript
jquery实现自定义树形表格的方法【自定义树形结构table】
2019/07/12 jQuery
vue子路由跳转实现tab选项卡
2019/07/24 Javascript
vue-cli3使用mock数据的方法分析
2020/03/16 Javascript
python dict乱码如何解决
2020/06/07 Python
Pycharm新手使用教程(图文详解)
2020/09/17 Python
python遍历路径破解表单的示例
2020/11/21 Python
澳大利亚女士时装在线:Rockmans
2018/09/26 全球购物
阿迪达斯希腊官方网上商店:adidas希腊
2019/04/06 全球购物
美国轻奢时尚购物网站:REVOLVE(支持中文)
2020/07/18 全球购物
Woods官网:加拿大最古老、最受尊敬的户外品牌之一
2020/09/12 全球购物
财务会计专业毕业生自荐信
2013/10/02 职场文书
演讲稿怎么写才完美
2014/01/02 职场文书
文化宣传方案
2014/03/13 职场文书
学校师德师风整改措施
2014/10/27 职场文书
公司岗位说明书
2015/10/08 职场文书
党员公开承诺书2016
2016/03/24 职场文书
Anaconda安装pytorch及配置PyCharm 2021环境
2021/06/04 Python
Feign调用传输文件异常的解决
2021/06/24 Java/Android