python PyTorch参数初始化和Finetune


Posted in Python onFebruary 11, 2018

前言

这篇文章算是论坛PyTorch Forums关于参数初始化和finetune的总结,也是我在写代码中用的算是“最佳实践”吧。最后希望大家没事多逛逛论坛,有很多高质量的回答。

参数初始化

参数的初始化其实就是对参数赋值。而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了。这就是PyTorch简洁高效所在。

python PyTorch参数初始化和Finetune

所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法是PyTorch作者所推崇的:

def weight_init(m):
# 使用isinstance来判断m属于什么类型
  if isinstance(m, nn.Conv2d):
    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    m.weight.data.normal_(0, math.sqrt(2. / n))
  elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其实都是Variable,为了能学习参数以及后向传播
    m.weight.data.fill_(1)
    m.bias.data.zero_()

Finetune

往往在加载了预训练模型的参数之后,我们需要finetune模型,可以使用不同的方式finetune。

局部微调

有时候我们加载了训练模型后,只想调节最后的几层,其他层不训练。其实不训练也就意味着不进行梯度计算,PyTorch中提供的requires_grad使得对训练的控制变得非常简单。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
  param.requires_grad = False
# 替换最后的全连接层, 改为训练100类
# 新构造的模块的参数默认requires_grad为True
model.fc = nn.Linear(512, 100)

# 只优化最后的分类层
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

全局微调

有时候我们需要对全局都进行finetune,只不过我们希望改换过的层和其他层的学习速率不一样,这时候我们可以把其他层和新层在optimizer中单独赋予不同的学习速率。比如:

ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,
           model.parameters())

optimizer = torch.optim.SGD([
      {'params': base_params},
      {'params': model.fc.parameters(), 'lr': 1e-3}
      ], lr=1e-2, momentum=0.9)

其中base_params使用1e-3来训练,model.fc.parameters使用1e-2来训练,momentum是二者共有的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Pyramid添加Middleware的方法实例
Nov 27 Python
利用Python绘制MySQL数据图实现数据可视化
Mar 30 Python
在Python中实现贪婪排名算法的教程
Apr 17 Python
在Django框架中编写Contact表单的教程
Jul 17 Python
Django实现快速分页的方法实例
Oct 22 Python
Python实现定时精度可调节的定时器
Apr 15 Python
Python设计模式之享元模式原理与用法实例分析
Jan 11 Python
不到20行代码用Python做一个智能聊天机器人
Apr 19 Python
Python实现爬取亚马逊数据并打印出Excel文件操作示例
May 16 Python
用sqlalchemy构建Django连接池的实例
Aug 29 Python
PyCharm更改字体和界面样式的方法步骤
Sep 27 Python
pytorch 实现在一个优化器中设置多个网络参数的例子
Feb 20 Python
Python装饰器用法示例小结
Feb 11 #Python
python PyTorch预训练示例
Feb 11 #Python
TensorFlow中权重的随机初始化的方法
Feb 11 #Python
python的staticmethod与classmethod实现实例代码
Feb 11 #Python
Python语言的变量认识及操作方法
Feb 11 #Python
利用Opencv中Houghline方法实现直线检测
Feb 11 #Python
tensorflow输出权重值和偏差的方法
Feb 10 #Python
You might like
采集邮箱的php代码(抓取网页中的邮箱地址)
2012/07/17 PHP
从PHP $_SERVER相关参数判断是否支持Rewrite模块
2013/09/26 PHP
php将html转成wml的WAP标记语言实例
2015/07/08 PHP
PHP排序算法之快速排序(Quick Sort)及其优化算法详解
2018/04/21 PHP
jQuery中append()方法用法实例
2015/01/08 Javascript
Nodejs基于LRU算法实现的缓存处理操作示例
2017/03/17 NodeJs
从零开始学习Node.js系列教程之基于connect和express框架的多页面实现数学运算示例
2017/04/13 Javascript
详解NODEJS的http实现
2018/01/04 NodeJs
基于vue cli重构多页面脚手架过程详解
2018/01/23 Javascript
JS求Number类型数组中最大元素方法
2018/04/08 Javascript
vue router 源码概览案例分析
2018/10/09 Javascript
vue ssr 实现方式(学习笔记)
2019/01/18 Javascript
如何优雅地在vue中添加权限控制示例详解
2019/03/07 Javascript
JS异步宏队列微队列原理详解
2020/09/09 Javascript
python使用mailbox打印电子邮件的方法
2015/04/30 Python
Python3学习urllib的使用方法示例
2017/11/29 Python
解决python中遇到字典里key值为None的情况,取不出来的问题
2018/10/17 Python
Python实现微信中找回好友、群聊用户撤回的消息功能示例
2019/08/23 Python
Python Subprocess模块原理及实例
2019/08/26 Python
pytorch 获取tensor维度信息示例
2020/01/03 Python
python数字类型math库原理解析
2020/03/02 Python
解决PyCharm不在run输出运行结果而不是再Console里输出的问题
2020/09/21 Python
django中ImageField的使用详解
2020/12/21 Python
浅谈css3中calc在less编译时被计算的解决办法
2017/12/04 HTML / CSS
HTML5+CSS3绘制锯齿状的矩形
2016/03/01 HTML / CSS
凯特·丝蓓英国官网:Kate Spade英国
2016/11/07 全球购物
印度尼西亚最大的电商平台:Tokopedia(印尼版淘宝)
2017/12/02 全球购物
美国家居装饰网上商店:Lulu & Georgia
2019/09/14 全球购物
AJAX的优缺点都有什么
2015/08/18 面试题
大学生农村教师实习自我鉴定
2013/09/21 职场文书
做一个有道德的人活动实施方案
2014/08/23 职场文书
村支部书记群众路线对照检查材料思想汇报
2014/10/08 职场文书
计划生育工作汇报
2014/10/28 职场文书
药房管理制度范本
2015/08/06 职场文书
导游词之镜泊湖
2019/12/09 职场文书
Python学习之迭代器详解
2022/04/01 Python