浅谈PyTorch的可重复性问题(如何使实验结果可复现)


Posted in Python onFebruary 20, 2020

由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。

许多博客都有介绍如何解决这个问题,但是很多都不够全面,往往不能保证结果精确一致。我经过许多调研和实验,总结了以下方法,记录下来。

全部设置可以分为三部分:

1. CUDNN

cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:

from torch.backends import cudnn
cudnn.benchmark = False      # if benchmark=True, deterministic will be False
cudnn.deterministic = True

不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。

2. Pytorch

torch.manual_seed(seed)      # 为CPU设置随机种子
torch.cuda.manual_seed(seed)    # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)  # 为所有GPU设置随机种子

3. Python & Numpy

如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。

import random
import numpy as np
random.seed(seed)
np.random.seed(seed)

最后,关于dataloader:

注意,如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。也就是说,改变num_workers参数,也会对实验结果产生影响。目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。

对于不同线程的随机数种子设置,主要通过DataLoader的worker_init_fn参数来实现。默认情况下使用线程ID作为随机数种子。如果需要自己设定,可以参考以下代码:

GLOBAL_SEED = 1
 
def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
 
GLOBAL_WORKER_ID = None
def worker_init_fn(worker_id):
  global GLOBAL_WORKER_ID
  GLOBAL_WORKER_ID = worker_id
  set_seed(GLOBAL_SEED + worker_id)
 
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2, worker_init_fn=worker_init_fn)

以上这篇浅谈PyTorch的可重复性问题(如何使实验结果可复现)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅谈python中的变量默认是什么类型
Sep 11 Python
Python中的sort()方法使用基础教程
Jan 08 Python
Python入门_学会创建并调用函数的方法
May 16 Python
Python实现简单遗传算法(SGA)
Jan 29 Python
用python实现对比两张图片的不同
Feb 05 Python
详解python的sorted函数对字典按key排序和按value排序
Aug 10 Python
python3+django2开发一个简单的人员管理系统过程详解
Jul 23 Python
python代码打印100-999之间的回文数示例
Nov 24 Python
python集成开发环境配置(pycharm)
Feb 14 Python
python删除指定列或多列单个或多个内容实例
Jun 28 Python
python使用XPath解析数据爬取起点小说网数据
Apr 22 Python
pytorch常用数据类型所占字节数对照表一览
May 17 Python
pytorch 模型的train模式与eval模式实例
Feb 20 #Python
pytorch dataloader 取batch_size时候出现bug的解决方式
Feb 20 #Python
pytorch 使用加载训练好的模型做inference
Feb 20 #Python
pytorch中的inference使用实例
Feb 20 #Python
python encrypt 实现AES加密的实例详解
Feb 20 #Python
Python关于反射的实例代码分享
Feb 20 #Python
Python3监控疫情的完整代码
Feb 20 #Python
You might like
关于我转生变成史莱姆这档事:第二季PV上线,萌王2021年回归
2020/05/06 日漫
PHP中usort在值相同时改变原始位置问题的解决方法
2011/11/27 PHP
解析PHP中一些可能会被忽略的问题
2013/06/21 PHP
PHP文件去掉PHP注释空格的函数分析(PHP代码压缩)
2013/07/02 PHP
php调用c接口无错版介绍
2014/03/11 PHP
ThinkPHP CURD方法之data方法详解
2014/06/18 PHP
Javascript的匿名函数小结
2009/12/31 Javascript
Jquery知识点一 Jquery的ready和Dom的onload的区别
2011/01/15 Javascript
如何确保JavaScript的执行顺序 之jQuery.html并非万能钥匙
2011/03/03 Javascript
js 遍历对象的属性的代码
2011/12/29 Javascript
JavaScript异步编程:异步数据收集的具体方法
2013/08/19 Javascript
9款2014最热门jQuery实用特效推荐
2014/12/07 Javascript
JavaScript function函数种类详解
2016/02/22 Javascript
Bootstrap CDN和本地化环境搭建
2016/10/26 Javascript
Vue自定义事件(详解)
2017/08/19 Javascript
vue实现页面加载动画效果
2017/09/19 Javascript
微信小程序实现折叠展开效果
2018/07/19 Javascript
ES6数组与对象的解构赋值详解
2019/06/14 Javascript
简单学习5种处理Vue.js异常的方法
2019/06/17 Javascript
JS数组reduce()方法原理及使用技巧解析
2020/07/14 Javascript
JavaScript实现沿五角星形线摆动的小圆实例详解
2020/07/28 Javascript
[01:19:23]2018DOTA2亚洲邀请赛 4.5 淘汰赛 Mineski vs VG 第二场
2018/04/06 DOTA
python二分法实现实例
2013/11/21 Python
python网络编程学习笔记(七):HTML和XHTML解析(HTMLParser、BeautifulSoup)
2014/06/09 Python
Django中url的反向查询的方法
2018/03/14 Python
Python字符串内置函数功能与用法总结
2019/04/16 Python
Python 写了个新型冠状病毒疫情传播模拟程序
2020/02/14 Python
python中什么是面向对象
2020/06/11 Python
CSS3实现可关闭的下拉手风琴菜单效果
2015/08/31 HTML / CSS
美国最大的城市服装和运动鞋零售商:Jimmy Jazz
2016/11/19 全球购物
信用社实习人员自我鉴定
2013/09/20 职场文书
老教师工作总结的自我评价
2013/09/27 职场文书
机关干部作风建设剖析材料
2014/10/23 职场文书
先进工作者事迹材料
2014/12/23 职场文书
结婚十年感言
2015/07/31 职场文书
九年级英语教学反思
2016/02/15 职场文书