浅谈PyTorch的可重复性问题(如何使实验结果可复现)


Posted in Python onFebruary 20, 2020

由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。

许多博客都有介绍如何解决这个问题,但是很多都不够全面,往往不能保证结果精确一致。我经过许多调研和实验,总结了以下方法,记录下来。

全部设置可以分为三部分:

1. CUDNN

cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:

from torch.backends import cudnn
cudnn.benchmark = False      # if benchmark=True, deterministic will be False
cudnn.deterministic = True

不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。

2. Pytorch

torch.manual_seed(seed)      # 为CPU设置随机种子
torch.cuda.manual_seed(seed)    # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)  # 为所有GPU设置随机种子

3. Python & Numpy

如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。

import random
import numpy as np
random.seed(seed)
np.random.seed(seed)

最后,关于dataloader:

注意,如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。也就是说,改变num_workers参数,也会对实验结果产生影响。目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。

对于不同线程的随机数种子设置,主要通过DataLoader的worker_init_fn参数来实现。默认情况下使用线程ID作为随机数种子。如果需要自己设定,可以参考以下代码:

GLOBAL_SEED = 1
 
def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
 
GLOBAL_WORKER_ID = None
def worker_init_fn(worker_id):
  global GLOBAL_WORKER_ID
  GLOBAL_WORKER_ID = worker_id
  set_seed(GLOBAL_SEED + worker_id)
 
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2, worker_init_fn=worker_init_fn)

以上这篇浅谈PyTorch的可重复性问题(如何使实验结果可复现)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的Django REST框架中的序列化及请求和返回
Apr 11 Python
Python 实现 贪吃蛇大作战 代码分享
Sep 07 Python
python多线程socket编程之多客户端接入
Sep 12 Python
python3.5基于TCP实现文件传输
Mar 20 Python
Python使用gRPC传输协议教程
Oct 16 Python
Python从函数参数类型引出元组实例分析
May 28 Python
python 调试冷知识(小结)
Nov 11 Python
Anaconda+VSCode配置tensorflow开发环境的教程详解
Mar 30 Python
python程序如何进行保存
Jul 03 Python
从零开始的TensorFlow+VScode开发环境搭建的步骤(图文)
Aug 31 Python
史上最详细的Python打包成exe文件教程
Jan 17 Python
如何用python清洗文件中的数据
Jun 18 Python
pytorch 模型的train模式与eval模式实例
Feb 20 #Python
pytorch dataloader 取batch_size时候出现bug的解决方式
Feb 20 #Python
pytorch 使用加载训练好的模型做inference
Feb 20 #Python
pytorch中的inference使用实例
Feb 20 #Python
python encrypt 实现AES加密的实例详解
Feb 20 #Python
Python关于反射的实例代码分享
Feb 20 #Python
Python3监控疫情的完整代码
Feb 20 #Python
You might like
德劲1103的维修打理经验
2021/03/02 无线电
虹吸式咖啡壶操作
2021/03/03 冲泡冲煮
用文本作数据处理
2006/10/09 PHP
PHP date函数参数详解
2006/11/27 PHP
PHP 设置MySQL连接字符集的方法
2011/01/02 PHP
解析thinkphp中的导入文件标签
2013/06/20 PHP
PHP中list()函数用法实例简析
2016/01/08 PHP
PHP识别二维码的方法(php-zbarcode安装与使用)
2016/07/07 PHP
jquery如何改变html标签的样式(两种实现方法)
2013/01/16 Javascript
利用CSS、JavaScript及Ajax实现高效的图片预加载
2013/10/16 Javascript
基于zepto的移动端轻量级日期插件--date_picker
2016/03/04 Javascript
jQuery+css实现非常漂亮的水平导航菜单效果
2016/07/27 Javascript
轻松掌握JavaScript代理模式
2016/08/26 Javascript
在Docker快速部署Node.js应用的详细步骤
2016/09/02 Javascript
jQuery插件zTree实现更新根节点中第i个节点名称的方法示例
2017/03/08 Javascript
详解AngularJS2 Http服务
2017/06/26 Javascript
Vue.js 踩坑记之双向绑定
2018/05/03 Javascript
[01:46]新英雄登场
2019/09/10 DOTA
python 实现堆排序算法代码
2012/06/05 Python
Python深入学习之上下文管理器
2014/08/31 Python
Python找出list中最常出现元素的方法
2016/06/14 Python
Python的爬虫程序编写框架Scrapy入门学习教程
2016/07/02 Python
Python 转换RGB颜色值的示例代码
2019/10/13 Python
DJango的创建和使用详解(默认数据库sqlite3)
2019/11/18 Python
python+Django+pycharm+mysql 搭建首个web项目详解
2019/11/29 Python
tensorflow生成多个tfrecord文件实例
2020/02/17 Python
django 模型中的计算字段实例
2020/05/19 Python
css3.0 图形构成实例练习二
2013/03/19 HTML / CSS
自定义html标记替换html5新增元素
2008/10/17 HTML / CSS
银行实习生的自我评价
2014/01/13 职场文书
学习交流会主持词
2014/04/01 职场文书
1000字打架检讨书
2014/11/03 职场文书
2014年会计工作总结
2014/11/27 职场文书
学校工会工作总结2015
2015/05/19 职场文书
2015年高三教学工作总结
2015/07/21 职场文书
SpringDataJPA实体类关系映射配置方式
2021/12/06 Java/Android