浅谈PyTorch的可重复性问题(如何使实验结果可复现)


Posted in Python onFebruary 20, 2020

由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。

许多博客都有介绍如何解决这个问题,但是很多都不够全面,往往不能保证结果精确一致。我经过许多调研和实验,总结了以下方法,记录下来。

全部设置可以分为三部分:

1. CUDNN

cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:

from torch.backends import cudnn
cudnn.benchmark = False      # if benchmark=True, deterministic will be False
cudnn.deterministic = True

不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。

2. Pytorch

torch.manual_seed(seed)      # 为CPU设置随机种子
torch.cuda.manual_seed(seed)    # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)  # 为所有GPU设置随机种子

3. Python & Numpy

如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。

import random
import numpy as np
random.seed(seed)
np.random.seed(seed)

最后,关于dataloader:

注意,如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。也就是说,改变num_workers参数,也会对实验结果产生影响。目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。

对于不同线程的随机数种子设置,主要通过DataLoader的worker_init_fn参数来实现。默认情况下使用线程ID作为随机数种子。如果需要自己设定,可以参考以下代码:

GLOBAL_SEED = 1
 
def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
 
GLOBAL_WORKER_ID = None
def worker_init_fn(worker_id):
  global GLOBAL_WORKER_ID
  GLOBAL_WORKER_ID = worker_id
  set_seed(GLOBAL_SEED + worker_id)
 
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2, worker_init_fn=worker_init_fn)

以上这篇浅谈PyTorch的可重复性问题(如何使实验结果可复现)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用python + hadoop streaming 分布式编程(一) -- 原理介绍,样例程序与本地调试
Jul 14 Python
python通过邮件服务器端口发送邮件的方法
Apr 30 Python
PyCharm 常用快捷键和设置方法
Dec 20 Python
pytorch中tensor的合并与截取方法
Jul 26 Python
Python实现查询某个目录下修改时间最新的文件示例
Aug 29 Python
Linux下升级安装python3.8并配置pip及yum的教程
Jan 02 Python
tensorflow 利用expand_dims和squeeze扩展和压缩tensor维度方式
Feb 07 Python
Python运行异常管理解决方案
Mar 09 Python
pycharm部署、配置anaconda环境的教程
Mar 24 Python
python ETL工具 pyetl
Jun 07 Python
彻底搞懂python 迭代器和生成器
Sep 07 Python
Python实现视频自动打码的示例代码
Apr 08 Python
pytorch 模型的train模式与eval模式实例
Feb 20 #Python
pytorch dataloader 取batch_size时候出现bug的解决方式
Feb 20 #Python
pytorch 使用加载训练好的模型做inference
Feb 20 #Python
pytorch中的inference使用实例
Feb 20 #Python
python encrypt 实现AES加密的实例详解
Feb 20 #Python
Python关于反射的实例代码分享
Feb 20 #Python
Python3监控疫情的完整代码
Feb 20 #Python
You might like
使用php+apc实现上传进度条且在IE7下不显示的问题解决方法
2013/04/25 PHP
php并发加锁示例
2016/10/17 PHP
PHP如何读取由JavaScript设置的Cookie
2017/03/22 PHP
PHP设计模式之外观模式(Facade)入门与应用详解
2019/12/13 PHP
js的闭包的一个示例说明
2008/11/18 Javascript
js判断两个日期是否相等的方法
2013/09/10 Javascript
让jQuery与其他JavaScript库并存避免冲突的方法
2013/12/23 Javascript
JQuery $.each遍历JavaScript数组对象实例
2014/09/01 Javascript
node.js中的http.response.removeHeader方法使用说明
2014/12/14 Javascript
javascript实现瀑布流自适应遇到的问题及解决方案
2015/01/28 Javascript
Jquery使用css方法改变样式实例
2015/05/18 Javascript
js实现发送验证码后的倒计时功能
2015/05/28 Javascript
z-blog SyntaxHighlighter 长代码无法换行解决办法(基于jquery)
2015/11/18 Javascript
利用VUE框架,实现列表分页功能示例代码
2017/01/12 Javascript
Angular 如何使用第三方库的方法
2018/04/18 Javascript
推荐15个最好用的JavaScript代码压缩工具
2019/02/13 Javascript
VUE 组件转换为微信小程序组件的方法
2019/11/06 Javascript
JS加载解析Markdown文档过程详解
2020/05/19 Javascript
JavaScript直接调用函数与call调用的区别实例分析
2020/05/22 Javascript
VUE使用axios调用后台API接口的方法
2020/08/03 Javascript
微信小程序视频弹幕发送功能的实现
2020/12/28 Javascript
[49:02]KG vs Infamous 2019国际邀请赛淘汰赛 败者组BO1 8.20.mp4
2020/07/19 DOTA
Python random模块常用方法
2014/11/03 Python
python-itchat 获取微信群用户信息的实例
2019/02/21 Python
Tensorflow实现酸奶销量预测分析
2019/07/19 Python
python3 requests库文件上传与下载实现详解
2019/08/22 Python
Python虚拟环境的创建和包下载过程分析
2020/06/19 Python
全球领先的中国制造商品在线批发平台:DHgate
2020/01/28 全球购物
通用C#笔试题附答案
2016/11/26 面试题
社会学专业求职信
2014/02/24 职场文书
乡镇消防安全责任书
2014/07/23 职场文书
国家奖学金获奖感言
2014/08/16 职场文书
乡镇党建工作总结2015
2015/05/19 职场文书
保外就医申请书范文
2015/08/06 职场文书
MySQL 数据库 增删查改、克隆、外键 等操作
2022/05/11 MySQL
Vue ECharts实现机舱座位选择展示功能
2022/05/15 Vue.js