PyTorch 如何设置随机数种子使结果可复现


Posted in Python onMay 12, 2021

由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。

因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。

CUDNN

cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:

from torch.backends import cudnn
cudnn.benchmark = False            # if benchmark=True, deterministic will be False
cudnn.deterministic = True

不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。

Pytorch

torch.manual_seed(seed)            # 为CPU设置随机种子
torch.cuda.manual_seed(seed)       # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)   # 为所有GPU设置随机种子

Python & Numpy

如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。

import random
import numpy as np
random.seed(seed)
np.random.seed(seed)

Dataloader

如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。

也就是说,改变num_workers参数,也会对实验结果产生影响。

目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。

补充:pytorch 固定随机数种子踩过的坑

1.初步固定

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     torch.cuda.manual_seed(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
     torch.backends.cudnn.enabled = False
     torch.backends.cudnn.benchmark = False
     #torch.backends.cudnn.benchmark = True #for accelerating the running
 setup_seed(2019)

2.继续添加如下代码:

tensor_dataset = ImageList(opt.training_list,transform)
def _init_fn(worker_id): 
    random.seed(10 + worker_id)
    np.random.seed(10 + worker_id)
    torch.manual_seed(10 + worker_id)
    torch.cuda.manual_seed(10 + worker_id)
    torch.cuda.manual_seed_all(10 + worker_id)
dataloader = DataLoader(tensor_dataset,                        
                    batch_size=opt.batchSize,     
                    shuffle=True,     
                    num_workers=opt.workers,
                    worker_init_fn=_init_fn)

3.在上面的操作之后发现加载的数据多次试验大部分一致了

但是仍然有些数据是不一致的,后来发现是pytorch版本的问题,将原先的0.3.1版本升级到1.1.0版本,问题解决

4.按照上面的操作后虽然解决了问题

但是由于将cudnn.benchmark设置为False,运行速度降低到原来的1/3,所以继续探索,最终解决方案是把第1步变为如下,同时将该部分代码尽可能放在主程序最开始的部分,例如:

import torch
import torch.nn as nn
from torch.nn import init
import pdb
import torch.nn.parallel
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.utils.data import DataLoader, Dataset
import sys
gpu_id = "3,2"
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
print('GPU: ',gpu_id)
def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     torch.cuda.manual_seed(seed)
     np.random.seed(seed)
     random.seed(seed)
     cudnn.deterministic = True
     #cudnn.benchmark = False
     #cudnn.enabled = False

setup_seed(2019)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Python入门篇之正则表达式
Oct 20 Python
Python获取DLL和EXE文件版本号的方法
Mar 10 Python
在python中使用正则表达式查找可嵌套字符串组
Oct 24 Python
Python socket实现的简单通信功能示例
Aug 21 Python
python使用writerows写csv文件产生多余空行的处理方法
Aug 01 Python
使用turtle绘制五角星、分形树
Oct 06 Python
Python 爬虫实现增加播客访问量的方法实现
Oct 31 Python
python ffmpeg任意提取视频帧的方法
Feb 21 Python
Python 3.9的到来到底是意味着什么
Oct 14 Python
python+selenium自动化实战携带cookies模拟登陆微博
Jan 19 Python
matplotlib之pyplot模块之标题(title()和suptitle())
Feb 22 Python
在python中实现导入一个需要传参的模块
May 12 Python
Python Parser的用法
May 12 #Python
pytorch MSELoss计算平均的实现方法
May 12 #Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
Python入门之使用pandas分析excel数据
May 12 #Python
You might like
php使用mysqli向数据库添加数据的方法
2015/03/20 PHP
PHP实现原生态图片上传封装类方法
2016/11/08 PHP
Avengerls vs Newbee BO3 第三场2.18
2021/03/10 DOTA
JavaScript小技巧 2.5 则
2010/09/12 Javascript
jquery中eq和get的区别与使用方法
2011/04/14 Javascript
PHP 数组current和next用法分享
2015/03/05 Javascript
原生js实现的贪吃蛇网页版游戏完整实例
2015/05/18 Javascript
jquery限定文本框只能输入数字(整数和小数)
2016/01/08 Javascript
分享jQuery网页元素拖拽插件
2020/12/01 Javascript
用原生JS对AJAX做简单封装的实例代码
2016/07/13 Javascript
微信小程序购物商城系统开发系列-工具篇的介绍
2016/11/21 Javascript
JS判断是否手机或pad访问实现方法
2016/12/09 Javascript
angular.js实现购物车功能
2017/10/23 Javascript
微信小程序图片选择区域裁剪实现方法
2017/12/02 Javascript
微信小程序switch组件使用详解
2018/01/31 Javascript
vue中v-show和v-if的异同及v-show用法
2019/06/06 Javascript
微信小程序自定义底部弹出框动画
2020/11/18 Javascript
[46:43]DOTA2上海特级锦标赛主赛事日 - 1 胜者组第一轮#2LGD VS MVP.Phx第二局
2016/03/02 DOTA
python 读取目录下csv文件并绘制曲线v111的方法
2018/07/06 Python
Python将json文件写入ES数据库的方法
2019/04/10 Python
详解pandas删除缺失数据(pd.dropna()方法)
2019/06/25 Python
Python Django框架防御CSRF攻击的方法分析
2019/10/18 Python
Python猴子补丁Monkey Patch用法实例解析
2020/03/23 Python
详解CSS3中使用gradient实现渐变效果的方法
2015/08/18 HTML / CSS
canvas中普通动效与粒子动效的实现代码示例
2019/01/03 HTML / CSS
Shein英国:女性时尚网上商店
2019/04/10 全球购物
2014年党员整改措施
2014/10/24 职场文书
主持人开幕词
2015/01/29 职场文书
公积金接收函格式
2015/01/30 职场文书
医务人员医德考评自我评价
2015/03/03 职场文书
庆七一晚会主持词
2015/06/30 职场文书
2016年读书月活动总结范文
2016/04/06 职场文书
留学文书中的个人陈述,应该注意哪些问题?
2019/08/23 职场文书
python办公自动化之excel的操作
2021/05/23 Python
MySQL 常见存储引擎的优劣
2021/06/02 MySQL
python周期任务调度工具Schedule使用详解
2021/11/23 Python