浅谈PyTorch的可重复性问题(如何使实验结果可复现)


Posted in Python onFebruary 20, 2020

由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。

许多博客都有介绍如何解决这个问题,但是很多都不够全面,往往不能保证结果精确一致。我经过许多调研和实验,总结了以下方法,记录下来。

全部设置可以分为三部分:

1. CUDNN

cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:

from torch.backends import cudnn
cudnn.benchmark = False      # if benchmark=True, deterministic will be False
cudnn.deterministic = True

不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。

2. Pytorch

torch.manual_seed(seed)      # 为CPU设置随机种子
torch.cuda.manual_seed(seed)    # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)  # 为所有GPU设置随机种子

3. Python & Numpy

如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。

import random
import numpy as np
random.seed(seed)
np.random.seed(seed)

最后,关于dataloader:

注意,如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。也就是说,改变num_workers参数,也会对实验结果产生影响。目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。

对于不同线程的随机数种子设置,主要通过DataLoader的worker_init_fn参数来实现。默认情况下使用线程ID作为随机数种子。如果需要自己设定,可以参考以下代码:

GLOBAL_SEED = 1
 
def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
 
GLOBAL_WORKER_ID = None
def worker_init_fn(worker_id):
  global GLOBAL_WORKER_ID
  GLOBAL_WORKER_ID = worker_id
  set_seed(GLOBAL_SEED + worker_id)
 
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2, worker_init_fn=worker_init_fn)

以上这篇浅谈PyTorch的可重复性问题(如何使实验结果可复现)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python显示天气预报
Mar 02 Python
Python中Collection的使用小技巧
Aug 18 Python
Django发送html邮件的方法
May 26 Python
基于python元祖与字典与集合的粗浅认识
Aug 23 Python
python 随机数使用方法,推导以及字符串,双色球小程序实例
Sep 12 Python
Python3实战之爬虫抓取网易云音乐的热门评论
Oct 09 Python
python3写爬取B站视频弹幕功能
Dec 22 Python
Python元组知识点总结
Feb 18 Python
python查看数据类型的方法
Oct 12 Python
Python Web静态服务器非堵塞模式实现方法示例
Nov 21 Python
python通过对字典的排序,对json字段进行排序的实例
Feb 27 Python
Django 项目布局方法(值得推荐)
Mar 22 Python
pytorch 模型的train模式与eval模式实例
Feb 20 #Python
pytorch dataloader 取batch_size时候出现bug的解决方式
Feb 20 #Python
pytorch 使用加载训练好的模型做inference
Feb 20 #Python
pytorch中的inference使用实例
Feb 20 #Python
python encrypt 实现AES加密的实例详解
Feb 20 #Python
Python关于反射的实例代码分享
Feb 20 #Python
Python3监控疫情的完整代码
Feb 20 #Python
You might like
PHP下载文件函数与用法示例
2019/09/27 PHP
利用ASP发送和接收XML数据的处理方法与代码
2007/11/13 Javascript
Fastest way to build an HTML string(拼装html字符串的最快方法)
2011/08/20 Javascript
jQuery下的动画处理总结
2013/10/10 Javascript
jquery对单选框,多选框,文本框等常见操作小结
2014/01/08 Javascript
jQuery移除tr无效的解决方法(tr是动态添加)
2014/09/22 Javascript
JS实现不规则TAB选项卡效果代码
2015/09/16 Javascript
jQuery插件开发精品教程(让你的jQuery更上一个台阶)
2015/11/07 Javascript
自学实现angularjs依赖注入
2016/12/20 Javascript
详解AngularJS之$window窗口对象
2018/01/17 Javascript
Vue2.0中集成UEditor富文本编辑器的方法
2018/03/03 Javascript
Vue中v-for的数据分组实例
2018/03/07 Javascript
vue移动端下拉刷新和上拉加载的实现代码
2018/09/08 Javascript
JS去除字符串最后的逗号实例分析【四种方法】
2019/06/20 Javascript
jQuery事件委托代码实践详解
2019/06/21 jQuery
解决vue 表格table列求和的问题
2019/11/06 Javascript
[02:49]2014DOTA2电竞也是体育项目! 势要把荣誉带回中国!
2014/07/20 DOTA
[01:10]DOTA2次级职业联赛 - U5战队宣传片
2014/12/01 DOTA
使用Python装饰器在Django框架下去除冗余代码的教程
2015/04/16 Python
python识别图像并提取文字的实现方法
2019/06/28 Python
python 通过视频url获取视频的宽高方式
2019/12/10 Python
Python递归及尾递归优化操作实例分析
2020/02/01 Python
Numpy中的数组搜索中np.where方法详细介绍
2021/01/08 Python
css3实现超炫风车特效
2014/11/12 HTML / CSS
html5 Canvas画图教程(8)—canvas里画曲线之bezierCurveTo方法
2013/01/09 HTML / CSS
怎样在 Applet 中建立自己的菜单(MenuBar/Menu)?
2012/06/20 面试题
社区庆中秋节活动方案
2014/02/07 职场文书
保安公司服务承诺书
2014/05/28 职场文书
生日宴会策划方案
2014/06/03 职场文书
学校食品安全责任书
2015/01/29 职场文书
2015年护士节慰问信
2015/03/23 职场文书
2015年档案管理工作总结
2015/04/08 职场文书
中学教师师德师风承诺书
2015/04/28 职场文书
建党伟业电影观后感
2015/06/01 职场文书
详解JS ES6编码规范
2021/05/07 Javascript
教你用python实现12306余票查询
2021/06/30 Python