浅谈PyTorch的可重复性问题(如何使实验结果可复现)


Posted in Python onFebruary 20, 2020

由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。

许多博客都有介绍如何解决这个问题,但是很多都不够全面,往往不能保证结果精确一致。我经过许多调研和实验,总结了以下方法,记录下来。

全部设置可以分为三部分:

1. CUDNN

cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:

from torch.backends import cudnn
cudnn.benchmark = False      # if benchmark=True, deterministic will be False
cudnn.deterministic = True

不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。

2. Pytorch

torch.manual_seed(seed)      # 为CPU设置随机种子
torch.cuda.manual_seed(seed)    # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)  # 为所有GPU设置随机种子

3. Python & Numpy

如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。

import random
import numpy as np
random.seed(seed)
np.random.seed(seed)

最后,关于dataloader:

注意,如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。也就是说,改变num_workers参数,也会对实验结果产生影响。目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。

对于不同线程的随机数种子设置,主要通过DataLoader的worker_init_fn参数来实现。默认情况下使用线程ID作为随机数种子。如果需要自己设定,可以参考以下代码:

GLOBAL_SEED = 1
 
def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
 
GLOBAL_WORKER_ID = None
def worker_init_fn(worker_id):
  global GLOBAL_WORKER_ID
  GLOBAL_WORKER_ID = worker_id
  set_seed(GLOBAL_SEED + worker_id)
 
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2, worker_init_fn=worker_init_fn)

以上这篇浅谈PyTorch的可重复性问题(如何使实验结果可复现)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python简单日志处理类分享
Feb 14 Python
python中pandas.DataFrame的简单操作方法(创建、索引、增添与删除)
Mar 12 Python
Python 逐行分割大txt文件的方法
Oct 10 Python
[原创]windows下Anaconda的安装与配置正解(Anaconda入门教程)
Apr 05 Python
python处理两种分隔符的数据集方法
Dec 12 Python
python3 property装饰器实现原理与用法示例
May 15 Python
对django 模型 unique together的示例讲解
Aug 06 Python
浅析python中while循环和for循环
Nov 19 Python
Pytorch Tensor基本数学运算详解
Dec 30 Python
python数据类型强制转换实例详解
Jun 22 Python
DRF使用simple JWT身份验证的实现
Jan 14 Python
Python turtle实现贪吃蛇游戏
Jun 18 Python
pytorch 模型的train模式与eval模式实例
Feb 20 #Python
pytorch dataloader 取batch_size时候出现bug的解决方式
Feb 20 #Python
pytorch 使用加载训练好的模型做inference
Feb 20 #Python
pytorch中的inference使用实例
Feb 20 #Python
python encrypt 实现AES加密的实例详解
Feb 20 #Python
Python关于反射的实例代码分享
Feb 20 #Python
Python3监控疫情的完整代码
Feb 20 #Python
You might like
ThinkPHP自动验证失败的解决方法
2011/06/09 PHP
php超快高效率统计大文件行数
2015/07/05 PHP
Zend Framework实现具有基本功能的留言本(附demo源码下载)
2016/03/22 PHP
发布一个高效的JavaScript分析、压缩工具 JavaScript Analyser
2007/11/30 Javascript
在JavaScript中获取请求的URL参数
2010/12/22 Javascript
Jquery获取元素的父容器对象示例代码
2014/02/10 Javascript
对Web开发中前端框架与前端类库的一些思考
2015/03/27 Javascript
如何改进javascript代码的性能
2015/04/02 Javascript
Bootstrap下拉菜单效果实例代码分享
2016/06/30 Javascript
Javascript学习之谈谈JS的全局变量跟局部变量(推荐)
2016/08/28 Javascript
javascript 内置对象及常见API详细介绍
2016/11/01 Javascript
原生JS实现垂直手风琴效果
2017/02/19 Javascript
JS实现的点击表头排序功能示例
2017/03/27 Javascript
JS库之Three.js 简易入门教程(详解之一)
2017/09/13 Javascript
JS基于对象的特性实现去除数组中重复项功能详解
2017/11/17 Javascript
简单谈谈CommonsChunkPlugin抽取公共模块
2017/12/31 Javascript
p5.js入门教程之平滑过渡(Easing)
2018/03/16 Javascript
vue中Npm run build 根据环境传递参数方法来打包不同域名
2018/03/29 Javascript
小程序实现列表多个批量倒计时
2021/01/29 Javascript
vue canvas绘制矩形并解决由clearRec带来的闪屏问题
2019/09/02 Javascript
微信小程序实现轨迹回放的示例代码
2019/12/13 Javascript
vue-cli创建的项目中的gitHooks原理解析
2020/02/14 Javascript
JS中准确判断变量类型的方法
2020/06/01 Javascript
深入理解Django的自定义过滤器
2017/10/17 Python
python opencv将表格图片按照表格框线分割和识别
2019/10/30 Python
python标准库OS模块函数列表与实例全解
2020/03/10 Python
Python求区间正整数内所有素数之和的方法实例
2020/10/13 Python
css3使用animation属性实现炫酷效果(推荐)
2020/02/04 HTML / CSS
暇步士官网:Hush Puppies
2016/09/22 全球购物
阿巴庭院:Abba Patio
2019/06/18 全球购物
Linux常见面试题
2013/03/18 面试题
事业单位个人应聘自荐信
2013/09/21 职场文书
安全生产责任书范本
2014/04/15 职场文书
运动会加油稿20字
2014/11/15 职场文书
初二数学教学反思
2016/02/17 职场文书
mysql 直接拷贝data 目录下文件还原数据的实现
2021/07/25 MySQL