浅谈PyTorch的可重复性问题(如何使实验结果可复现)


Posted in Python onFebruary 20, 2020

由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。

许多博客都有介绍如何解决这个问题,但是很多都不够全面,往往不能保证结果精确一致。我经过许多调研和实验,总结了以下方法,记录下来。

全部设置可以分为三部分:

1. CUDNN

cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:

from torch.backends import cudnn
cudnn.benchmark = False      # if benchmark=True, deterministic will be False
cudnn.deterministic = True

不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。

2. Pytorch

torch.manual_seed(seed)      # 为CPU设置随机种子
torch.cuda.manual_seed(seed)    # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)  # 为所有GPU设置随机种子

3. Python & Numpy

如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。

import random
import numpy as np
random.seed(seed)
np.random.seed(seed)

最后,关于dataloader:

注意,如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。也就是说,改变num_workers参数,也会对实验结果产生影响。目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。

对于不同线程的随机数种子设置,主要通过DataLoader的worker_init_fn参数来实现。默认情况下使用线程ID作为随机数种子。如果需要自己设定,可以参考以下代码:

GLOBAL_SEED = 1
 
def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
 
GLOBAL_WORKER_ID = None
def worker_init_fn(worker_id):
  global GLOBAL_WORKER_ID
  GLOBAL_WORKER_ID = worker_id
  set_seed(GLOBAL_SEED + worker_id)
 
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2, worker_init_fn=worker_init_fn)

以上这篇浅谈PyTorch的可重复性问题(如何使实验结果可复现)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python模块学习 re 正则表达式
May 19 Python
Python与Redis的连接教程
Apr 22 Python
使用Python的Bottle框架写一个简单的服务接口的示例
Aug 25 Python
用Python写飞机大战游戏之pygame入门(4):获取鼠标的位置及运动
Nov 05 Python
Python基础知识_浅谈用户交互
May 31 Python
python实现用户管理系统
Jan 10 Python
Python抓取聚划算商品分析页面获取商品信息并以XML格式保存到本地
Feb 23 Python
Python计算时间间隔(精确到微妙)的代码实例
Feb 26 Python
Django如何开发简单的查询接口详解
May 17 Python
解决pycharm启动后总是不停的updating indices...indexing的问题
Nov 27 Python
python进行参数传递的方法
May 12 Python
Python绘制K线图之可视化神器pyecharts的使用
Mar 02 Python
pytorch 模型的train模式与eval模式实例
Feb 20 #Python
pytorch dataloader 取batch_size时候出现bug的解决方式
Feb 20 #Python
pytorch 使用加载训练好的模型做inference
Feb 20 #Python
pytorch中的inference使用实例
Feb 20 #Python
python encrypt 实现AES加密的实例详解
Feb 20 #Python
Python关于反射的实例代码分享
Feb 20 #Python
Python3监控疫情的完整代码
Feb 20 #Python
You might like
生成卡号php代码
2008/04/09 PHP
快速配置PHPMyAdmin方法
2008/06/05 PHP
php在线打包程序源码
2008/07/27 PHP
php自动跳转中英文页面
2008/07/29 PHP
PHP模块 Memcached功能多于Memcache
2011/06/14 PHP
thinkPHP模板引擎用法示例
2016/12/08 PHP
Laravel基础-关于引入公共文件的两种方式
2019/10/18 PHP
javascript 对象定义方法 简单易学
2009/03/22 Javascript
js网页实时倒计时精确到秒级
2014/02/10 Javascript
原生js结合html5制作简易的双色子游戏
2015/03/30 Javascript
nodejs简单实现中英文翻译
2015/05/04 NodeJs
jQuery.prop() 使用详解
2015/07/19 Javascript
JS获取字符串实际长度(包含汉字)的简单方法
2016/08/11 Javascript
AngularJS定时器的使用与移除操作方法【interval与timeout】
2016/12/14 Javascript
关于vue-router的那些事儿
2018/05/23 Javascript
node中modules.exports与exports导出的区别
2018/06/08 Javascript
微信小程序new Date()方法失效问题解决方法
2019/07/29 Javascript
[52:29]DOTA2上海特级锦标赛主赛事日 - 2 胜者组第一轮#3Secret VS OG第三局
2016/03/03 DOTA
python应用程序在windows下不出现cmd窗口的办法
2014/05/29 Python
python通过post提交数据的方法
2015/05/06 Python
python通过ssh-powershell监控windows的方法
2015/06/02 Python
对numpy 数组和矩阵的乘法的进一步理解
2018/04/04 Python
Python脚本完成post接口测试的实例
2018/12/17 Python
使用Python脚本zabbix自定义key监控oracle连接状态
2019/08/28 Python
Python环境Pillow( PIL )图像处理工具使用解析
2019/09/12 Python
python的pip有什么用
2020/06/17 Python
Python读写Excel表格的方法
2021/03/02 Python
实列教程 一款基于jquery和css3的响应式二级导航菜单
2014/11/13 HTML / CSS
德国化妆品和天然化妆品网上商店:kosmetikfuchs.de
2017/06/09 全球购物
英国音乐设备和乐器商店:Gear4music
2017/10/16 全球购物
应届大专毕业生自我鉴定
2014/04/08 职场文书
校园活动策划方案
2014/06/13 职场文书
学习礼仪心得体会
2014/09/01 职场文书
励志语录:你若不勇敢,谁替你坚强
2019/11/08 职场文书
想创业成功,需要掌握这些要点
2019/12/06 职场文书
python 模拟在天空中放风筝的示例代码
2021/04/21 Python