Pytorch的mean和std调查实例


Posted in Python onJanuary 02, 2020

如下所示:

# coding: utf-8

from __future__ import print_function
import copy
import click
import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models, transforms

import matplotlib.pyplot as plt
import load_caffemodel
import scipy.io as sio

# if model has LSTM
# torch.backends.cudnn.enabled = False

imgpath = 'D:/ck/files_detected_face224/'   

imgname = 'S055_002_00000025.png' # anger
image_path = imgpath + imgname

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
raw_image = cv2.imread(image_path)[..., ::-1]
print(raw_image.shape)
raw_image = cv2.resize(raw_image, (224, ) * 2)
image = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(
    mean=mean_file,
    std =std_file,
    #mean = mean_file,
    #std = std_file,
  )
])(raw_image).unsqueeze(0)

print(image.shape)

convert_image1 = image.numpy()
convert_image1 = np.squeeze(convert_image1) # 3* 224 *224, C * H * W
convert_image1 = convert_image1 * np.reshape(std_file,(3,1,1)) + np.reshape(mean_file,(3,1,1))
convert_image1 = np.transpose(convert_image1, (1,2,0)) # H * W * C
print(convert_image1.shape)

convert_image1 = convert_image1 * 255

diff = raw_image - convert_image1
err = np.max(diff)
print(err)
plt.imshow(np.uint8(convert_image1))
plt.show()

结论:

input_image = (raw_image / 255 - mean) ./ std

下面调查均值文件和方差文件是如何生成的:

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
# coding: utf-8
import matplotlib.pyplot as plt
import argparse
import os
import numpy as np
import torchvision
import torchvision.transforms as transforms

dataset_names = ('cifar10','cifar100','mnist')

parser = argparse.ArgumentParser(description='PyTorchLab')
parser.add_argument('-d', '--dataset', metavar='DATA', default='cifar10', choices=dataset_names,
          help='dataset to be used: ' + ' | '.join(dataset_names) + ' (default: cifar10)')

args = parser.parse_args()

data_dir = os.path.join('.', args.dataset)

print(args.dataset)
args.dataset = 'cifar10'
if args.dataset == "cifar10":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(train_set.train_data.mean(axis=(0,1,2))/255)
  print(train_set.train_data.std(axis=(0,1,2))/255)

  # imshow image
  train_data = train_set.train_data
  ind = 100
  img0 = train_data[ind,...]
  ## test channel number, in total , the correct channel is : RGB,not like BGR in caffe
  # error produce
  #b,g,r=cv2.split(img0)
  #img0=cv2.merge([r,g,b])

  print(img0.shape)
  print(type(img0))
  plt.imshow(img0)
  plt.show() # in ship in sea

  #img0 = cv2.resize(img0,(224,224))
  #cv2.imshow('img0',img0)
  #cv2.waitKey()

elif args.dataset == "cifar100":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(np.mean(train_set.train_data, axis=(0,1,2))/255)
  print(np.std(train_set.train_data, axis=(0,1,2))/255)

elif args.dataset == "mnist":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(list(train_set.train_data.size()))
  print(train_set.train_data.float().mean()/255)
  print(train_set.train_data.float().std()/255)

结果:

cifar10
Files already downloaded and verified
(50000, 32, 32, 3)
[ 0.49139968 0.48215841 0.44653091]
[ 0.24703223 0.24348513 0.26158784]
(32, 32, 3)
<class 'numpy.ndarray'>

使用matlab检测是如何计算mean_file和std_file的:

% load cifar10 dataset

data = load('cifar10_train_data.mat');
train_data = data.train_data;
disp(size(train_data));

temp = mean(train_data,1);
disp(size(temp));

train_data = double(train_data);

% compute mean_file 
mean_val = mean(mean(mean(train_data,1),2),3)/255;


% compute std_file 
temp1 = train_data(:,:,:,1);
std_val1 = std(temp1(:))/255;

temp2 = train_data(:,:,:,2);
std_val2 = std(temp2(:))/255;

temp3 = train_data(:,:,:,3);
std_val3 = std(temp3(:))/255;

mean_val = squeeze(mean_val);
std_val = [std_val1, std_val2, std_val3];

disp(mean_val);
disp(std_val);

% result: mean_val: [0.4914, 0.4822, 0.4465]
%     std_val: [0.2470, 0.2435, 0.2616]

均值计算的过程也可以遵循标准差的计算过程。为 了简单,例如对于一个矩阵,所有元素的均值,等于两个方向上先后均值。所以会直接采用如下的形式:

mean_val = mean(mean(mean(train_data,1),2),3)/255;

标准差的计算是每一个通道的对所有样本的求标准差。然后再除以255。

以上这篇Pytorch的mean和std调查实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现将目录中TXT合并成一个大TXT文件的方法
Jul 15 Python
python目录与文件名操作例子
Aug 28 Python
Pycharm编辑器技巧之自动导入模块详解
Jul 18 Python
TensorFlow 合并/连接数组的方法
Jul 27 Python
python爬虫中多线程的使用详解
Sep 23 Python
使用Python串口实时显示数据并绘图的例子
Dec 26 Python
python 实现人和电脑猜拳的示例代码
Mar 02 Python
python selenium自动化测试框架搭建的方法步骤
Jun 14 Python
Python面向对象程序设计之类和对象、实例变量、类变量用法分析
Mar 23 Python
Pycharm操作Git及GitHub的步骤详解
Oct 27 Python
python输出国际象棋棋盘的实例分享
Nov 26 Python
pycharm无法导入lxml的解决办法
Mar 31 Python
pytorch 图像预处理之减去均值,除以方差的实例
Jan 02 #Python
Linux下升级安装python3.8并配置pip及yum的教程
Jan 02 #Python
pytorch实现focal loss的两种方式小结
Jan 02 #Python
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
Jan 02 #Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
You might like
ASP和PHP都是可以删除自身的
2007/04/09 PHP
基于PHP的cURL快速入门教程 (小偷采集程序)
2011/06/02 PHP
Fatal error: Allowed memory size of 134217728 bytes exhausted (tried to allocate 2611816 bytes)
2014/11/08 PHP
PHP基于DOMDocument解析和生成xml的方法分析
2017/07/17 PHP
PHP使用file_get_contents发送http请求功能简单示例
2018/04/29 PHP
php ajax confirm 删除实例详解
2019/03/06 PHP
些很实用且必用的小脚本代码
2006/06/26 Javascript
javascript面向对象之Javascript 继承
2010/05/04 Javascript
13个绚丽的Jquery 界面设计网站推荐
2010/09/28 Javascript
JS解析json数据并将json字符串转化为数组的实现方法
2012/12/25 Javascript
a标签的href与onclick事件的区别详解
2014/11/12 Javascript
JS实现浏览器状态栏文字从右向左弹出效果代码
2015/10/27 Javascript
JavaScript动态插入CSS的方法
2015/12/10 Javascript
js实现纯前端的图片预览
2016/04/27 Javascript
如何利用Promises编写更优雅的JavaScript代码
2016/05/17 Javascript
Vuejs第一篇之入门教程详解(单向绑定、双向绑定、列表渲染、响应函数)
2016/09/09 Javascript
JavaScript文件的同步和异步加载的实现代码
2017/08/19 Javascript
如何在项目中使用log4.js的方法步骤
2019/07/16 Javascript
[45:46]2014 DOTA2国际邀请赛中国区预选赛5.21 HGT VS DT
2014/05/23 DOTA
详解Python3中yield生成器的用法
2015/08/20 Python
python模块之paramiko实例代码
2018/01/31 Python
Django如何将URL映射到视图
2019/07/29 Python
浅谈Python type的使用
2019/11/19 Python
Python函数的定义方式与函数参数问题实例分析
2019/12/26 Python
Windows下PyCharm配置Anaconda环境(超详细教程)
2020/07/31 Python
Python 的 __str__ 和 __repr__ 方法对比
2020/09/02 Python
乌克兰设计师和品牌的服装:Love&Live
2020/04/14 全球购物
介绍一下.NET构架下remoting和webservice
2014/05/08 面试题
大学生护理专业自荐信
2013/10/03 职场文书
六月份红领巾广播稿
2014/02/03 职场文书
春季运动会广播稿大全
2014/02/19 职场文书
高考备战决心书
2014/03/11 职场文书
入党自我鉴定
2014/03/25 职场文书
学历公证委托书
2014/04/09 职场文书
公司的门卫岗位职责
2014/09/09 职场文书
怎样写好演讲稿题目?
2019/08/21 职场文书