python PyTorch参数初始化和Finetune


Posted in Python onFebruary 11, 2018

前言

这篇文章算是论坛PyTorch Forums关于参数初始化和finetune的总结,也是我在写代码中用的算是“最佳实践”吧。最后希望大家没事多逛逛论坛,有很多高质量的回答。

参数初始化

参数的初始化其实就是对参数赋值。而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了。这就是PyTorch简洁高效所在。

python PyTorch参数初始化和Finetune

所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法是PyTorch作者所推崇的:

def weight_init(m):
# 使用isinstance来判断m属于什么类型
  if isinstance(m, nn.Conv2d):
    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    m.weight.data.normal_(0, math.sqrt(2. / n))
  elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其实都是Variable,为了能学习参数以及后向传播
    m.weight.data.fill_(1)
    m.bias.data.zero_()

Finetune

往往在加载了预训练模型的参数之后,我们需要finetune模型,可以使用不同的方式finetune。

局部微调

有时候我们加载了训练模型后,只想调节最后的几层,其他层不训练。其实不训练也就意味着不进行梯度计算,PyTorch中提供的requires_grad使得对训练的控制变得非常简单。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
  param.requires_grad = False
# 替换最后的全连接层, 改为训练100类
# 新构造的模块的参数默认requires_grad为True
model.fc = nn.Linear(512, 100)

# 只优化最后的分类层
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

全局微调

有时候我们需要对全局都进行finetune,只不过我们希望改换过的层和其他层的学习速率不一样,这时候我们可以把其他层和新层在optimizer中单独赋予不同的学习速率。比如:

ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,
           model.parameters())

optimizer = torch.optim.SGD([
      {'params': base_params},
      {'params': model.fc.parameters(), 'lr': 1e-3}
      ], lr=1e-2, momentum=0.9)

其中base_params使用1e-3来训练,model.fc.parameters使用1e-2来训练,momentum是二者共有的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
基于scrapy实现的简单蜘蛛采集程序
Apr 17 Python
Python中字典映射类型的学习教程
Aug 20 Python
Python采用Django制作简易的知乎日报API
Aug 03 Python
python僵尸进程产生的原因
Jul 21 Python
对python实现模板生成脚本的方法详解
Jan 30 Python
Python实现的爬取小说爬虫功能示例
Mar 30 Python
详解用pyecharts Geo实现动态数据热力图城市找不到问题解决
Jun 26 Python
在linux系统下安装python librtmp包的实现方法
Jul 22 Python
python对XML文件的操作实现代码
Mar 27 Python
Django集成MongoDB实现过程解析
Dec 01 Python
使用Python下载抖音各大V视频的思路详解
Feb 06 Python
python使用pywinauto驱动微信客户端实现公众号爬虫
May 19 Python
Python装饰器用法示例小结
Feb 11 #Python
python PyTorch预训练示例
Feb 11 #Python
TensorFlow中权重的随机初始化的方法
Feb 11 #Python
python的staticmethod与classmethod实现实例代码
Feb 11 #Python
Python语言的变量认识及操作方法
Feb 11 #Python
利用Opencv中Houghline方法实现直线检测
Feb 11 #Python
tensorflow输出权重值和偏差的方法
Feb 10 #Python
You might like
PHP 5.0对象模型深度探索之属性和方法
2008/03/27 PHP
深入解析php中的foreach函数
2013/08/31 PHP
ThinkPHP中的三大自动简介
2014/08/22 PHP
php include类文件超时问题处理
2015/02/06 PHP
PHP直接修改表内容DataGrid功能实现代码
2015/09/24 PHP
php中的依赖注入实例详解
2019/08/14 PHP
PHP实现简易用户登录系统
2020/07/10 PHP
JS去除字符串的空格增强版(可以去除中间的空格)
2009/08/26 Javascript
jquery 图片截取工具jquery.imagecropper.js
2010/04/09 Javascript
jquery教程限制文本框只能输入数字和小数点示例分享
2014/01/13 Javascript
js点击事件链接的问题解决
2014/04/25 Javascript
JQuery记住用户名密码实现下次自动登录功能
2015/04/27 Javascript
浅析函数声明和函数表达式——函数声明的声明提前
2016/05/03 Javascript
node.js 中国天气预报 简单实现
2016/06/06 Javascript
AngularJS HTML DOM详解及示例代码
2016/08/17 Javascript
JS使用正则表达式实现常用的表单验证功能分析
2020/04/30 Javascript
python网络编程之UDP通信实例(含服务器端、客户端、UDP广播例子)
2014/04/25 Python
跟老齐学Python之眼花缭乱的运算符
2014/09/14 Python
Python3.6基于正则实现的计算器示例【无优化简单注释版】
2018/06/14 Python
python 实现方阵的对角线遍历示例
2019/11/29 Python
使用 Python 处理3万多条数据只要几秒钟
2020/01/19 Python
浅谈Pytorch中的自动求导函数backward()所需参数的含义
2020/02/29 Python
Python使用struct处理二进制(pack和unpack用法)
2020/11/12 Python
CSS3五个技巧给你的网站带来出色的效果
2009/04/02 HTML / CSS
凯蒂·佩里个人女鞋品牌:Katy Perry Collections
2019/04/04 全球购物
外企测试工程师面试题
2015/02/01 面试题
幼儿园家长会欢迎词
2014/01/09 职场文书
建筑设计专业求职自我评价
2014/03/02 职场文书
春节请假条
2014/04/11 职场文书
2014教师党员自我评议(5篇)
2014/09/20 职场文书
财务工作失职检讨书
2014/11/21 职场文书
红色电影观后感
2015/06/18 职场文书
承兑汇票延期证明
2015/06/23 职场文书
2016年党校科级干部培训班学习心得体会
2016/01/06 职场文书
写作技巧:优秀文案必备的3种结构
2019/08/19 职场文书
导游词之桂林山水
2019/09/20 职场文书