Pytorch的mean和std调查实例


Posted in Python onJanuary 02, 2020

如下所示:

# coding: utf-8

from __future__ import print_function
import copy
import click
import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models, transforms

import matplotlib.pyplot as plt
import load_caffemodel
import scipy.io as sio

# if model has LSTM
# torch.backends.cudnn.enabled = False

imgpath = 'D:/ck/files_detected_face224/'   

imgname = 'S055_002_00000025.png' # anger
image_path = imgpath + imgname

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
raw_image = cv2.imread(image_path)[..., ::-1]
print(raw_image.shape)
raw_image = cv2.resize(raw_image, (224, ) * 2)
image = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(
    mean=mean_file,
    std =std_file,
    #mean = mean_file,
    #std = std_file,
  )
])(raw_image).unsqueeze(0)

print(image.shape)

convert_image1 = image.numpy()
convert_image1 = np.squeeze(convert_image1) # 3* 224 *224, C * H * W
convert_image1 = convert_image1 * np.reshape(std_file,(3,1,1)) + np.reshape(mean_file,(3,1,1))
convert_image1 = np.transpose(convert_image1, (1,2,0)) # H * W * C
print(convert_image1.shape)

convert_image1 = convert_image1 * 255

diff = raw_image - convert_image1
err = np.max(diff)
print(err)
plt.imshow(np.uint8(convert_image1))
plt.show()

结论:

input_image = (raw_image / 255 - mean) ./ std

下面调查均值文件和方差文件是如何生成的:

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
# coding: utf-8
import matplotlib.pyplot as plt
import argparse
import os
import numpy as np
import torchvision
import torchvision.transforms as transforms

dataset_names = ('cifar10','cifar100','mnist')

parser = argparse.ArgumentParser(description='PyTorchLab')
parser.add_argument('-d', '--dataset', metavar='DATA', default='cifar10', choices=dataset_names,
          help='dataset to be used: ' + ' | '.join(dataset_names) + ' (default: cifar10)')

args = parser.parse_args()

data_dir = os.path.join('.', args.dataset)

print(args.dataset)
args.dataset = 'cifar10'
if args.dataset == "cifar10":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(train_set.train_data.mean(axis=(0,1,2))/255)
  print(train_set.train_data.std(axis=(0,1,2))/255)

  # imshow image
  train_data = train_set.train_data
  ind = 100
  img0 = train_data[ind,...]
  ## test channel number, in total , the correct channel is : RGB,not like BGR in caffe
  # error produce
  #b,g,r=cv2.split(img0)
  #img0=cv2.merge([r,g,b])

  print(img0.shape)
  print(type(img0))
  plt.imshow(img0)
  plt.show() # in ship in sea

  #img0 = cv2.resize(img0,(224,224))
  #cv2.imshow('img0',img0)
  #cv2.waitKey()

elif args.dataset == "cifar100":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(np.mean(train_set.train_data, axis=(0,1,2))/255)
  print(np.std(train_set.train_data, axis=(0,1,2))/255)

elif args.dataset == "mnist":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(list(train_set.train_data.size()))
  print(train_set.train_data.float().mean()/255)
  print(train_set.train_data.float().std()/255)

结果:

cifar10
Files already downloaded and verified
(50000, 32, 32, 3)
[ 0.49139968 0.48215841 0.44653091]
[ 0.24703223 0.24348513 0.26158784]
(32, 32, 3)
<class 'numpy.ndarray'>

使用matlab检测是如何计算mean_file和std_file的:

% load cifar10 dataset

data = load('cifar10_train_data.mat');
train_data = data.train_data;
disp(size(train_data));

temp = mean(train_data,1);
disp(size(temp));

train_data = double(train_data);

% compute mean_file 
mean_val = mean(mean(mean(train_data,1),2),3)/255;


% compute std_file 
temp1 = train_data(:,:,:,1);
std_val1 = std(temp1(:))/255;

temp2 = train_data(:,:,:,2);
std_val2 = std(temp2(:))/255;

temp3 = train_data(:,:,:,3);
std_val3 = std(temp3(:))/255;

mean_val = squeeze(mean_val);
std_val = [std_val1, std_val2, std_val3];

disp(mean_val);
disp(std_val);

% result: mean_val: [0.4914, 0.4822, 0.4465]
%     std_val: [0.2470, 0.2435, 0.2616]

均值计算的过程也可以遵循标准差的计算过程。为 了简单,例如对于一个矩阵,所有元素的均值,等于两个方向上先后均值。所以会直接采用如下的形式:

mean_val = mean(mean(mean(train_data,1),2),3)/255;

标准差的计算是每一个通道的对所有样本的求标准差。然后再除以255。

以上这篇Pytorch的mean和std调查实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现的石头剪子布代码分享
Aug 22 Python
浅谈python中的数字类型与处理工具
Aug 02 Python
python利用paramiko连接远程服务器执行命令的方法
Oct 16 Python
Python中Proxypool库的安装与配置
Oct 19 Python
Python过滤txt文件内重复内容的方法
Oct 21 Python
Python2和Python3.6环境解决共存问题
Nov 09 Python
python之信息加密题目详解
Jun 26 Python
Python3 venv搭建轻量级虚拟环境的步骤(图文)
Aug 09 Python
python编写计算器功能
Oct 25 Python
在python中使用nohup命令说明
Apr 16 Python
python 提高开发效率的5个小技巧
Oct 19 Python
python的列表生成式,生成器和generator对象你了解吗
Mar 16 Python
pytorch 图像预处理之减去均值,除以方差的实例
Jan 02 #Python
Linux下升级安装python3.8并配置pip及yum的教程
Jan 02 #Python
pytorch实现focal loss的两种方式小结
Jan 02 #Python
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
Jan 02 #Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
You might like
MVC模式的PHP实现
2006/10/09 PHP
CI框架中zip类应用示例
2014/06/17 PHP
Thinkphp 在api开发中异常返回依然是html的解决方式
2019/10/16 PHP
一文看懂PHP进程管理器php-fpm
2020/06/01 PHP
event.srcElement+表格应用
2006/08/29 Javascript
js 父窗口控制子窗口的行为-打开,关闭,重定位,回复
2010/04/20 Javascript
基于jquery实现控制经纬度显示地图与卫星
2013/05/20 Javascript
HTML页面滚动时获取离页面顶部的距离2种实现方法
2013/09/05 Javascript
JavaScript中九种常用排序算法
2014/09/02 Javascript
jQuery中animate用法实例分析
2015/03/09 Javascript
JavaScript包装对象使用详解
2015/07/09 Javascript
JavaScript学习笔记之创建对象
2016/03/25 Javascript
js将滚动条滚动到指定位置的简单实现方法
2016/06/25 Javascript
BootStrap Validator对于隐藏域验证和程序赋值即时验证的问题浅析
2016/12/01 Javascript
JS实现的tab切换选项卡效果示例
2017/02/28 Javascript
jQuery插件FusionCharts绘制2D柱状图和折线图的组合图效果示例【附demo源码】
2017/04/10 jQuery
javascript基于牛顿迭代法实现求浮点数的平方根【递归原理】
2017/09/28 Javascript
解决v-for中使用v-if或者v-bind:class失效的问题
2018/09/25 Javascript
JS中使用new Option()实现时间联动效果
2018/12/10 Javascript
vue.js封装switch开关组件的操作
2020/10/26 Javascript
Python实现微信表情包炸群功能
2021/01/28 Python
css3发光搜索表单分享
2014/04/11 HTML / CSS
利用CSS3制作简单的3d半透明立方体图片展示
2017/03/25 HTML / CSS
欧洲顶级的童装奢侈品购物网站:Bambini Fashion(面向全球)
2018/04/24 全球购物
同学聚会老师邀请函
2014/01/28 职场文书
学雷锋演讲稿
2014/03/04 职场文书
公司合作协议书范本
2014/04/18 职场文书
毕业生应聘求职信
2014/07/10 职场文书
社区学习党的群众路线教育实践活动心得体会
2014/11/03 职场文书
个人事迹材料怎么写
2014/12/30 职场文书
助学感谢信范文
2015/01/21 职场文书
结婚通知短信怎么写
2015/04/17 职场文书
劳动仲裁调解书
2015/05/20 职场文书
导游词之蜀山胜景瓦屋山
2019/11/29 职场文书
MySQL之PXC集群搭建的方法步骤
2021/05/25 MySQL
详解OpenCV获取高动态范围(HDR)成像
2022/04/29 Python