浅谈PyTorch的可重复性问题(如何使实验结果可复现)


Posted in Python onFebruary 20, 2020

由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。

许多博客都有介绍如何解决这个问题,但是很多都不够全面,往往不能保证结果精确一致。我经过许多调研和实验,总结了以下方法,记录下来。

全部设置可以分为三部分:

1. CUDNN

cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:

from torch.backends import cudnn
cudnn.benchmark = False      # if benchmark=True, deterministic will be False
cudnn.deterministic = True

不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。

2. Pytorch

torch.manual_seed(seed)      # 为CPU设置随机种子
torch.cuda.manual_seed(seed)    # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)  # 为所有GPU设置随机种子

3. Python & Numpy

如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。

import random
import numpy as np
random.seed(seed)
np.random.seed(seed)

最后,关于dataloader:

注意,如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。也就是说,改变num_workers参数,也会对实验结果产生影响。目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。

对于不同线程的随机数种子设置,主要通过DataLoader的worker_init_fn参数来实现。默认情况下使用线程ID作为随机数种子。如果需要自己设定,可以参考以下代码:

GLOBAL_SEED = 1
 
def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
 
GLOBAL_WORKER_ID = None
def worker_init_fn(worker_id):
  global GLOBAL_WORKER_ID
  GLOBAL_WORKER_ID = worker_id
  set_seed(GLOBAL_SEED + worker_id)
 
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2, worker_init_fn=worker_init_fn)

以上这篇浅谈PyTorch的可重复性问题(如何使实验结果可复现)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python中的字典来处理索引统计的方法
May 05 Python
Python实现的朴素贝叶斯分类器示例
Jan 06 Python
Python面向对象之类的内置attr属性示例
Dec 14 Python
python获取url的返回信息方法
Dec 17 Python
利用Python库Scapy解析pcap文件的方法
Jul 23 Python
python redis连接 有序集合去重的代码
Aug 04 Python
python-序列解包(对可迭代元素的快速取值方法)
Aug 24 Python
pycharm修改file type方式
Nov 19 Python
Python TCP通信客户端服务端代码实例
Nov 21 Python
Pyspark读取parquet数据过程解析
Mar 27 Python
python基于爬虫+django,打造个性化API接口
Jan 21 Python
Python实现自动玩连连看的脚本分享
Apr 04 Python
pytorch 模型的train模式与eval模式实例
Feb 20 #Python
pytorch dataloader 取batch_size时候出现bug的解决方式
Feb 20 #Python
pytorch 使用加载训练好的模型做inference
Feb 20 #Python
pytorch中的inference使用实例
Feb 20 #Python
python encrypt 实现AES加密的实例详解
Feb 20 #Python
Python关于反射的实例代码分享
Feb 20 #Python
Python3监控疫情的完整代码
Feb 20 #Python
You might like
Apache设置虚拟WEB
2006/10/09 PHP
PHP与MySQL开发中页面乱码的产生与解决
2008/03/27 PHP
php 删除cookie方法详解
2014/12/01 PHP
php简单获取目录列表的方法
2015/03/24 PHP
php实现的一个简单json rpc框架实例
2015/03/30 PHP
浅谈PHP中foreach/in_array的使用
2015/11/02 PHP
thinkPHP简单遍历数组方法分析
2016/05/16 PHP
PHP实现动态创建XML文档的方法
2018/03/30 PHP
php微信开发之图片回复功能
2018/06/14 PHP
javascript 多种搜索引擎集成的页面实现代码
2010/01/02 Javascript
超链接的禁用属性Disabled使用示例
2014/07/31 Javascript
浅谈javascript中for in 和 for each in的区别
2015/04/23 Javascript
javascript常用的方法分享
2015/07/01 Javascript
Angular的自定义指令以及实例
2016/12/26 Javascript
如何使用angularJs
2017/05/08 Javascript
NodeJs中express框架的send()方法简介
2017/06/20 NodeJs
Vue项目中使用Vux的安装过程
2018/05/01 Javascript
详解JavaScript的变量
2019/04/04 Javascript
[55:42]VG vs VGJ.T 2018国际邀请赛淘汰赛BO1 8.21
2018/08/22 DOTA
Python常见的pandas用法demo示例
2019/03/16 Python
pycharm sciview的图片另存为操作
2020/06/01 Python
利用纯CSS3实现动态的自行车特效源码
2017/01/20 HTML / CSS
HTML5实现自带进度条和滑块滑杆效果
2018/04/17 HTML / CSS
深入理解HTML5定时器requestAnimationFrame的使用
2018/12/12 HTML / CSS
Banana Republic欧盟:美国都市简约风格的代表品牌
2018/05/09 全球购物
都柏林通行卡/城市通票:The Dublin Pass
2020/02/16 全球购物
金蝶的一道SQL笔试题
2012/12/18 面试题
大学系主任推荐信范文
2013/12/24 职场文书
办公室秘书自我鉴定
2014/01/18 职场文书
小学校长竞聘演讲稿
2014/05/16 职场文书
普通党员个人整改措施
2014/10/27 职场文书
离婚案件上诉状
2015/05/23 职场文书
实习证明模板
2015/06/16 职场文书
MySQL 中如何归档数据的实现方法
2022/03/16 SQL Server
如何更改Win11声音输出设备?Win11声音输出设备四种更改方法
2022/04/08 数码科技
关于vue-router-link选择样式设置
2022/04/30 Vue.js