python PyTorch参数初始化和Finetune


Posted in Python onFebruary 11, 2018

前言

这篇文章算是论坛PyTorch Forums关于参数初始化和finetune的总结,也是我在写代码中用的算是“最佳实践”吧。最后希望大家没事多逛逛论坛,有很多高质量的回答。

参数初始化

参数的初始化其实就是对参数赋值。而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了。这就是PyTorch简洁高效所在。

python PyTorch参数初始化和Finetune

所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法是PyTorch作者所推崇的:

def weight_init(m):
# 使用isinstance来判断m属于什么类型
  if isinstance(m, nn.Conv2d):
    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    m.weight.data.normal_(0, math.sqrt(2. / n))
  elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其实都是Variable,为了能学习参数以及后向传播
    m.weight.data.fill_(1)
    m.bias.data.zero_()

Finetune

往往在加载了预训练模型的参数之后,我们需要finetune模型,可以使用不同的方式finetune。

局部微调

有时候我们加载了训练模型后,只想调节最后的几层,其他层不训练。其实不训练也就意味着不进行梯度计算,PyTorch中提供的requires_grad使得对训练的控制变得非常简单。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
  param.requires_grad = False
# 替换最后的全连接层, 改为训练100类
# 新构造的模块的参数默认requires_grad为True
model.fc = nn.Linear(512, 100)

# 只优化最后的分类层
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

全局微调

有时候我们需要对全局都进行finetune,只不过我们希望改换过的层和其他层的学习速率不一样,这时候我们可以把其他层和新层在optimizer中单独赋予不同的学习速率。比如:

ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,
           model.parameters())

optimizer = torch.optim.SGD([
      {'params': base_params},
      {'params': model.fc.parameters(), 'lr': 1e-3}
      ], lr=1e-2, momentum=0.9)

其中base_params使用1e-3来训练,model.fc.parameters使用1e-2来训练,momentum是二者共有的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python struct.unpack
Sep 06 Python
python抓取豆瓣图片并自动保存示例学习
Jan 10 Python
日常整理python执行系统命令的常见方法(全)
Oct 22 Python
Python Pexpect库的简单使用方法
Jan 29 Python
django 中QuerySet特性功能详解
Jul 25 Python
Python中list循环遍历删除数据的正确方法
Sep 02 Python
python super的使用方法及实例详解
Sep 25 Python
python使用HTMLTestRunner导出饼图分析报告的方法
Dec 30 Python
Python自省及反射原理实例详解
Jul 06 Python
python报错TypeError: ‘NoneType‘ object is not subscriptable的解决方法
Nov 05 Python
总结python 三种常见的内存泄漏场景
Nov 20 Python
利用Python第三方库实现预测NBA比赛结果
Jun 21 Python
Python装饰器用法示例小结
Feb 11 #Python
python PyTorch预训练示例
Feb 11 #Python
TensorFlow中权重的随机初始化的方法
Feb 11 #Python
python的staticmethod与classmethod实现实例代码
Feb 11 #Python
Python语言的变量认识及操作方法
Feb 11 #Python
利用Opencv中Houghline方法实现直线检测
Feb 11 #Python
tensorflow输出权重值和偏差的方法
Feb 10 #Python
You might like
php简单封装了一些常用JS操作
2007/02/25 PHP
在PHP中读取和写入WORD文档的代码
2008/04/09 PHP
PHP的autoload机制的实现解析
2012/09/15 PHP
php中0,null,empty,空,false,字符串关系的详细介绍
2013/06/20 PHP
10 个经典PHP函数
2013/10/17 PHP
php使用crypt()函数进行加密
2017/06/08 PHP
js正则表达exec与match的区别说明
2014/01/29 Javascript
JavaScript使用yield模拟多线程的方法
2015/03/19 Javascript
jQuery Mobile弹出窗、弹出层知识汇总
2016/01/05 Javascript
基于JavaScript实现全屏透明遮罩div层锁屏效果
2016/01/26 Javascript
深入理解bootstrap框架之入门准备
2016/10/09 Javascript
浅谈jquery页面初始化的4种方式
2016/11/27 Javascript
jQuery使用ajax_动力节点Java学院整理
2017/07/05 jQuery
bootstrap插件treeview实现全选父节点下所有子节点和反选功能
2017/07/21 Javascript
修改UA在PC中访问只能在微信中打开的链接方法
2017/11/27 Javascript
jquery获取transform里的值实现方法
2017/12/12 jQuery
详解Vue+axios+Node+express实现文件上传(用户头像上传)
2018/08/10 Javascript
微信小程序个人中心的列表控件实现代码
2020/04/26 Javascript
Vue基本指令实例图文讲解
2021/02/25 Vue.js
[42:32]DOTA2上海特级锦标赛B组资格赛#2 Fnatic VS Spirit第二局
2016/02/27 DOTA
[49:40]2018DOTA2亚洲邀请赛小组赛 A组加赛 TNC vs Newbee
2018/04/03 DOTA
Python中返回字典键的值的values()方法使用
2015/05/22 Python
全面了解Nginx, WSGI, Flask之间的关系
2018/01/09 Python
PyQt5的PyQtGraph实践系列3之实时数据更新绘制图形
2019/05/13 Python
Python imutils 填充图片周边为黑色的实现
2020/01/19 Python
绝对令人的惊叹的CSS3折叠效果(3D效果)整理
2012/12/30 HTML / CSS
html5与css3小应用
2013/04/03 HTML / CSS
html5 自定义播放器核心代码
2013/12/20 HTML / CSS
绘画设计学生的个人自我评价
2013/09/20 职场文书
音乐学个人的自荐书范文
2013/11/26 职场文书
专升本个人自我评价
2013/12/22 职场文书
高中课程设置方案
2014/05/28 职场文书
不同意离婚代理词
2015/05/23 职场文书
如何使JavaScript休眠或等待
2021/04/27 Javascript
Win11怎样将锁屏账户头像图片改成动画视频
2021/11/21 数码科技
代码复现python目标检测yolo3详解预测
2022/05/06 Python