PyTorch 如何设置随机数种子使结果可复现


Posted in Python onMay 12, 2021

由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。

因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。

CUDNN

cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:

from torch.backends import cudnn
cudnn.benchmark = False            # if benchmark=True, deterministic will be False
cudnn.deterministic = True

不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。

Pytorch

torch.manual_seed(seed)            # 为CPU设置随机种子
torch.cuda.manual_seed(seed)       # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)   # 为所有GPU设置随机种子

Python & Numpy

如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。

import random
import numpy as np
random.seed(seed)
np.random.seed(seed)

Dataloader

如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。

也就是说,改变num_workers参数,也会对实验结果产生影响。

目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。

补充:pytorch 固定随机数种子踩过的坑

1.初步固定

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     torch.cuda.manual_seed(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
     torch.backends.cudnn.enabled = False
     torch.backends.cudnn.benchmark = False
     #torch.backends.cudnn.benchmark = True #for accelerating the running
 setup_seed(2019)

2.继续添加如下代码:

tensor_dataset = ImageList(opt.training_list,transform)
def _init_fn(worker_id): 
    random.seed(10 + worker_id)
    np.random.seed(10 + worker_id)
    torch.manual_seed(10 + worker_id)
    torch.cuda.manual_seed(10 + worker_id)
    torch.cuda.manual_seed_all(10 + worker_id)
dataloader = DataLoader(tensor_dataset,                        
                    batch_size=opt.batchSize,     
                    shuffle=True,     
                    num_workers=opt.workers,
                    worker_init_fn=_init_fn)

3.在上面的操作之后发现加载的数据多次试验大部分一致了

但是仍然有些数据是不一致的,后来发现是pytorch版本的问题,将原先的0.3.1版本升级到1.1.0版本,问题解决

4.按照上面的操作后虽然解决了问题

但是由于将cudnn.benchmark设置为False,运行速度降低到原来的1/3,所以继续探索,最终解决方案是把第1步变为如下,同时将该部分代码尽可能放在主程序最开始的部分,例如:

import torch
import torch.nn as nn
from torch.nn import init
import pdb
import torch.nn.parallel
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.utils.data import DataLoader, Dataset
import sys
gpu_id = "3,2"
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
print('GPU: ',gpu_id)
def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     torch.cuda.manual_seed(seed)
     np.random.seed(seed)
     random.seed(seed)
     cudnn.deterministic = True
     #cudnn.benchmark = False
     #cudnn.enabled = False

setup_seed(2019)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python利用hook技术破解https的实例代码
Mar 25 Python
Python OS模块常用函数说明
May 23 Python
python生成随机图形验证码详解
Nov 08 Python
Python GUI Tkinter简单实现个性签名设计
Jun 19 Python
Python socket套接字实现C/S模式远程命令执行功能案例
Jul 06 Python
Django 实现admin后台显示图片缩略图的例子
Jul 28 Python
Pytorch Tensor的索引与切片例子
Aug 18 Python
Python 2种方法求某个范围内的所有素数(质数)
Jan 31 Python
Python类的绑定方法和非绑定方法实例解析
Mar 04 Python
jupyter notebook插入本地图片的实现
Apr 13 Python
使用python创建生成动态链接库dll的方法
May 09 Python
五分钟带你搞懂python 迭代器与生成器
Aug 30 Python
Python Parser的用法
May 12 #Python
pytorch MSELoss计算平均的实现方法
May 12 #Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
Python入门之使用pandas分析excel数据
May 12 #Python
You might like
PHP手机号码归属地查询代码(API接口/mysql)
2012/09/04 PHP
PHP中空字符串介绍0、null、empty和false之间的关系
2012/09/25 PHP
PHP数组传递是值传递而非引用传递概念纠正
2013/01/31 PHP
php 魔术方法详解
2014/11/11 PHP
PHP中UNIX时间戳和日期间的转换与计算实例
2014/11/19 PHP
Yii+upload实现AJAX上传图片的方法
2016/07/13 PHP
15个款优秀的 jQuery 图片特效插件推荐
2011/11/21 Javascript
js函数调用常用方法详解
2012/12/03 Javascript
javascript获取重复次数最多的字符
2015/07/08 Javascript
JS实现横向拉伸动感伸缩菜单效果代码
2015/09/04 Javascript
js面向对象之常见创建对象的几种方式(工厂模式、构造函数模式、原型模式)
2015/11/09 Javascript
JS中BOM相关知识点总结(必看篇)
2016/11/22 Javascript
jquery网页日历显示控件calendar3.1使用详解
2016/11/24 Javascript
vue iview实现动态路由和权限验证功能
2018/04/17 Javascript
element-ui table span-method(行合并)的实现代码
2018/12/20 Javascript
vue-router命名视图的使用讲解
2019/01/19 Javascript
序列化模块json代码实例详解
2020/03/03 Javascript
python和pyqt实现360的CLable控件
2014/02/21 Python
python 利用for循环 保存多个图像或者文件的实例
2018/11/09 Python
PyQt5实现QLineEdit添加clicked信号的方法
2019/06/25 Python
Python识别快递条形码及Tesseract-OCR使用详解
2019/07/15 Python
python二分法查找算法实现方法【递归与非递归】
2019/12/06 Python
python GUI库图形界面开发之PyQt5复选框控件QCheckBox详细使用方法与实例
2020/02/28 Python
Tensorflow tensor 数学运算和逻辑运算方式
2020/06/30 Python
HTML5 语音搜索只需一句代码
2013/01/03 HTML / CSS
Gerry Weber德国官网:优质女性时装,德国最大的时装公司之一
2019/11/02 全球购物
惠普新加坡官方商店:HP Singapore
2020/04/17 全球购物
瀑布模型都有哪些优缺点
2014/06/23 面试题
为什么要做架构设计
2015/07/08 面试题
Prototype如何为一个Ajax添加一个参数
2015/12/06 面试题
协议书格式
2014/04/23 职场文书
建筑工地门卫岗位职责
2014/04/30 职场文书
公司董事长助理工作职责
2014/07/12 职场文书
医学检验专业自荐信
2014/09/18 职场文书
合理化建议书
2015/02/04 职场文书
幼儿园小班个人工作总结
2015/02/12 职场文书