PyTorch 如何设置随机数种子使结果可复现


Posted in Python onMay 12, 2021

由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。

因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。

CUDNN

cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:

from torch.backends import cudnn
cudnn.benchmark = False            # if benchmark=True, deterministic will be False
cudnn.deterministic = True

不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。

Pytorch

torch.manual_seed(seed)            # 为CPU设置随机种子
torch.cuda.manual_seed(seed)       # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)   # 为所有GPU设置随机种子

Python & Numpy

如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。

import random
import numpy as np
random.seed(seed)
np.random.seed(seed)

Dataloader

如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。

也就是说,改变num_workers参数,也会对实验结果产生影响。

目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。

补充:pytorch 固定随机数种子踩过的坑

1.初步固定

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     torch.cuda.manual_seed(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
     torch.backends.cudnn.enabled = False
     torch.backends.cudnn.benchmark = False
     #torch.backends.cudnn.benchmark = True #for accelerating the running
 setup_seed(2019)

2.继续添加如下代码:

tensor_dataset = ImageList(opt.training_list,transform)
def _init_fn(worker_id): 
    random.seed(10 + worker_id)
    np.random.seed(10 + worker_id)
    torch.manual_seed(10 + worker_id)
    torch.cuda.manual_seed(10 + worker_id)
    torch.cuda.manual_seed_all(10 + worker_id)
dataloader = DataLoader(tensor_dataset,                        
                    batch_size=opt.batchSize,     
                    shuffle=True,     
                    num_workers=opt.workers,
                    worker_init_fn=_init_fn)

3.在上面的操作之后发现加载的数据多次试验大部分一致了

但是仍然有些数据是不一致的,后来发现是pytorch版本的问题,将原先的0.3.1版本升级到1.1.0版本,问题解决

4.按照上面的操作后虽然解决了问题

但是由于将cudnn.benchmark设置为False,运行速度降低到原来的1/3,所以继续探索,最终解决方案是把第1步变为如下,同时将该部分代码尽可能放在主程序最开始的部分,例如:

import torch
import torch.nn as nn
from torch.nn import init
import pdb
import torch.nn.parallel
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.utils.data import DataLoader, Dataset
import sys
gpu_id = "3,2"
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
print('GPU: ',gpu_id)
def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     torch.cuda.manual_seed(seed)
     np.random.seed(seed)
     random.seed(seed)
     cudnn.deterministic = True
     #cudnn.benchmark = False
     #cudnn.enabled = False

setup_seed(2019)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Python中的复制操作及copy模块中的浅拷贝与深拷贝方法
Jul 02 Python
用virtualenv建立多个Python独立虚拟开发环境
Jul 06 Python
浅谈python 里面的单下划线与双下划线的区别
Dec 01 Python
python按时间排序目录下的文件实现方法
Oct 17 Python
python使用xlsxwriter实现有向无环图到Excel的转换
Dec 12 Python
基于Python对数据shape的常见操作详解
Dec 25 Python
解决python3 requests headers参数不能有中文的问题
Aug 21 Python
python实现的发邮件功能示例
Sep 11 Python
python实现12306登录并保存cookie的方法示例
Dec 17 Python
Win系统PyQt5安装和使用教程
Dec 25 Python
Python map及filter函数使用方法解析
Aug 06 Python
Django自带用户认证系统使用方法解析
Nov 12 Python
Python Parser的用法
May 12 #Python
pytorch MSELoss计算平均的实现方法
May 12 #Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
Python入门之使用pandas分析excel数据
May 12 #Python
You might like
PHP JSON 数据解析代码
2010/05/26 PHP
PHP CURL获取cookies模拟登录的方法
2013/11/04 PHP
JS面向对象编程浅析
2011/08/28 Javascript
javascript实现tabs选项卡切换效果(自写原生js)
2013/03/19 Javascript
js实现动态改变字体大小代码
2014/01/02 Javascript
node.js中的path.resolve方法使用说明
2014/12/08 Javascript
js完美实现@提到好友特效(兼容各大浏览器)
2015/03/16 Javascript
JS图片定时翻滚效果实现方法
2016/06/21 Javascript
js实现功能比较全面的全选和多选
2017/03/02 Javascript
如何安装控制器JavaScript生成插件详解
2018/10/21 Javascript
深入解析vue 源码目录及构建过程分析
2019/04/24 Javascript
nodeJS与MySQL实现分页数据以及倒序数据
2020/06/05 NodeJs
JS简易计算器实例讲解
2020/06/30 Javascript
[03:49]2016完美“圣”典风云人物:AMS专访
2016/12/06 DOTA
[04:50]2019DOTA2高校联赛秋季赛四强集锦
2019/12/27 DOTA
Python cookbook(数据结构与算法)找出序列中出现次数最多的元素算法示例
2018/03/15 Python
pytorch 更改预训练模型网络结构的方法
2019/08/19 Python
Pytorch中的VGG实现修改最后一层FC
2020/01/15 Python
Django实现列表页商品数据返回教程
2020/04/03 Python
英国二手物品交易网站:Preloved
2017/10/06 全球购物
西班牙电子产品购物网站:Electronicamente
2018/07/26 全球购物
意大利婴儿产品网上商店:Mukako
2018/10/14 全球购物
Magee 1866官网:Donegal粗花呢外套和大衣专家
2019/11/01 全球购物
香港艺人陈冠希创办的潮流品牌:JUICESTORE
2021/03/04 全球购物
华美博弈C/VC工程师笔试试题
2012/07/16 面试题
教师自荐信
2013/12/10 职场文书
主持人婚宴答谢词
2014/01/28 职场文书
党的群众路线学习材料
2014/05/16 职场文书
四风查摆问题及整改措施
2014/10/10 职场文书
2014年便民服务中心工作总结
2014/12/20 职场文书
民间借贷借条范本
2015/05/25 职场文书
关于保护环境的建议书
2019/06/24 职场文书
HTML5中 rem适配方案与 viewport 适配问题详解
2021/04/27 HTML / CSS
详细聊聊浏览器是如何看闭包的
2021/11/11 Javascript
python中pymysql包操作数据库方法
2022/04/19 Python
手把手带你彻底卸载MySQL数据库
2022/06/14 MySQL