PyTorch 如何设置随机数种子使结果可复现


Posted in Python onMay 12, 2021

由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。

因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。

CUDNN

cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:

from torch.backends import cudnn
cudnn.benchmark = False            # if benchmark=True, deterministic will be False
cudnn.deterministic = True

不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。

Pytorch

torch.manual_seed(seed)            # 为CPU设置随机种子
torch.cuda.manual_seed(seed)       # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)   # 为所有GPU设置随机种子

Python & Numpy

如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。

import random
import numpy as np
random.seed(seed)
np.random.seed(seed)

Dataloader

如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。

也就是说,改变num_workers参数,也会对实验结果产生影响。

目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。

补充:pytorch 固定随机数种子踩过的坑

1.初步固定

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     torch.cuda.manual_seed(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
     torch.backends.cudnn.enabled = False
     torch.backends.cudnn.benchmark = False
     #torch.backends.cudnn.benchmark = True #for accelerating the running
 setup_seed(2019)

2.继续添加如下代码:

tensor_dataset = ImageList(opt.training_list,transform)
def _init_fn(worker_id): 
    random.seed(10 + worker_id)
    np.random.seed(10 + worker_id)
    torch.manual_seed(10 + worker_id)
    torch.cuda.manual_seed(10 + worker_id)
    torch.cuda.manual_seed_all(10 + worker_id)
dataloader = DataLoader(tensor_dataset,                        
                    batch_size=opt.batchSize,     
                    shuffle=True,     
                    num_workers=opt.workers,
                    worker_init_fn=_init_fn)

3.在上面的操作之后发现加载的数据多次试验大部分一致了

但是仍然有些数据是不一致的,后来发现是pytorch版本的问题,将原先的0.3.1版本升级到1.1.0版本,问题解决

4.按照上面的操作后虽然解决了问题

但是由于将cudnn.benchmark设置为False,运行速度降低到原来的1/3,所以继续探索,最终解决方案是把第1步变为如下,同时将该部分代码尽可能放在主程序最开始的部分,例如:

import torch
import torch.nn as nn
from torch.nn import init
import pdb
import torch.nn.parallel
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.utils.data import DataLoader, Dataset
import sys
gpu_id = "3,2"
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
print('GPU: ',gpu_id)
def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     torch.cuda.manual_seed(seed)
     np.random.seed(seed)
     random.seed(seed)
     cudnn.deterministic = True
     #cudnn.benchmark = False
     #cudnn.enabled = False

setup_seed(2019)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python分割文件的常用方法
Nov 01 Python
python实现简单日期工具类
Apr 24 Python
详解Python 切片语法
Jun 10 Python
Python列表的切片实例讲解
Aug 20 Python
Python统计分析模块statistics用法示例
Sep 06 Python
使用pygame编写Flappy bird小游戏
Mar 14 Python
容易被忽略的Python内置类型
Sep 03 Python
python调用win32接口进行截图的示例
Nov 11 Python
python中用Scrapy实现定时爬虫的实例讲解
Jan 18 Python
Python编程super应用场景及示例解析
Oct 05 Python
关于python中模块和重载的问题
Nov 02 Python
Python中super().__init__()测试以及理解
Dec 06 Python
Python Parser的用法
May 12 #Python
pytorch MSELoss计算平均的实现方法
May 12 #Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
Python入门之使用pandas分析excel数据
May 12 #Python
You might like
php smarty 二级分类代码和模版循环例子
2011/06/16 PHP
PHP中的排序函数sort、asort、rsort、krsort、ksort区别分析
2014/08/18 PHP
PHP编写学校网站上新生注册登陆程序的实例分享
2016/03/21 PHP
删除PHP数组中头部、尾部、任意元素的实现代码
2017/04/10 PHP
Laravel框架集成UEditor编辑器的方法图文与实例详解
2019/04/17 PHP
PHP7.3.10编译安装教程
2019/10/08 PHP
javascript JSON操作入门实例
2010/04/16 Javascript
JavaScript实现点击文字切换登录窗口的方法
2015/05/11 Javascript
解决Vue2.0自带浏览器里无法打开的原因(兼容处理)
2017/07/28 Javascript
Vue.js实现输入框绑定的实例代码
2017/08/24 Javascript
JavaScript正则表达式函数总结(常用)
2018/02/22 Javascript
vue中的模态对话框组件实现过程
2018/05/01 Javascript
vue js秒转天数小时分钟秒的实例代码
2018/08/08 Javascript
layui内置模块layim发送图片添加加载动画的方法
2019/09/23 Javascript
vue实现鼠标移过出现下拉二级菜单功能
2019/12/12 Javascript
vue实现移动端input上传视频、音频
2020/08/18 Javascript
微信小程序 接入腾讯地图的两种写法
2021/01/12 Javascript
Python的Django框架中从url中捕捉文本的方法
2015/07/20 Python
Python操作mysql数据库实现增删查改功能的方法
2018/01/15 Python
python中的json总结
2018/10/11 Python
python之消除前缀重命名的方法
2018/10/21 Python
django将数组传递给前台模板的方法
2019/08/06 Python
Python爬虫实现HTTP网络请求多种实现方式
2020/06/19 Python
IE9下html5初试小刀
2010/09/21 HTML / CSS
美国珠宝网上商店:Jeulia
2016/09/01 全球购物
英国网上花店:Bunches
2016/11/29 全球购物
美国高级工作服品牌:Carhartt
2018/01/25 全球购物
微软瑞士官方网站:Microsoft瑞士
2018/04/20 全球购物
双语教学实施方案
2014/03/23 职场文书
态度决定一切演讲稿
2014/05/20 职场文书
面试通知邮件
2015/04/20 职场文书
2015年世界无烟日活动方案
2015/05/04 职场文书
撤诉申请书法院范本
2015/05/18 职场文书
初中运动会闭幕词范本3篇
2019/12/09 职场文书
使用这 6个Vue加载动画库来减少我们网站的跳出率
2021/05/18 Vue.js
解决redis批量删除key值的问题
2022/03/23 Redis