PyTorch 如何设置随机数种子使结果可复现


Posted in Python onMay 12, 2021

由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。

因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。

CUDNN

cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:

from torch.backends import cudnn
cudnn.benchmark = False            # if benchmark=True, deterministic will be False
cudnn.deterministic = True

不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。

Pytorch

torch.manual_seed(seed)            # 为CPU设置随机种子
torch.cuda.manual_seed(seed)       # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)   # 为所有GPU设置随机种子

Python & Numpy

如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。

import random
import numpy as np
random.seed(seed)
np.random.seed(seed)

Dataloader

如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。

也就是说,改变num_workers参数,也会对实验结果产生影响。

目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。

补充:pytorch 固定随机数种子踩过的坑

1.初步固定

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     torch.cuda.manual_seed(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
     torch.backends.cudnn.enabled = False
     torch.backends.cudnn.benchmark = False
     #torch.backends.cudnn.benchmark = True #for accelerating the running
 setup_seed(2019)

2.继续添加如下代码:

tensor_dataset = ImageList(opt.training_list,transform)
def _init_fn(worker_id): 
    random.seed(10 + worker_id)
    np.random.seed(10 + worker_id)
    torch.manual_seed(10 + worker_id)
    torch.cuda.manual_seed(10 + worker_id)
    torch.cuda.manual_seed_all(10 + worker_id)
dataloader = DataLoader(tensor_dataset,                        
                    batch_size=opt.batchSize,     
                    shuffle=True,     
                    num_workers=opt.workers,
                    worker_init_fn=_init_fn)

3.在上面的操作之后发现加载的数据多次试验大部分一致了

但是仍然有些数据是不一致的,后来发现是pytorch版本的问题,将原先的0.3.1版本升级到1.1.0版本,问题解决

4.按照上面的操作后虽然解决了问题

但是由于将cudnn.benchmark设置为False,运行速度降低到原来的1/3,所以继续探索,最终解决方案是把第1步变为如下,同时将该部分代码尽可能放在主程序最开始的部分,例如:

import torch
import torch.nn as nn
from torch.nn import init
import pdb
import torch.nn.parallel
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.utils.data import DataLoader, Dataset
import sys
gpu_id = "3,2"
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
print('GPU: ',gpu_id)
def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     torch.cuda.manual_seed(seed)
     np.random.seed(seed)
     random.seed(seed)
     cudnn.deterministic = True
     #cudnn.benchmark = False
     #cudnn.enabled = False

setup_seed(2019)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
一个简单的python程序实例(通讯录)
Nov 29 Python
Python之os操作方法(详解)
Jun 15 Python
Python 内置函数进制转换的用法(十进制转二进制、八进制、十六进制)
Apr 30 Python
Python查看微信撤回消息代码
Jun 07 Python
Python的条件锁与事件共享详解
Sep 12 Python
Python中生成一个指定长度的随机字符串实现示例
Nov 06 Python
最小二乘法及其python实现详解
Feb 24 Python
Django中使用Json返回数据的实现方法
Jun 03 Python
python subprocess pipe 实时输出日志的操作
Dec 05 Python
PyChon中关于Jekins的详细安装(推荐)
Dec 28 Python
python 装饰器的基本使用
Jan 13 Python
七个Python必备的GUI库
Apr 27 Python
Python Parser的用法
May 12 #Python
pytorch MSELoss计算平均的实现方法
May 12 #Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
Python入门之使用pandas分析excel数据
May 12 #Python
You might like
在PHP中养成7个面向对象的好习惯
2010/07/17 PHP
体育彩票排列三组选三算法分享
2014/03/07 PHP
PHP网页游戏学习之Xnova(ogame)源码解读(十一)
2014/06/25 PHP
Joomla调用系统自带编辑器的实现方法
2016/05/05 PHP
PHP中的empty、isset、isnull的区别与使用实例
2019/03/22 PHP
JS 中document.URL 和 windows.location.href 的区别
2009/11/11 Javascript
jquery蒙版控件实现代码
2010/12/08 Javascript
js实现字符串的16进制编码不加密
2014/04/25 Javascript
Node.js 的异步 IO 性能探讨
2014/10/08 Javascript
《JavaScript DOM 编程艺术》读书笔记之JavaScript 简史
2015/01/09 Javascript
jquery实现聚光灯效果的方法
2015/02/06 Javascript
jquery带有索引按钮且自动轮播切换特效代码分享
2015/09/15 Javascript
JavaScript数组去重的五种方法
2015/11/05 Javascript
Node.js之网络通讯模块实现浅析
2017/04/01 Javascript
详解nodejs微信公众号开发——6.自定义菜单
2017/04/13 NodeJs
JavaScript实现的选择排序算法实例分析
2017/04/14 Javascript
Angularjs根据json文件动态生成路由状态的实现方法
2017/04/17 Javascript
JS实现简易的图片拖拽排序实例代码
2017/06/09 Javascript
vue中get请求如何传递数组参数的方法示例
2019/11/08 Javascript
[03:42]2014DOTA2国际邀请赛 第三日比赛排位扑朔迷离
2014/07/12 DOTA
python实现连接mongodb的方法
2015/05/08 Python
人生苦短我用python python如何快速入门?
2018/03/12 Python
Python正则匹配判断手机号是否合法的方法
2020/12/09 Python
浅谈Python大神都是这样处理XML文件的
2019/05/31 Python
python 整数越界问题详解
2019/06/27 Python
django 中使用DateTime常用的时间查询方式
2019/12/03 Python
python中format函数如何使用
2020/06/22 Python
10行Python代码实现Web自动化管控的示例代码
2020/08/14 Python
CSS3 制作旋转的大风车(充满童年回忆)
2013/01/30 HTML / CSS
html5的新增的标签和废除的标签简要概述
2013/02/20 HTML / CSS
八年级数学教学反思
2014/01/31 职场文书
大学生学期自我鉴定
2014/03/19 职场文书
经典禁毒标语
2014/06/16 职场文书
爱护环境建议书
2015/09/14 职场文书
小学六年级班主任工作经验交流材料
2015/11/02 职场文书
2019预备党员转正申请书模板2篇!
2019/08/07 职场文书