PyTorch 如何设置随机数种子使结果可复现


Posted in Python onMay 12, 2021

由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。

因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。

CUDNN

cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:

from torch.backends import cudnn
cudnn.benchmark = False            # if benchmark=True, deterministic will be False
cudnn.deterministic = True

不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。

Pytorch

torch.manual_seed(seed)            # 为CPU设置随机种子
torch.cuda.manual_seed(seed)       # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)   # 为所有GPU设置随机种子

Python & Numpy

如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。

import random
import numpy as np
random.seed(seed)
np.random.seed(seed)

Dataloader

如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。

也就是说,改变num_workers参数,也会对实验结果产生影响。

目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。

补充:pytorch 固定随机数种子踩过的坑

1.初步固定

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     torch.cuda.manual_seed(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
     torch.backends.cudnn.enabled = False
     torch.backends.cudnn.benchmark = False
     #torch.backends.cudnn.benchmark = True #for accelerating the running
 setup_seed(2019)

2.继续添加如下代码:

tensor_dataset = ImageList(opt.training_list,transform)
def _init_fn(worker_id): 
    random.seed(10 + worker_id)
    np.random.seed(10 + worker_id)
    torch.manual_seed(10 + worker_id)
    torch.cuda.manual_seed(10 + worker_id)
    torch.cuda.manual_seed_all(10 + worker_id)
dataloader = DataLoader(tensor_dataset,                        
                    batch_size=opt.batchSize,     
                    shuffle=True,     
                    num_workers=opt.workers,
                    worker_init_fn=_init_fn)

3.在上面的操作之后发现加载的数据多次试验大部分一致了

但是仍然有些数据是不一致的,后来发现是pytorch版本的问题,将原先的0.3.1版本升级到1.1.0版本,问题解决

4.按照上面的操作后虽然解决了问题

但是由于将cudnn.benchmark设置为False,运行速度降低到原来的1/3,所以继续探索,最终解决方案是把第1步变为如下,同时将该部分代码尽可能放在主程序最开始的部分,例如:

import torch
import torch.nn as nn
from torch.nn import init
import pdb
import torch.nn.parallel
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.utils.data import DataLoader, Dataset
import sys
gpu_id = "3,2"
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
print('GPU: ',gpu_id)
def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     torch.cuda.manual_seed(seed)
     np.random.seed(seed)
     random.seed(seed)
     cudnn.deterministic = True
     #cudnn.benchmark = False
     #cudnn.enabled = False

setup_seed(2019)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Python collections模块实例讲解
Apr 07 Python
在Python中处理字符串之isdecimal()方法的使用
May 20 Python
Windows中使用wxPython和py2exe开发Python的GUI程序的实例教程
Jul 11 Python
十行代码使用Python写一个USB病毒
Jun 21 Python
python的json中方法及jsonpath模块用法分析
Dec 06 Python
Tensorflow中批量读取数据的案列分析及TFRecord文件的打包与读取
Jun 30 Python
一文解决django 2.2与mysql兼容性问题
Jul 15 Python
python中return不返回值的问题解析
Jul 22 Python
python在一个范围内取随机数的简单实例
Aug 16 Python
超级实用的8个Python列表技巧
Aug 24 Python
如何用Python编写一个电子考勤系统
Feb 08 Python
使用tensorflow 实现反向传播求导
May 26 Python
Python Parser的用法
May 12 #Python
pytorch MSELoss计算平均的实现方法
May 12 #Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
Python入门之使用pandas分析excel数据
May 12 #Python
You might like
用 php 编写的日历
2006/10/09 PHP
PHP生成网页快照 不用COM不用扩展.
2010/02/11 PHP
PHP 简易输出CSV表格文件的方法详解
2013/06/20 PHP
php检测文件编码的方法示例
2014/04/25 PHP
Win2003+apache+PHP+SqlServer2008 配置生产环境
2014/07/29 PHP
php获取英文姓名首字母的方法
2015/07/13 PHP
yii插入数据库防并发的简单代码
2017/05/27 PHP
php实现生成PDF文件的方法示例【基于FPDF类库】
2018/07/21 PHP
浅谈Laravel POST,PUT,PATCH 路由的区别
2019/10/15 PHP
为指定元素增加样式的js代码
2009/12/09 Javascript
JSON 学习之JSON in JavaScript详细使用说明
2010/02/23 Javascript
通过DOM脚本去设置样式信息
2010/09/19 Javascript
js兼容火狐显示上传图片预览效果的方法
2015/05/21 Javascript
初步认识JavaScript函数库jQuery
2015/06/18 Javascript
JavaScript节点及列表操作实例小结
2015/08/05 Javascript
JavaScript 判断一个对象{}是否为空对象的简单方法
2016/10/09 Javascript
javascript iframe跨域详解
2016/10/26 Javascript
BootStrap table删除指定行的注意事项(笔记整理)
2017/02/05 Javascript
动态统计当前输入内容的字节、字符数的实例详解
2017/10/27 Javascript
使用react实现手机号的数据同步显示功能的示例代码
2018/04/03 Javascript
vue.js前后端数据交互之提交数据操作详解
2018/04/24 Javascript
Vue.js实现双向数据绑定方法(表单自动赋值、表单自动取值)
2018/08/27 Javascript
JS实现点击拉拽轮播图pc端移动端适配
2018/09/05 Javascript
Jquery遍历筛选数组的几种方法和遍历解析json对象,Map()方法详解以及数组中查询某值是否存在
2019/01/18 jQuery
uni-app从安装到卸载的入门教程
2020/05/15 Javascript
教你使用python画一朵花送女朋友
2018/03/29 Python
pytorch实现从本地加载 .pth 格式模型
2020/02/14 Python
python GUI库图形界面开发之PyQt5计数器控件QSpinBox详细使用方法与实例
2020/02/28 Python
python之openpyxl模块的安装和基本用法(excel管理)
2021/02/03 Python
教师产假请假条
2014/04/10 职场文书
法制演讲稿
2014/09/10 职场文书
2015年初三班主任工作总结
2015/05/21 职场文书
2016入党积极分子党课学习心得体会
2015/10/09 职场文书
Nginx的反向代理实例详解
2021/03/31 Servers
通过shell脚本对mysql的增删改查及my.cnf的配置
2021/07/07 MySQL
一文搞懂Java中的注解和反射
2022/06/21 Java/Android