浅谈PyTorch的可重复性问题(如何使实验结果可复现)


Posted in Python onFebruary 20, 2020

由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。

许多博客都有介绍如何解决这个问题,但是很多都不够全面,往往不能保证结果精确一致。我经过许多调研和实验,总结了以下方法,记录下来。

全部设置可以分为三部分:

1. CUDNN

cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:

from torch.backends import cudnn
cudnn.benchmark = False      # if benchmark=True, deterministic will be False
cudnn.deterministic = True

不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。

2. Pytorch

torch.manual_seed(seed)      # 为CPU设置随机种子
torch.cuda.manual_seed(seed)    # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)  # 为所有GPU设置随机种子

3. Python & Numpy

如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。

import random
import numpy as np
random.seed(seed)
np.random.seed(seed)

最后,关于dataloader:

注意,如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。也就是说,改变num_workers参数,也会对实验结果产生影响。目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。

对于不同线程的随机数种子设置,主要通过DataLoader的worker_init_fn参数来实现。默认情况下使用线程ID作为随机数种子。如果需要自己设定,可以参考以下代码:

GLOBAL_SEED = 1
 
def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
 
GLOBAL_WORKER_ID = None
def worker_init_fn(worker_id):
  global GLOBAL_WORKER_ID
  GLOBAL_WORKER_ID = worker_id
  set_seed(GLOBAL_SEED + worker_id)
 
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2, worker_init_fn=worker_init_fn)

以上这篇浅谈PyTorch的可重复性问题(如何使实验结果可复现)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python编程-将Python程序转化为可执行程序[整理]
Apr 09 Python
Python学习笔记(二)基础语法
Jun 06 Python
python BeautifulSoup设置页面编码的方法
Apr 03 Python
使用Python对Access读写操作
Mar 30 Python
分享一下Python数据分析常用的8款工具
Apr 29 Python
对pytorch中的梯度更新方法详解
Aug 20 Python
python垃圾回收机制(GC)原理解析
Dec 30 Python
python 爬取疫情数据的源码
Feb 09 Python
关于TensorFlow新旧版本函数接口变化详解
Feb 10 Python
python语言的优势是什么
Jun 17 Python
Python读写Excel表格的方法
Mar 02 Python
一些让Python代码简洁的实用技巧总结
Aug 23 Python
pytorch 模型的train模式与eval模式实例
Feb 20 #Python
pytorch dataloader 取batch_size时候出现bug的解决方式
Feb 20 #Python
pytorch 使用加载训练好的模型做inference
Feb 20 #Python
pytorch中的inference使用实例
Feb 20 #Python
python encrypt 实现AES加密的实例详解
Feb 20 #Python
Python关于反射的实例代码分享
Feb 20 #Python
Python3监控疫情的完整代码
Feb 20 #Python
You might like
PHP开启opcache提升代码性能
2015/04/26 PHP
php命令行(cli)下执行PHP脚本文件的相对路径的问题解决方法
2015/05/25 PHP
PHP实现求连续子数组最大和问题2种解决方法
2017/12/26 PHP
可兼容php5与php7的cURL文件上传功能实例分析
2018/05/11 PHP
PHP实现常用排序算法的方法
2020/02/05 PHP
使用SyntaxHighlighter实现HTML高亮显示代码的方法
2010/02/04 Javascript
javascipt匹配单行和多行注释的正则表达式
2013/11/20 Javascript
jquery改变disabled的boolean状态的三种方法
2013/12/13 Javascript
Egret引擎开发指南之编译项目
2014/09/03 Javascript
JQuery操作元素的css样式
2015/03/09 Javascript
JS实现先显示大图后自动收起显示小图的广告代码
2015/09/04 Javascript
详解maxlength属性在textarea里奇怪的表现
2015/12/27 Javascript
jQuery图片轮播插件——前端开发必看
2016/05/31 Javascript
js实现右键菜单功能
2016/11/28 Javascript
jQuery实现搜索页面关键字的功能
2017/02/16 Javascript
深入理解Javascript中的作用域链和闭包
2017/04/25 Javascript
NodeJs安装npm包一直失败的解决方法
2017/04/28 NodeJs
使用vue点击li,获取当前点击li父辈元素的属性值方法
2018/09/12 Javascript
Angular设置别名alias的方法
2018/11/08 Javascript
koa大型web项目中使用路由装饰器的方法示例
2019/04/02 Javascript
Vue computed 计算属性代码实例
2020/04/22 Javascript
Element InputNumber 计数器的实现示例
2020/08/03 Javascript
[12:29]《一刀刀一天》之DOTA全时刻19:蝙蝠骑士田伯光再度不举
2014/06/10 DOTA
Python使用稀疏矩阵节省内存实例
2014/06/27 Python
浅谈flask截获所有访问及before/after_request修饰器
2018/01/18 Python
浅析Python 中几种字符串格式化方法及其比较
2019/07/02 Python
Python3enumrate和range对比及示例详解
2019/07/13 Python
python利用tkinter实现屏保
2019/07/30 Python
Python操作redis和mongoDB的方法
2019/12/19 Python
PyCharm License Activation激活码失效问题的解决方法(图文详解)
2020/03/12 Python
Python判断三段线能否构成三角形的代码
2020/04/12 Python
html5的新增的标签和废除的标签简要概述
2013/02/20 HTML / CSS
普通党员对照检查材料
2014/08/28 职场文书
小学班级特色活动方案
2014/08/31 职场文书
2015公务员试用期工作总结
2014/12/12 职场文书
初中班主任工作总结2015
2015/05/13 职场文书