python PyTorch参数初始化和Finetune


Posted in Python onFebruary 11, 2018

前言

这篇文章算是论坛PyTorch Forums关于参数初始化和finetune的总结,也是我在写代码中用的算是“最佳实践”吧。最后希望大家没事多逛逛论坛,有很多高质量的回答。

参数初始化

参数的初始化其实就是对参数赋值。而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了。这就是PyTorch简洁高效所在。

python PyTorch参数初始化和Finetune

所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法是PyTorch作者所推崇的:

def weight_init(m):
# 使用isinstance来判断m属于什么类型
  if isinstance(m, nn.Conv2d):
    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    m.weight.data.normal_(0, math.sqrt(2. / n))
  elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其实都是Variable,为了能学习参数以及后向传播
    m.weight.data.fill_(1)
    m.bias.data.zero_()

Finetune

往往在加载了预训练模型的参数之后,我们需要finetune模型,可以使用不同的方式finetune。

局部微调

有时候我们加载了训练模型后,只想调节最后的几层,其他层不训练。其实不训练也就意味着不进行梯度计算,PyTorch中提供的requires_grad使得对训练的控制变得非常简单。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
  param.requires_grad = False
# 替换最后的全连接层, 改为训练100类
# 新构造的模块的参数默认requires_grad为True
model.fc = nn.Linear(512, 100)

# 只优化最后的分类层
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

全局微调

有时候我们需要对全局都进行finetune,只不过我们希望改换过的层和其他层的学习速率不一样,这时候我们可以把其他层和新层在optimizer中单独赋予不同的学习速率。比如:

ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,
           model.parameters())

optimizer = torch.optim.SGD([
      {'params': base_params},
      {'params': model.fc.parameters(), 'lr': 1e-3}
      ], lr=1e-2, momentum=0.9)

其中base_params使用1e-3来训练,model.fc.parameters使用1e-2来训练,momentum是二者共有的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的引用和拷贝浅析
Nov 22 Python
Python中的pprint折腾记
Jan 21 Python
Python的Flask框架中实现简单的登录功能的教程
Apr 20 Python
Python urllib、urllib2、httplib抓取网页代码实例
May 09 Python
详解python并发获取snmp信息及性能测试
Mar 27 Python
Python实现PS滤镜碎片特效功能示例
Jan 24 Python
修复 Django migration 时遇到的问题解决
Jun 14 Python
Python进程间通信 multiProcessing Queue队列实现详解
Sep 23 Python
spyder 在控制台(console)执行python文件,debug python程序方式
Apr 20 Python
浅谈多卡服务器下隐藏部分 GPU 和 TensorFlow 的显存使用设置
Jun 30 Python
解决Django响应JsonResponse返回json格式数据报错问题
Aug 09 Python
如何完美的建立一个python项目
Oct 09 Python
Python装饰器用法示例小结
Feb 11 #Python
python PyTorch预训练示例
Feb 11 #Python
TensorFlow中权重的随机初始化的方法
Feb 11 #Python
python的staticmethod与classmethod实现实例代码
Feb 11 #Python
Python语言的变量认识及操作方法
Feb 11 #Python
利用Opencv中Houghline方法实现直线检测
Feb 11 #Python
tensorflow输出权重值和偏差的方法
Feb 10 #Python
You might like
php设计模式之正面模式实例分析【星际争霸游戏案例】
2020/03/24 PHP
Javascript注入技巧
2007/06/22 Javascript
js控制滚动条缓慢滚动到顶部实现代码
2013/03/20 Javascript
JS 作用域与作用域链详解
2015/04/07 Javascript
微信小程序 wxapp地图 map详解
2016/10/31 Javascript
jQuery学习笔记之入门
2016/12/14 Javascript
jQuery实现文档树效果
2017/02/20 Javascript
JS实现求数组起始项到终止项之和的方法【基于数组扩展函数】
2017/06/13 Javascript
JS按条件 serialize() 对应标签的使用方法
2017/07/24 Javascript
js字符串处理之绝妙的代码
2019/04/05 Javascript
浅谈ECMAScript 中的Array类型
2019/06/10 Javascript
layui实现图片虚拟路径上传,预览和删除的例子
2019/09/25 Javascript
nodejs对mongodb数据库的增加修删该查实例代码
2020/01/05 NodeJs
[46:48]DOTA2上海特级锦标赛A组小组赛#2 Secret VS CDEC第三局
2016/02/25 DOTA
Python判断字符串与大小写转换
2015/06/08 Python
python爬虫的工作原理
2017/03/05 Python
Python升级导致yum、pip报错的解决方法
2017/09/06 Python
Python 获取中文字拼音首个字母的方法
2018/11/28 Python
PyQt5组件读取参数的实例
2019/06/25 Python
python如何实现异步调用函数执行
2019/07/08 Python
Python 获取 datax 执行结果保存到数据库的方法
2019/07/11 Python
django 中的聚合函数,分组函数,F 查询,Q查询
2019/07/25 Python
基于python实现模拟数据结构模型
2020/06/12 Python
世界上获奖最多的手机镜头:Olloclip
2018/03/03 全球购物
美国伊甸园兄弟种子公司:Eden Brothers
2018/07/01 全球购物
美国摩托车头盔、零件、齿轮及配件商店:Cycle Gear
2019/06/12 全球购物
简述数组与指针的区别
2014/01/02 面试题
介绍一下Linux内核的排队自旋锁
2014/01/04 面试题
外包公司软件测试工程师
2014/11/01 面试题
什么时候用assert
2015/05/08 面试题
初一地理教学反思
2014/01/16 职场文书
服务承诺书格式
2014/05/21 职场文书
承诺书应该怎么写?
2019/09/10 职场文书
Redis模仿手机验证码发送的实现示例
2021/11/02 Redis
MybatisPlus EntityWrapper如何自定义SQL
2022/03/22 Java/Android
零基础学java之带返回值的方法的定义和调用
2022/04/10 Java/Android