Pytorch的mean和std调查实例


Posted in Python onJanuary 02, 2020

如下所示:

# coding: utf-8

from __future__ import print_function
import copy
import click
import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models, transforms

import matplotlib.pyplot as plt
import load_caffemodel
import scipy.io as sio

# if model has LSTM
# torch.backends.cudnn.enabled = False

imgpath = 'D:/ck/files_detected_face224/'   

imgname = 'S055_002_00000025.png' # anger
image_path = imgpath + imgname

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
raw_image = cv2.imread(image_path)[..., ::-1]
print(raw_image.shape)
raw_image = cv2.resize(raw_image, (224, ) * 2)
image = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(
    mean=mean_file,
    std =std_file,
    #mean = mean_file,
    #std = std_file,
  )
])(raw_image).unsqueeze(0)

print(image.shape)

convert_image1 = image.numpy()
convert_image1 = np.squeeze(convert_image1) # 3* 224 *224, C * H * W
convert_image1 = convert_image1 * np.reshape(std_file,(3,1,1)) + np.reshape(mean_file,(3,1,1))
convert_image1 = np.transpose(convert_image1, (1,2,0)) # H * W * C
print(convert_image1.shape)

convert_image1 = convert_image1 * 255

diff = raw_image - convert_image1
err = np.max(diff)
print(err)
plt.imshow(np.uint8(convert_image1))
plt.show()

结论:

input_image = (raw_image / 255 - mean) ./ std

下面调查均值文件和方差文件是如何生成的:

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
# coding: utf-8
import matplotlib.pyplot as plt
import argparse
import os
import numpy as np
import torchvision
import torchvision.transforms as transforms

dataset_names = ('cifar10','cifar100','mnist')

parser = argparse.ArgumentParser(description='PyTorchLab')
parser.add_argument('-d', '--dataset', metavar='DATA', default='cifar10', choices=dataset_names,
          help='dataset to be used: ' + ' | '.join(dataset_names) + ' (default: cifar10)')

args = parser.parse_args()

data_dir = os.path.join('.', args.dataset)

print(args.dataset)
args.dataset = 'cifar10'
if args.dataset == "cifar10":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(train_set.train_data.mean(axis=(0,1,2))/255)
  print(train_set.train_data.std(axis=(0,1,2))/255)

  # imshow image
  train_data = train_set.train_data
  ind = 100
  img0 = train_data[ind,...]
  ## test channel number, in total , the correct channel is : RGB,not like BGR in caffe
  # error produce
  #b,g,r=cv2.split(img0)
  #img0=cv2.merge([r,g,b])

  print(img0.shape)
  print(type(img0))
  plt.imshow(img0)
  plt.show() # in ship in sea

  #img0 = cv2.resize(img0,(224,224))
  #cv2.imshow('img0',img0)
  #cv2.waitKey()

elif args.dataset == "cifar100":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(np.mean(train_set.train_data, axis=(0,1,2))/255)
  print(np.std(train_set.train_data, axis=(0,1,2))/255)

elif args.dataset == "mnist":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(list(train_set.train_data.size()))
  print(train_set.train_data.float().mean()/255)
  print(train_set.train_data.float().std()/255)

结果:

cifar10
Files already downloaded and verified
(50000, 32, 32, 3)
[ 0.49139968 0.48215841 0.44653091]
[ 0.24703223 0.24348513 0.26158784]
(32, 32, 3)
<class 'numpy.ndarray'>

使用matlab检测是如何计算mean_file和std_file的:

% load cifar10 dataset

data = load('cifar10_train_data.mat');
train_data = data.train_data;
disp(size(train_data));

temp = mean(train_data,1);
disp(size(temp));

train_data = double(train_data);

% compute mean_file 
mean_val = mean(mean(mean(train_data,1),2),3)/255;


% compute std_file 
temp1 = train_data(:,:,:,1);
std_val1 = std(temp1(:))/255;

temp2 = train_data(:,:,:,2);
std_val2 = std(temp2(:))/255;

temp3 = train_data(:,:,:,3);
std_val3 = std(temp3(:))/255;

mean_val = squeeze(mean_val);
std_val = [std_val1, std_val2, std_val3];

disp(mean_val);
disp(std_val);

% result: mean_val: [0.4914, 0.4822, 0.4465]
%     std_val: [0.2470, 0.2435, 0.2616]

均值计算的过程也可以遵循标准差的计算过程。为 了简单,例如对于一个矩阵,所有元素的均值,等于两个方向上先后均值。所以会直接采用如下的形式:

mean_val = mean(mean(mean(train_data,1),2),3)/255;

标准差的计算是每一个通道的对所有样本的求标准差。然后再除以255。

以上这篇Pytorch的mean和std调查实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python的Crypto模块实现AES加密实例代码
Jan 22 Python
浅谈python中真正关闭socket的方法
Dec 18 Python
Python3多线程版TCP端口扫描器
Aug 31 Python
python被修饰的函数消失问题解决(基于wraps函数)
Nov 04 Python
安装完Python包然后找不到模块的解决步骤
Feb 13 Python
Win 10下Anaconda虚拟环境的教程
May 18 Python
浅谈pytorch中torch.max和F.softmax函数的维度解释
Jun 28 Python
Python3爬虫中识别图形验证码的实例讲解
Jul 30 Python
详解torch.Tensor的4种乘法
Sep 03 Python
python爬虫调度器用法及实例代码
Nov 30 Python
python 使用xlsxwriter循环向excel中插入数据和图片的操作
Jan 01 Python
Python自然语言处理之切分算法详解
Apr 25 Python
pytorch 图像预处理之减去均值,除以方差的实例
Jan 02 #Python
Linux下升级安装python3.8并配置pip及yum的教程
Jan 02 #Python
pytorch实现focal loss的两种方式小结
Jan 02 #Python
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
Jan 02 #Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
You might like
CodeIgniter框架过滤HTML危险代码
2014/06/12 PHP
PHP+Mysql基于事务处理实现转账功能的方法
2015/07/08 PHP
PHP实现生成模糊图片的方法示例
2017/12/21 PHP
一个报数游戏js版(约瑟夫环问题)
2010/08/05 Javascript
jquery中的查找parents与closest方法之间的区别
2013/12/02 Javascript
JavaScript 学习笔记之基础中的基础
2015/01/13 Javascript
总结Javascript中的隐式类型转换
2016/08/24 Javascript
vue.js实现仿原生ios时间选择组件实例代码
2016/12/21 Javascript
WebView启动支付宝客户端支付失败的问题小结
2017/01/11 Javascript
详解基于angular-cli配置代理解决跨域请求问题
2017/07/05 Javascript
使用JS组件实现带ToolTip验证框的实例代码
2017/08/23 Javascript
微信小程序中setInterval的使用方法
2017/09/29 Javascript
解决webpack+Vue引入iView找不到字体文件的问题
2018/09/28 Javascript
angularjs使用div模拟textarea文本框的方法
2018/10/02 Javascript
[40:04]Secret vs Infamous 2019国际邀请赛淘汰赛 败者组 BO3 第二场 8.23
2019/09/05 DOTA
Python3里的super()和__class__使用介绍
2015/04/23 Python
Python中特殊函数集锦
2015/07/27 Python
python中子类调用父类函数的方法示例
2017/08/18 Python
python的dataframe和matrix的互换方法
2018/04/11 Python
关于python写入文件自动换行的问题
2018/06/23 Python
python实战教程之自动扫雷
2018/07/13 Python
Python ORM框架Peewee用法详解
2020/04/29 Python
使用layui实现左侧菜单栏及动态操作tab项的方法
2020/11/10 HTML / CSS
阿迪达斯比利时官方商城:adidas比利时
2016/10/10 全球购物
纪伊国屋新加坡网上书店:Kinokuniya新加坡
2017/12/29 全球购物
护理专业毕业生自我鉴定
2013/10/08 职场文书
2013年学期结束动员演讲稿
2014/01/07 职场文书
医学院毕业生自荐信范文
2014/03/06 职场文书
竞聘书怎么写,如何写?
2014/03/31 职场文书
学生党员公开承诺书
2014/05/28 职场文书
2014乡镇干部纪律作风整顿思想汇报
2014/09/13 职场文书
庐山导游词
2015/02/03 职场文书
2019个人工作计划书的格式及范文!
2019/07/04 职场文书
MySQL Shell的介绍以及安装
2021/04/24 MySQL
python开发的自动化运维工具ansible详解
2021/08/07 Python
关于MySQL中explain工具的使用
2023/05/08 MySQL