PyTorch 如何设置随机数种子使结果可复现


Posted in Python onMay 12, 2021

由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。

因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。

CUDNN

cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:

from torch.backends import cudnn
cudnn.benchmark = False            # if benchmark=True, deterministic will be False
cudnn.deterministic = True

不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。

Pytorch

torch.manual_seed(seed)            # 为CPU设置随机种子
torch.cuda.manual_seed(seed)       # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)   # 为所有GPU设置随机种子

Python & Numpy

如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。

import random
import numpy as np
random.seed(seed)
np.random.seed(seed)

Dataloader

如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。

也就是说,改变num_workers参数,也会对实验结果产生影响。

目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。

补充:pytorch 固定随机数种子踩过的坑

1.初步固定

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     torch.cuda.manual_seed(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
     torch.backends.cudnn.enabled = False
     torch.backends.cudnn.benchmark = False
     #torch.backends.cudnn.benchmark = True #for accelerating the running
 setup_seed(2019)

2.继续添加如下代码:

tensor_dataset = ImageList(opt.training_list,transform)
def _init_fn(worker_id): 
    random.seed(10 + worker_id)
    np.random.seed(10 + worker_id)
    torch.manual_seed(10 + worker_id)
    torch.cuda.manual_seed(10 + worker_id)
    torch.cuda.manual_seed_all(10 + worker_id)
dataloader = DataLoader(tensor_dataset,                        
                    batch_size=opt.batchSize,     
                    shuffle=True,     
                    num_workers=opt.workers,
                    worker_init_fn=_init_fn)

3.在上面的操作之后发现加载的数据多次试验大部分一致了

但是仍然有些数据是不一致的,后来发现是pytorch版本的问题,将原先的0.3.1版本升级到1.1.0版本,问题解决

4.按照上面的操作后虽然解决了问题

但是由于将cudnn.benchmark设置为False,运行速度降低到原来的1/3,所以继续探索,最终解决方案是把第1步变为如下,同时将该部分代码尽可能放在主程序最开始的部分,例如:

import torch
import torch.nn as nn
from torch.nn import init
import pdb
import torch.nn.parallel
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.utils.data import DataLoader, Dataset
import sys
gpu_id = "3,2"
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
print('GPU: ',gpu_id)
def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     torch.cuda.manual_seed(seed)
     np.random.seed(seed)
     random.seed(seed)
     cudnn.deterministic = True
     #cudnn.benchmark = False
     #cudnn.enabled = False

setup_seed(2019)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python使用os模块的os.walk遍历文件夹示例
Jan 27 Python
python操作xml文件详细介绍
Jun 09 Python
python获取Linux下文件版本信息、公司名和产品名的方法
Oct 05 Python
Python动态加载模块的3种方法
Nov 22 Python
python搜索指定目录的方法
Apr 29 Python
Python中struct模块对字节流/二进制流的操作教程
Jan 21 Python
对numpy中轴与维度的理解
Apr 18 Python
Python里字典的基本用法(包括嵌套字典)
Feb 27 Python
详谈tensorflow gfile文件的用法
Feb 05 Python
浅谈Pycharm的项目文件名是红色的原因及解决方式
Jun 01 Python
Python numpy大矩阵运算内存不足如何解决
Nov 19 Python
Python排序函数的使用方法详解
Dec 11 Python
Python Parser的用法
May 12 #Python
pytorch MSELoss计算平均的实现方法
May 12 #Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
Python入门之使用pandas分析excel数据
May 12 #Python
You might like
PHP微信开发之模板消息回复
2016/06/24 PHP
PHP xpath()函数讲解
2019/02/11 PHP
JavaScript 动态将数字金额转化为中文大写金额
2009/05/14 Javascript
jquery1.4 教程二 ajax方法的改进
2010/02/25 Javascript
JS中不为人知的五种声明Number的方式简要概述
2013/02/22 Javascript
JS格式化数字金额用逗号隔开保留两位小数
2013/10/18 Javascript
jquery UI Datepicker时间控件的使用方法(基础版)
2015/11/07 Javascript
Javascript将JSON日期格式化
2016/08/23 Javascript
Angularjs 实现分页功能及示例代码
2016/09/14 Javascript
JS实现快速的导航下拉菜单动画效果附源码下载
2016/11/01 Javascript
jquery 给动态生成的标签绑定事件的几种方法总结
2018/02/24 jQuery
小程序实现多选框功能
2018/10/30 Javascript
js中int和string数据类型互相转化实例
2019/01/16 Javascript
微信小程序缓存过期时间的使用详情
2019/05/12 Javascript
JavaScript将数组转换为链表的方法
2020/02/16 Javascript
[04:39]显微镜下的DOTA2第十三期—Pis卡尔个人秀
2014/04/04 DOTA
Python随机生成数据后插入到PostgreSQL
2016/07/28 Python
ubuntu系统下 python链接mysql数据库的方法
2017/01/09 Python
Python列表解析配合if else的方法
2018/06/23 Python
用Django写天气预报查询网站
2018/10/21 Python
详解django自定义中间件处理
2018/11/21 Python
Python模拟百度自动输入搜索功能的实例
2019/02/14 Python
打包python 加icon 去掉cmd黑窗口方法
2019/06/24 Python
使用python打印十行杨辉三角过程详解
2019/07/10 Python
CSS实现定位元素居中的方法
2015/06/23 HTML / CSS
CSS3实现多重边框的方法总结
2016/05/31 HTML / CSS
Parfume Klik丹麦:香水网上商店
2018/07/10 全球购物
类的核心特性有哪些
2014/01/01 面试题
秘书岗位职责
2013/11/18 职场文书
竞选副班长演讲稿
2014/04/24 职场文书
关于爱国的演讲稿
2014/05/07 职场文书
2019求职信:应届生求职信范文
2019/04/24 职场文书
python tkinter实现定时关机
2021/04/21 Python
启动Tomcat时出现大量乱码的解决方法
2021/06/21 Java/Android
通过Qt连接OpenGauss数据库的详细教程
2021/06/23 PostgreSQL
mysql全面解析json/数组
2022/07/07 MySQL