python PyTorch参数初始化和Finetune


Posted in Python onFebruary 11, 2018

前言

这篇文章算是论坛PyTorch Forums关于参数初始化和finetune的总结,也是我在写代码中用的算是“最佳实践”吧。最后希望大家没事多逛逛论坛,有很多高质量的回答。

参数初始化

参数的初始化其实就是对参数赋值。而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了。这就是PyTorch简洁高效所在。

python PyTorch参数初始化和Finetune

所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法是PyTorch作者所推崇的:

def weight_init(m):
# 使用isinstance来判断m属于什么类型
  if isinstance(m, nn.Conv2d):
    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    m.weight.data.normal_(0, math.sqrt(2. / n))
  elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其实都是Variable,为了能学习参数以及后向传播
    m.weight.data.fill_(1)
    m.bias.data.zero_()

Finetune

往往在加载了预训练模型的参数之后,我们需要finetune模型,可以使用不同的方式finetune。

局部微调

有时候我们加载了训练模型后,只想调节最后的几层,其他层不训练。其实不训练也就意味着不进行梯度计算,PyTorch中提供的requires_grad使得对训练的控制变得非常简单。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
  param.requires_grad = False
# 替换最后的全连接层, 改为训练100类
# 新构造的模块的参数默认requires_grad为True
model.fc = nn.Linear(512, 100)

# 只优化最后的分类层
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

全局微调

有时候我们需要对全局都进行finetune,只不过我们希望改换过的层和其他层的学习速率不一样,这时候我们可以把其他层和新层在optimizer中单独赋予不同的学习速率。比如:

ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,
           model.parameters())

optimizer = torch.optim.SGD([
      {'params': base_params},
      {'params': model.fc.parameters(), 'lr': 1e-3}
      ], lr=1e-2, momentum=0.9)

其中base_params使用1e-3来训练,model.fc.parameters使用1e-2来训练,momentum是二者共有的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中请使用isinstance()判断变量类型
Aug 25 Python
详解python 字符串和日期之间转换 StringAndDate
May 04 Python
Python实现绘制双柱状图并显示数值功能示例
Jun 23 Python
python使用matplotlib绘制热图
Nov 07 Python
pandas 数据归一化以及行删除例程的方法
Nov 10 Python
使用python判断你是青少年还是老年人
Nov 29 Python
对python 多个分隔符split 的实例详解
Dec 20 Python
关于Python字符串显示u...的解决方式
Mar 06 Python
python小白切忌乱用表达式
May 29 Python
python的链表基础知识点
Sep 13 Python
OpenCV灰度化之后图片为绿色的解决
Dec 01 Python
python 如何读、写、解析CSV文件
Mar 03 Python
Python装饰器用法示例小结
Feb 11 #Python
python PyTorch预训练示例
Feb 11 #Python
TensorFlow中权重的随机初始化的方法
Feb 11 #Python
python的staticmethod与classmethod实现实例代码
Feb 11 #Python
Python语言的变量认识及操作方法
Feb 11 #Python
利用Opencv中Houghline方法实现直线检测
Feb 11 #Python
tensorflow输出权重值和偏差的方法
Feb 10 #Python
You might like
解析php获取字符串的编码格式的方法(函数)
2013/06/21 PHP
php中的curl_multi系列函数使用例子
2014/07/29 PHP
php文件扩展名判断及获取文件扩展名的N种方法
2015/09/12 PHP
PHP识别二维码的方法(php-zbarcode安装与使用)
2016/07/07 PHP
Laravel服务容器绑定的几种方法总结
2020/06/14 PHP
jQuery页面图片伴随滚动条逐渐显示的小例子
2013/03/21 Javascript
JS中获取数据库中的值的方法
2013/07/14 Javascript
js带按钮的提示框可供选择示例代码
2013/09/17 Javascript
js 显示base64编码的二进制流网页图片
2014/04/04 Javascript
JavaScript中停止执行setInterval和setTimeout事件的方法
2015/05/14 Javascript
Knockout结合Bootstrap创建动态UI实现产品列表管理
2016/09/14 Javascript
AngularJS中watch监听用法分析
2016/11/04 Javascript
nodejs实现邮件发送服务实例分享
2017/03/29 NodeJs
[js高手之路]图解javascript的原型(prototype)对象,原型链实例
2017/08/28 Javascript
NodeJS爬虫实例之糗事百科
2017/12/14 NodeJs
vue 通过下拉框组件学习vue中的父子通讯
2017/12/19 Javascript
基于axios封装fetch方法及调用实例
2018/02/05 Javascript
webpack中使用iconfont字体图标的方法
2018/02/22 Javascript
JavaScript DOM元素常见操作详解【添加、删除、修改等】
2018/05/09 Javascript
[54:26]完美世界DOTA2联赛PWL S3 Forest vs Rebirth 第一场 12.10
2020/12/12 DOTA
Python2.x版本中maketrans()方法的使用介绍
2015/05/19 Python
python  创建一个保留重复值的列表的补码
2018/10/15 Python
解决Python二维数组赋值问题
2019/11/28 Python
keras实现基于孪生网络的图片相似度计算方式
2020/06/11 Python
python基于opencv 实现图像时钟
2021/01/04 Python
CSS3使用transition实现的鼠标悬停淡入淡出
2015/01/09 HTML / CSS
使用CSS3的appearance属性改变元素的外观的方法
2015/12/12 HTML / CSS
html5 Canvas画图教程(6)—canvas里画曲线之arcTo方法
2013/01/09 HTML / CSS
英国现代家具和装饰网站:PN Home
2018/08/16 全球购物
Homestay中文官网:全球寄宿家庭
2018/10/18 全球购物
建筑工程实习自我鉴定
2013/09/19 职场文书
车间主任岗位职责
2014/03/16 职场文书
食品安全宣传标语
2014/06/07 职场文书
逃课检讨书范文
2015/05/06 职场文书
教你用python实现一个无界面的小型图书管理系统
2021/05/21 Python
SQL Server内存机制浅探
2022/04/06 SQL Server