Pytorch的mean和std调查实例


Posted in Python onJanuary 02, 2020

如下所示:

# coding: utf-8

from __future__ import print_function
import copy
import click
import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models, transforms

import matplotlib.pyplot as plt
import load_caffemodel
import scipy.io as sio

# if model has LSTM
# torch.backends.cudnn.enabled = False

imgpath = 'D:/ck/files_detected_face224/'   

imgname = 'S055_002_00000025.png' # anger
image_path = imgpath + imgname

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
raw_image = cv2.imread(image_path)[..., ::-1]
print(raw_image.shape)
raw_image = cv2.resize(raw_image, (224, ) * 2)
image = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(
    mean=mean_file,
    std =std_file,
    #mean = mean_file,
    #std = std_file,
  )
])(raw_image).unsqueeze(0)

print(image.shape)

convert_image1 = image.numpy()
convert_image1 = np.squeeze(convert_image1) # 3* 224 *224, C * H * W
convert_image1 = convert_image1 * np.reshape(std_file,(3,1,1)) + np.reshape(mean_file,(3,1,1))
convert_image1 = np.transpose(convert_image1, (1,2,0)) # H * W * C
print(convert_image1.shape)

convert_image1 = convert_image1 * 255

diff = raw_image - convert_image1
err = np.max(diff)
print(err)
plt.imshow(np.uint8(convert_image1))
plt.show()

结论:

input_image = (raw_image / 255 - mean) ./ std

下面调查均值文件和方差文件是如何生成的:

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
# coding: utf-8
import matplotlib.pyplot as plt
import argparse
import os
import numpy as np
import torchvision
import torchvision.transforms as transforms

dataset_names = ('cifar10','cifar100','mnist')

parser = argparse.ArgumentParser(description='PyTorchLab')
parser.add_argument('-d', '--dataset', metavar='DATA', default='cifar10', choices=dataset_names,
          help='dataset to be used: ' + ' | '.join(dataset_names) + ' (default: cifar10)')

args = parser.parse_args()

data_dir = os.path.join('.', args.dataset)

print(args.dataset)
args.dataset = 'cifar10'
if args.dataset == "cifar10":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(train_set.train_data.mean(axis=(0,1,2))/255)
  print(train_set.train_data.std(axis=(0,1,2))/255)

  # imshow image
  train_data = train_set.train_data
  ind = 100
  img0 = train_data[ind,...]
  ## test channel number, in total , the correct channel is : RGB,not like BGR in caffe
  # error produce
  #b,g,r=cv2.split(img0)
  #img0=cv2.merge([r,g,b])

  print(img0.shape)
  print(type(img0))
  plt.imshow(img0)
  plt.show() # in ship in sea

  #img0 = cv2.resize(img0,(224,224))
  #cv2.imshow('img0',img0)
  #cv2.waitKey()

elif args.dataset == "cifar100":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(np.mean(train_set.train_data, axis=(0,1,2))/255)
  print(np.std(train_set.train_data, axis=(0,1,2))/255)

elif args.dataset == "mnist":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(list(train_set.train_data.size()))
  print(train_set.train_data.float().mean()/255)
  print(train_set.train_data.float().std()/255)

结果:

cifar10
Files already downloaded and verified
(50000, 32, 32, 3)
[ 0.49139968 0.48215841 0.44653091]
[ 0.24703223 0.24348513 0.26158784]
(32, 32, 3)
<class 'numpy.ndarray'>

使用matlab检测是如何计算mean_file和std_file的:

% load cifar10 dataset

data = load('cifar10_train_data.mat');
train_data = data.train_data;
disp(size(train_data));

temp = mean(train_data,1);
disp(size(temp));

train_data = double(train_data);

% compute mean_file 
mean_val = mean(mean(mean(train_data,1),2),3)/255;


% compute std_file 
temp1 = train_data(:,:,:,1);
std_val1 = std(temp1(:))/255;

temp2 = train_data(:,:,:,2);
std_val2 = std(temp2(:))/255;

temp3 = train_data(:,:,:,3);
std_val3 = std(temp3(:))/255;

mean_val = squeeze(mean_val);
std_val = [std_val1, std_val2, std_val3];

disp(mean_val);
disp(std_val);

% result: mean_val: [0.4914, 0.4822, 0.4465]
%     std_val: [0.2470, 0.2435, 0.2616]

均值计算的过程也可以遵循标准差的计算过程。为 了简单,例如对于一个矩阵,所有元素的均值,等于两个方向上先后均值。所以会直接采用如下的形式:

mean_val = mean(mean(mean(train_data,1),2),3)/255;

标准差的计算是每一个通道的对所有样本的求标准差。然后再除以255。

以上这篇Pytorch的mean和std调查实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中处理时间的几种方法小结
Apr 09 Python
django在接受post请求时显示403forbidden实例解析
Jan 25 Python
JS设计模式之责任链模式实例详解
Feb 03 Python
python如何实现反向迭代
Mar 20 Python
python查看模块安装位置的方法
Oct 16 Python
Python实现的银行系统模拟程序完整案例
Apr 12 Python
Python实现带下标索引的遍历操作示例
May 30 Python
pytorch程序异常后删除占用的显存操作
Jan 13 Python
Python稀疏矩阵及参数保存代码实现
Apr 18 Python
Python实现在线批量美颜功能过程解析
Jun 10 Python
python实现马丁策略的实例详解
Jan 15 Python
Python if else条件语句形式详解
Mar 24 Python
pytorch 图像预处理之减去均值,除以方差的实例
Jan 02 #Python
Linux下升级安装python3.8并配置pip及yum的教程
Jan 02 #Python
pytorch实现focal loss的两种方式小结
Jan 02 #Python
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
Jan 02 #Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
You might like
PHP中的生成XML文件的4种方法分享
2012/10/06 PHP
PHP模板引擎Smarty自定义变量调解器用法
2016/04/11 PHP
PHP7新增运算符用法实例分析
2016/09/26 PHP
PHP通过curl获取接口URL的数据方法
2018/05/31 PHP
PHP接入支付宝接口失效流程详解
2020/11/10 PHP
javascript 命名规则 变量命名规则
2010/02/25 Javascript
jQuery中$.each使用详解
2015/01/29 Javascript
javascript中call和apply的用法示例分析
2015/04/02 Javascript
JavaScript中的原型prototype属性使用详解
2015/06/05 Javascript
jquery判断输入密码两次是否相等
2020/04/22 Javascript
js实现页面跳转的五种方法推荐
2016/03/10 Javascript
js调用父框架函数与弹窗调用父页面函数的简单方法
2016/11/01 Javascript
jQuery手指滑动轮播效果
2016/12/22 Javascript
详解nodejs 文本操作模块-fs模块(五)
2016/12/23 NodeJs
过期软件破解办法实例详解
2017/01/04 Javascript
vue2.0 与 bootstrap datetimepicker的结合使用实例
2017/05/22 Javascript
Angular.js前台传list数组由后台spring MVC接收数组示例代码
2017/07/31 Javascript
Vue2.0实现将页面中表格数据导出excel的实例
2017/08/09 Javascript
Layui点击图片弹框预览的实现方法
2019/09/16 Javascript
Windows上node.js的多版本管理工具用法实例分析
2019/11/06 Javascript
[02:04]完美世界城市挑战赛秋季赛报名开始 谁是solo路人王?
2019/10/10 DOTA
Python中类的继承代码实例
2014/10/28 Python
Python代码实现KNN算法
2017/12/20 Python
Python+tkinter使用80行代码实现一个计算器实例
2018/01/16 Python
Python饼状图的绘制实例
2019/01/15 Python
python将字符串转换成json的方法小结
2019/07/09 Python
Python解决pip install时出现的Could not fetch URL问题
2019/08/01 Python
css3学习之2D转换功能详解
2016/12/23 HTML / CSS
科尔士百货公司官网:Kohl’s
2016/07/11 全球购物
吉列剃须刀英国官网:Gillette英国
2019/03/28 全球购物
馥蕾诗美国官网:Fresh美国
2019/10/09 全球购物
俄罗斯苹果优质经销商商店:iPort
2020/05/27 全球购物
医务人员自我评价
2014/01/26 职场文书
黄继光的英雄事迹材料
2014/02/13 职场文书
升学宴家长致辞
2015/07/27 职场文书
JS高级程序设计之class继承重点详解
2022/07/07 Javascript