PyTorch 如何设置随机数种子使结果可复现


Posted in Python onMay 12, 2021

由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。

因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。

CUDNN

cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:

from torch.backends import cudnn
cudnn.benchmark = False            # if benchmark=True, deterministic will be False
cudnn.deterministic = True

不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。

Pytorch

torch.manual_seed(seed)            # 为CPU设置随机种子
torch.cuda.manual_seed(seed)       # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)   # 为所有GPU设置随机种子

Python & Numpy

如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。

import random
import numpy as np
random.seed(seed)
np.random.seed(seed)

Dataloader

如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。

也就是说,改变num_workers参数,也会对实验结果产生影响。

目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。

补充:pytorch 固定随机数种子踩过的坑

1.初步固定

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     torch.cuda.manual_seed(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
     torch.backends.cudnn.enabled = False
     torch.backends.cudnn.benchmark = False
     #torch.backends.cudnn.benchmark = True #for accelerating the running
 setup_seed(2019)

2.继续添加如下代码:

tensor_dataset = ImageList(opt.training_list,transform)
def _init_fn(worker_id): 
    random.seed(10 + worker_id)
    np.random.seed(10 + worker_id)
    torch.manual_seed(10 + worker_id)
    torch.cuda.manual_seed(10 + worker_id)
    torch.cuda.manual_seed_all(10 + worker_id)
dataloader = DataLoader(tensor_dataset,                        
                    batch_size=opt.batchSize,     
                    shuffle=True,     
                    num_workers=opt.workers,
                    worker_init_fn=_init_fn)

3.在上面的操作之后发现加载的数据多次试验大部分一致了

但是仍然有些数据是不一致的,后来发现是pytorch版本的问题,将原先的0.3.1版本升级到1.1.0版本,问题解决

4.按照上面的操作后虽然解决了问题

但是由于将cudnn.benchmark设置为False,运行速度降低到原来的1/3,所以继续探索,最终解决方案是把第1步变为如下,同时将该部分代码尽可能放在主程序最开始的部分,例如:

import torch
import torch.nn as nn
from torch.nn import init
import pdb
import torch.nn.parallel
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.utils.data import DataLoader, Dataset
import sys
gpu_id = "3,2"
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
print('GPU: ',gpu_id)
def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     torch.cuda.manual_seed(seed)
     np.random.seed(seed)
     random.seed(seed)
     cudnn.deterministic = True
     #cudnn.benchmark = False
     #cudnn.enabled = False

setup_seed(2019)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
使用Python简单的实现树莓派的WEB控制
Feb 18 Python
python爬虫框架talonspider简单介绍
Jun 09 Python
Python实现矩阵转置的方法分析
Nov 24 Python
对pandas写入读取h5文件的方法详解
Dec 28 Python
Python中将两个或多个list合成一个list的方法小结
May 12 Python
详解python实现交叉验证法与留出法
Jul 11 Python
Python 类,property属性(简化属性的操作),@property,property()用法示例
Oct 12 Python
Django REST framework 单元测试实例解析
Nov 07 Python
如何在Django中使用聚合的实现示例
Mar 23 Python
python爬虫实现爬取同一个网站的多页数据的实例讲解
Jan 18 Python
Python基于argparse与ConfigParser库进行入参解析与ini parser
Feb 02 Python
Python 多线程之threading 模块的使用
Apr 14 Python
Python Parser的用法
May 12 #Python
pytorch MSELoss计算平均的实现方法
May 12 #Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
Python入门之使用pandas分析excel数据
May 12 #Python
You might like
关于手调机和数调机的选择
2021/03/02 无线电
php简单复制文件的方法
2016/05/09 PHP
用Greasemonkey 脚本收藏网站会员信息到本地
2009/10/26 Javascript
nodejs入门详解(多篇文章结合)
2012/03/07 NodeJs
JavaScript数据结构与算法之栈详解
2015/03/12 Javascript
JS+DIV+CSS排版布局实现美观的选项卡效果
2015/10/10 Javascript
jQuery学习笔记之Ajax用法实例详解
2015/12/01 Javascript
JQuery实现Ajax加载图片的方法
2015/12/24 Javascript
jQuery插件Validate实现自定义表单验证
2016/01/18 Javascript
Node.js实现文件上传
2016/07/05 Javascript
用Vue.js实现监听属性的变化
2016/11/17 Javascript
js实现常见的工具条效果
2017/03/02 Javascript
解决VUEX兼容IE上的报错问题
2018/03/01 Javascript
详解如何用babel转换es6的class语法
2018/04/03 Javascript
webpack中的热刷新与热加载的区别
2018/04/09 Javascript
微信小程序实现美团菜单
2018/06/06 Javascript
使用Javascript简单计算器
2018/11/17 Javascript
vue中img src 动态加载本地json的图片路径写法
2019/04/25 Javascript
解决Vue+Electron下Vuex的Dispatch没有效果问题
2019/05/20 Javascript
cordova+vue+webapp使用html5获取地理位置的方法
2019/07/06 Javascript
js微信分享接口调用详解
2019/07/23 Javascript
layui写后台表格思路和赋值用法详解
2019/11/14 Javascript
Node Express用法详解【安装、使用、路由、中间件、模板引擎等】
2020/05/13 Javascript
python 多线程应用介绍
2012/12/19 Python
python使用新浪微博api上传图片到微博示例
2014/01/10 Python
Linux下使用python自动修改本机网关代码分享
2015/05/21 Python
浅谈python字符串方法的简单使用
2016/07/18 Python
python实现微信远程控制电脑
2018/02/22 Python
python numpy 一维数组转变为多维数组的实例
2018/07/02 Python
如何实现Django Rest framework版本控制
2019/07/25 Python
Python基于数列实现购物车程序过程详解
2020/06/09 Python
以设计师精品品质提供快速时尚:PopJulia
2018/01/09 全球购物
2014年应届大学生毕业自我鉴定
2014/01/31 职场文书
三八红旗集体先进事迹材料
2014/05/22 职场文书
鼋头渚导游词
2015/02/05 职场文书
初中班主任工作随笔
2015/08/15 职场文书