Pytorch的mean和std调查实例


Posted in Python onJanuary 02, 2020

如下所示:

# coding: utf-8

from __future__ import print_function
import copy
import click
import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models, transforms

import matplotlib.pyplot as plt
import load_caffemodel
import scipy.io as sio

# if model has LSTM
# torch.backends.cudnn.enabled = False

imgpath = 'D:/ck/files_detected_face224/'   

imgname = 'S055_002_00000025.png' # anger
image_path = imgpath + imgname

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
raw_image = cv2.imread(image_path)[..., ::-1]
print(raw_image.shape)
raw_image = cv2.resize(raw_image, (224, ) * 2)
image = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(
    mean=mean_file,
    std =std_file,
    #mean = mean_file,
    #std = std_file,
  )
])(raw_image).unsqueeze(0)

print(image.shape)

convert_image1 = image.numpy()
convert_image1 = np.squeeze(convert_image1) # 3* 224 *224, C * H * W
convert_image1 = convert_image1 * np.reshape(std_file,(3,1,1)) + np.reshape(mean_file,(3,1,1))
convert_image1 = np.transpose(convert_image1, (1,2,0)) # H * W * C
print(convert_image1.shape)

convert_image1 = convert_image1 * 255

diff = raw_image - convert_image1
err = np.max(diff)
print(err)
plt.imshow(np.uint8(convert_image1))
plt.show()

结论:

input_image = (raw_image / 255 - mean) ./ std

下面调查均值文件和方差文件是如何生成的:

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
# coding: utf-8
import matplotlib.pyplot as plt
import argparse
import os
import numpy as np
import torchvision
import torchvision.transforms as transforms

dataset_names = ('cifar10','cifar100','mnist')

parser = argparse.ArgumentParser(description='PyTorchLab')
parser.add_argument('-d', '--dataset', metavar='DATA', default='cifar10', choices=dataset_names,
          help='dataset to be used: ' + ' | '.join(dataset_names) + ' (default: cifar10)')

args = parser.parse_args()

data_dir = os.path.join('.', args.dataset)

print(args.dataset)
args.dataset = 'cifar10'
if args.dataset == "cifar10":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(train_set.train_data.mean(axis=(0,1,2))/255)
  print(train_set.train_data.std(axis=(0,1,2))/255)

  # imshow image
  train_data = train_set.train_data
  ind = 100
  img0 = train_data[ind,...]
  ## test channel number, in total , the correct channel is : RGB,not like BGR in caffe
  # error produce
  #b,g,r=cv2.split(img0)
  #img0=cv2.merge([r,g,b])

  print(img0.shape)
  print(type(img0))
  plt.imshow(img0)
  plt.show() # in ship in sea

  #img0 = cv2.resize(img0,(224,224))
  #cv2.imshow('img0',img0)
  #cv2.waitKey()

elif args.dataset == "cifar100":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(np.mean(train_set.train_data, axis=(0,1,2))/255)
  print(np.std(train_set.train_data, axis=(0,1,2))/255)

elif args.dataset == "mnist":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(list(train_set.train_data.size()))
  print(train_set.train_data.float().mean()/255)
  print(train_set.train_data.float().std()/255)

结果:

cifar10
Files already downloaded and verified
(50000, 32, 32, 3)
[ 0.49139968 0.48215841 0.44653091]
[ 0.24703223 0.24348513 0.26158784]
(32, 32, 3)
<class 'numpy.ndarray'>

使用matlab检测是如何计算mean_file和std_file的:

% load cifar10 dataset

data = load('cifar10_train_data.mat');
train_data = data.train_data;
disp(size(train_data));

temp = mean(train_data,1);
disp(size(temp));

train_data = double(train_data);

% compute mean_file 
mean_val = mean(mean(mean(train_data,1),2),3)/255;


% compute std_file 
temp1 = train_data(:,:,:,1);
std_val1 = std(temp1(:))/255;

temp2 = train_data(:,:,:,2);
std_val2 = std(temp2(:))/255;

temp3 = train_data(:,:,:,3);
std_val3 = std(temp3(:))/255;

mean_val = squeeze(mean_val);
std_val = [std_val1, std_val2, std_val3];

disp(mean_val);
disp(std_val);

% result: mean_val: [0.4914, 0.4822, 0.4465]
%     std_val: [0.2470, 0.2435, 0.2616]

均值计算的过程也可以遵循标准差的计算过程。为 了简单,例如对于一个矩阵,所有元素的均值,等于两个方向上先后均值。所以会直接采用如下的形式:

mean_val = mean(mean(mean(train_data,1),2),3)/255;

标准差的计算是每一个通道的对所有样本的求标准差。然后再除以255。

以上这篇Pytorch的mean和std调查实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
处理Python中的URLError异常的方法
Apr 30 Python
整理Python 常用string函数(收藏)
May 30 Python
Python实现改变与矩形橡胶的线条的颜色代码示例
Jan 05 Python
pandas多级分组实现排序的方法
Apr 20 Python
python更改已存在excel文件的方法
May 03 Python
python dataframe常见操作方法:实现取行、列、切片、统计特征值
Jun 09 Python
Python面向对象之反射/自省机制实例分析
Aug 24 Python
Django REST framework内置路由用法
Jul 26 Python
Python加密模块的hashlib,hmac模块使用解析
Jan 02 Python
200行python代码实现贪吃蛇游戏
Apr 24 Python
在pytorch中动态调整优化器的学习率方式
Jun 24 Python
python中pow函数用法及功能说明
Dec 04 Python
pytorch 图像预处理之减去均值,除以方差的实例
Jan 02 #Python
Linux下升级安装python3.8并配置pip及yum的教程
Jan 02 #Python
pytorch实现focal loss的两种方式小结
Jan 02 #Python
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
Jan 02 #Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
You might like
PHP 和 XML: 使用expat函数(一)
2006/10/09 PHP
php生成数组的使用示例 php全组合算法
2014/01/16 PHP
微信公众号点击菜单即可打开并登录微站的实现方法
2014/11/14 PHP
深入浅析php json 格式控制
2015/12/24 PHP
微信开发之网页授权获取用户信息(二)
2016/01/08 PHP
php微信公众平台开发(三)订阅事件处理
2016/12/06 PHP
使用Javascript和DOM Interfaces来处理HTML
2006/10/09 Javascript
Extjs学习笔记之一 初识Extjs之MessageBox
2010/01/07 Javascript
javascript字符串替换及字符串分割示例代码
2013/12/12 Javascript
js用Date对象的setDate()函数对日期进行加减操作
2014/09/18 Javascript
jQuery实现彩带延伸效果的网页加载条loading动画
2015/10/29 Javascript
JS Array创建及concat()split()slice()的使用方法
2016/06/03 Javascript
Js 获取当前函数参数对象的实现代码
2016/06/20 Javascript
jQuery实现点击查看大图并以弹框的形式居中
2016/08/08 Javascript
JS判断form内所有表单是否为空的简单实例
2016/09/09 Javascript
Node.js的特点详解
2017/02/03 Javascript
Bootstrap导航条学习使用(一)
2017/02/08 Javascript
ThinkPHP+jquery实现“加载更多”功能代码
2017/03/11 Javascript
使用Nodejs连接mongodb数据库的实现代码
2017/08/21 NodeJs
VUE脚手架的下载和配置步骤详解
2019/04/01 Javascript
微信小程序自定义纯净模态框(弹出框)的实例代码
2020/03/09 Javascript
微信小程序实现多图上传
2020/06/19 Javascript
python binascii 进制转换实例
2019/06/12 Python
使用python将mysql数据库的数据转换为json数据的方法
2019/07/01 Python
python 基于UDP协议套接字通信的实现
2021/01/22 Python
手把手教你用Django执行原生SQL的方法
2021/02/18 Python
HTML5新增的表单元素和属性实例解析
2014/07/07 HTML / CSS
7 For All Mankind官网:美国加州洛杉矶的高级牛仔服装品牌
2018/12/20 全球购物
婚鞋、新娘鞋、礼服鞋、童鞋:Nina Shoes
2019/09/04 全球购物
实习单位接收函
2014/01/11 职场文书
协议书样本
2014/04/23 职场文书
文明单位汇报材料
2014/12/24 职场文书
导游词之河北白洋淀
2020/01/15 职场文书
Java中try catch处理异常示例
2021/12/06 Java/Android
排查Tomcat进程假死的问题
2022/05/06 Servers
一篇文章带你掌握SQLite3基本用法
2022/06/14 数据库