PyTorch 如何设置随机数种子使结果可复现


Posted in Python onMay 12, 2021

由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。

因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。

CUDNN

cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:

from torch.backends import cudnn
cudnn.benchmark = False            # if benchmark=True, deterministic will be False
cudnn.deterministic = True

不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。

Pytorch

torch.manual_seed(seed)            # 为CPU设置随机种子
torch.cuda.manual_seed(seed)       # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)   # 为所有GPU设置随机种子

Python & Numpy

如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。

import random
import numpy as np
random.seed(seed)
np.random.seed(seed)

Dataloader

如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。

也就是说,改变num_workers参数,也会对实验结果产生影响。

目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。

补充:pytorch 固定随机数种子踩过的坑

1.初步固定

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     torch.cuda.manual_seed(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
     torch.backends.cudnn.enabled = False
     torch.backends.cudnn.benchmark = False
     #torch.backends.cudnn.benchmark = True #for accelerating the running
 setup_seed(2019)

2.继续添加如下代码:

tensor_dataset = ImageList(opt.training_list,transform)
def _init_fn(worker_id): 
    random.seed(10 + worker_id)
    np.random.seed(10 + worker_id)
    torch.manual_seed(10 + worker_id)
    torch.cuda.manual_seed(10 + worker_id)
    torch.cuda.manual_seed_all(10 + worker_id)
dataloader = DataLoader(tensor_dataset,                        
                    batch_size=opt.batchSize,     
                    shuffle=True,     
                    num_workers=opt.workers,
                    worker_init_fn=_init_fn)

3.在上面的操作之后发现加载的数据多次试验大部分一致了

但是仍然有些数据是不一致的,后来发现是pytorch版本的问题,将原先的0.3.1版本升级到1.1.0版本,问题解决

4.按照上面的操作后虽然解决了问题

但是由于将cudnn.benchmark设置为False,运行速度降低到原来的1/3,所以继续探索,最终解决方案是把第1步变为如下,同时将该部分代码尽可能放在主程序最开始的部分,例如:

import torch
import torch.nn as nn
from torch.nn import init
import pdb
import torch.nn.parallel
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.utils.data import DataLoader, Dataset
import sys
gpu_id = "3,2"
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
print('GPU: ',gpu_id)
def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     torch.cuda.manual_seed(seed)
     np.random.seed(seed)
     random.seed(seed)
     cudnn.deterministic = True
     #cudnn.benchmark = False
     #cudnn.enabled = False

setup_seed(2019)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python复制与引用用法分析
Apr 08 Python
Python操作Word批量生成文章的方法
Jul 28 Python
python 添加用户设置密码并发邮件给root用户
Jul 25 Python
Python基于递归和非递归算法求两个数最大公约数、最小公倍数示例
May 21 Python
Python3中exp()函数用法分析
Feb 19 Python
详解Django-restframework 之频率源码分析
Feb 27 Python
python 使用turtule绘制递归图形(螺旋、二叉树、谢尔宾斯基三角形)
May 30 Python
快速解决jupyter启动卡死的问题
Apr 10 Python
基于Python实现体育彩票选号器功能代码实例
Sep 16 Python
Python多线程 Queue 模块常见用法
Jul 04 Python
python数据可视化JupyterLab实用扩展程序Mito
Nov 20 Python
Python日志模块logging用法
Jun 05 Python
Python Parser的用法
May 12 #Python
pytorch MSELoss计算平均的实现方法
May 12 #Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
Python入门之使用pandas分析excel数据
May 12 #Python
You might like
php checkdate、getdate等日期时间函数操作详解
2010/03/11 PHP
php上的memcache和memcached两个pecl库
2010/03/29 PHP
php小偷相关截取函数备忘
2010/11/28 PHP
探讨PHP中this,self,parent的区别详解
2013/06/08 PHP
php var_export与var_dump 输出的不同
2013/08/09 PHP
php多维数组去掉重复值示例分享
2014/03/02 PHP
ThinkPHP3.1新特性之动态设置自动完成及自动验证示例代码
2014/06/23 PHP
ThinkPHP实现将本地文件打包成zip下载
2014/06/26 PHP
php处理复杂xml数据示例
2016/07/11 PHP
jquery 结合C#后台的数组对文章的关键字自动添加链接的代码
2011/07/15 Javascript
nodejs入门详解(多篇文章结合)
2012/03/07 NodeJs
Extjs4 GridPanel的主要配置参数详细介绍
2013/04/18 Javascript
js获取html文件的思路及示例
2013/09/17 Javascript
jQuery根据name属性进行查找的用法分析
2016/06/23 Javascript
纯JS焦点图特效实例(可一个页面多用)
2016/12/07 Javascript
Vue.js轮播图走马灯代码实例(全)
2019/05/08 Javascript
JS+html5实现异步上传图片显示上传文件进度条功能示例
2019/11/09 Javascript
nuxt.js 在middleware(中间件)中实现路由鉴权操作
2020/11/06 Javascript
[09:13]2014DOTA2国际邀请赛 中国区预选赛coser表演
2014/05/23 DOTA
python操作字典类型的常用方法(推荐)
2016/05/16 Python
基于Django contrib Comments 评论模块(详解)
2017/12/08 Python
对Python2与Python3中__bool__方法的差异详解
2018/11/01 Python
python爬虫刷访问量 2019 7月
2019/08/01 Python
Python3 集合set入门基础
2020/02/10 Python
Python3基本输入与输出操作实例分析
2020/02/14 Python
使用Python将语音转换为文本的方法
2020/08/10 Python
Python爬虫之Selenium多窗口切换的实现
2020/12/04 Python
Python 虚拟环境工作原理解析
2020/12/24 Python
Web页面中八种创建多列等高(等高列布局)的实现技术
2012/12/24 HTML / CSS
HTML5中input[type='date']自定义样式与日历校验功能的实现代码
2017/07/11 HTML / CSS
幼儿园中秋节活动方案
2014/02/06 职场文书
卫生院健康教育实施方案
2014/06/07 职场文书
学困生转化工作总结
2015/08/13 职场文书
2016十一国庆节感言
2015/12/09 职场文书
jackson json序列化实现首字母大写,第二个字母需小写
2021/06/29 Java/Android
vue中利用mqtt服务端实现即时通讯的步骤记录
2021/07/01 Vue.js