python PyTorch参数初始化和Finetune


Posted in Python onFebruary 11, 2018

前言

这篇文章算是论坛PyTorch Forums关于参数初始化和finetune的总结,也是我在写代码中用的算是“最佳实践”吧。最后希望大家没事多逛逛论坛,有很多高质量的回答。

参数初始化

参数的初始化其实就是对参数赋值。而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了。这就是PyTorch简洁高效所在。

python PyTorch参数初始化和Finetune

所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法是PyTorch作者所推崇的:

def weight_init(m):
# 使用isinstance来判断m属于什么类型
  if isinstance(m, nn.Conv2d):
    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    m.weight.data.normal_(0, math.sqrt(2. / n))
  elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其实都是Variable,为了能学习参数以及后向传播
    m.weight.data.fill_(1)
    m.bias.data.zero_()

Finetune

往往在加载了预训练模型的参数之后,我们需要finetune模型,可以使用不同的方式finetune。

局部微调

有时候我们加载了训练模型后,只想调节最后的几层,其他层不训练。其实不训练也就意味着不进行梯度计算,PyTorch中提供的requires_grad使得对训练的控制变得非常简单。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
  param.requires_grad = False
# 替换最后的全连接层, 改为训练100类
# 新构造的模块的参数默认requires_grad为True
model.fc = nn.Linear(512, 100)

# 只优化最后的分类层
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

全局微调

有时候我们需要对全局都进行finetune,只不过我们希望改换过的层和其他层的学习速率不一样,这时候我们可以把其他层和新层在optimizer中单独赋予不同的学习速率。比如:

ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,
           model.parameters())

optimizer = torch.optim.SGD([
      {'params': base_params},
      {'params': model.fc.parameters(), 'lr': 1e-3}
      ], lr=1e-2, momentum=0.9)

其中base_params使用1e-3来训练,model.fc.parameters使用1e-2来训练,momentum是二者共有的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅要分析Python程序与C程序的结合使用
Apr 07 Python
Python操作列表之List.insert()方法的使用
May 20 Python
Python标准库sched模块使用指南
Jul 06 Python
解决python3中自定义wsgi函数,make_server函数报错的问题
Nov 21 Python
简单了解什么是神经网络
Dec 23 Python
Python enumerate索引迭代代码解析
Jan 19 Python
python编程嵌套函数实例代码
Feb 11 Python
Python批处理删除和重命名文件夹的实例
Jul 11 Python
python 解决flask 图片在线浏览或者直接下载的问题
Jan 09 Python
Python 实现加密过的PDF文件转WORD格式
Feb 04 Python
pycharm使用技巧之自动调整代码格式总结
Nov 04 Python
Django filter动态过滤与排序实现过程解析
Nov 26 Python
Python装饰器用法示例小结
Feb 11 #Python
python PyTorch预训练示例
Feb 11 #Python
TensorFlow中权重的随机初始化的方法
Feb 11 #Python
python的staticmethod与classmethod实现实例代码
Feb 11 #Python
Python语言的变量认识及操作方法
Feb 11 #Python
利用Opencv中Houghline方法实现直线检测
Feb 11 #Python
tensorflow输出权重值和偏差的方法
Feb 10 #Python
You might like
Ajax PHP 边学边练 之三 数据库
2009/11/26 PHP
PHP CodeBase:将时间显示为"刚刚""n分钟/小时前"的方法详解
2013/06/06 PHP
php使用mb_check_encoding检查字符串在指定的编码里是否有效
2013/11/07 PHP
PHP去掉json字符串中的反斜杠\及去掉双引号前的反斜杠
2015/09/30 PHP
Yii框架ACF(accessController)简单权限控制操作示例
2019/04/26 PHP
jQuery开发者都需要知道的5个小技巧
2010/01/08 Javascript
js定义对象或数组直接量时各浏览器对多余逗号的处理(json)
2011/03/05 Javascript
基于pthread_create,readlink,getpid等函数的学习与总结
2013/07/17 Javascript
jquery遍历之parent()和parents()的区别及parentsUntil()方法详解
2013/12/02 Javascript
点评js异步加载的4种方式
2015/12/22 Javascript
VUEJS实战之构建基础并渲染出列表(1)
2016/06/13 Javascript
基于MVC+EasyUI的web开发框架之使用云打印控件C-Lodop打印页面或套打报关运单信息
2016/08/29 Javascript
利用Angularjs实现幻灯片效果
2016/09/07 Javascript
JavaScript生成验证码并实现验证功能
2016/09/24 Javascript
JavaScript和jQuery获取input框的绝对位置实现方法
2016/10/13 Javascript
web打印小结
2017/01/11 Javascript
canvas绘制环形进度条
2017/02/23 Javascript
vue与bootstrap实现时间选择器的示例代码
2017/08/26 Javascript
微信小程序实时聊天WebSocket
2018/07/05 Javascript
JavaScript中变量、指针和引用功能与操作示例
2018/08/04 Javascript
Vue.js@2.6.10更新内置错误处机制Fundebug同步支持相应错误监控
2019/05/13 Javascript
jQuery属性选择器用法实例分析
2019/06/28 jQuery
Layui实现主窗口和Iframe层参数传递
2019/11/14 Javascript
Python计算开方、立方、圆周率,精确到小数点后任意位的方法
2018/07/17 Python
Python 20行简单实现有道在线翻译的详解
2019/05/15 Python
Python全栈之列表数据类型详解
2019/10/01 Python
python 实现rolling和apply函数的向下取值操作
2020/06/08 Python
Bootstrap File Input文件上传组件
2020/12/01 HTML / CSS
VICHY薇姿英国官网:全球专业敏感肌护肤领先品牌
2017/07/04 全球购物
巴西体育用品商店:Lojão dos Esportes
2018/07/21 全球购物
Amcal中文官网:澳洲综合性连锁药房
2019/03/28 全球购物
意向书范文
2014/03/31 职场文书
难忘的一课教学反思
2014/04/30 职场文书
2014年教师业务工作总结
2014/12/19 职场文书
致地震灾区的慰问信
2015/03/23 职场文书
Python进程池与进程锁之语法学习
2022/04/11 Python