Pytorch的mean和std调查实例


Posted in Python onJanuary 02, 2020

如下所示:

# coding: utf-8

from __future__ import print_function
import copy
import click
import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models, transforms

import matplotlib.pyplot as plt
import load_caffemodel
import scipy.io as sio

# if model has LSTM
# torch.backends.cudnn.enabled = False

imgpath = 'D:/ck/files_detected_face224/'   

imgname = 'S055_002_00000025.png' # anger
image_path = imgpath + imgname

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
raw_image = cv2.imread(image_path)[..., ::-1]
print(raw_image.shape)
raw_image = cv2.resize(raw_image, (224, ) * 2)
image = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(
    mean=mean_file,
    std =std_file,
    #mean = mean_file,
    #std = std_file,
  )
])(raw_image).unsqueeze(0)

print(image.shape)

convert_image1 = image.numpy()
convert_image1 = np.squeeze(convert_image1) # 3* 224 *224, C * H * W
convert_image1 = convert_image1 * np.reshape(std_file,(3,1,1)) + np.reshape(mean_file,(3,1,1))
convert_image1 = np.transpose(convert_image1, (1,2,0)) # H * W * C
print(convert_image1.shape)

convert_image1 = convert_image1 * 255

diff = raw_image - convert_image1
err = np.max(diff)
print(err)
plt.imshow(np.uint8(convert_image1))
plt.show()

结论:

input_image = (raw_image / 255 - mean) ./ std

下面调查均值文件和方差文件是如何生成的:

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
# coding: utf-8
import matplotlib.pyplot as plt
import argparse
import os
import numpy as np
import torchvision
import torchvision.transforms as transforms

dataset_names = ('cifar10','cifar100','mnist')

parser = argparse.ArgumentParser(description='PyTorchLab')
parser.add_argument('-d', '--dataset', metavar='DATA', default='cifar10', choices=dataset_names,
          help='dataset to be used: ' + ' | '.join(dataset_names) + ' (default: cifar10)')

args = parser.parse_args()

data_dir = os.path.join('.', args.dataset)

print(args.dataset)
args.dataset = 'cifar10'
if args.dataset == "cifar10":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(train_set.train_data.mean(axis=(0,1,2))/255)
  print(train_set.train_data.std(axis=(0,1,2))/255)

  # imshow image
  train_data = train_set.train_data
  ind = 100
  img0 = train_data[ind,...]
  ## test channel number, in total , the correct channel is : RGB,not like BGR in caffe
  # error produce
  #b,g,r=cv2.split(img0)
  #img0=cv2.merge([r,g,b])

  print(img0.shape)
  print(type(img0))
  plt.imshow(img0)
  plt.show() # in ship in sea

  #img0 = cv2.resize(img0,(224,224))
  #cv2.imshow('img0',img0)
  #cv2.waitKey()

elif args.dataset == "cifar100":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(np.mean(train_set.train_data, axis=(0,1,2))/255)
  print(np.std(train_set.train_data, axis=(0,1,2))/255)

elif args.dataset == "mnist":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(list(train_set.train_data.size()))
  print(train_set.train_data.float().mean()/255)
  print(train_set.train_data.float().std()/255)

结果:

cifar10
Files already downloaded and verified
(50000, 32, 32, 3)
[ 0.49139968 0.48215841 0.44653091]
[ 0.24703223 0.24348513 0.26158784]
(32, 32, 3)
<class 'numpy.ndarray'>

使用matlab检测是如何计算mean_file和std_file的:

% load cifar10 dataset

data = load('cifar10_train_data.mat');
train_data = data.train_data;
disp(size(train_data));

temp = mean(train_data,1);
disp(size(temp));

train_data = double(train_data);

% compute mean_file 
mean_val = mean(mean(mean(train_data,1),2),3)/255;


% compute std_file 
temp1 = train_data(:,:,:,1);
std_val1 = std(temp1(:))/255;

temp2 = train_data(:,:,:,2);
std_val2 = std(temp2(:))/255;

temp3 = train_data(:,:,:,3);
std_val3 = std(temp3(:))/255;

mean_val = squeeze(mean_val);
std_val = [std_val1, std_val2, std_val3];

disp(mean_val);
disp(std_val);

% result: mean_val: [0.4914, 0.4822, 0.4465]
%     std_val: [0.2470, 0.2435, 0.2616]

均值计算的过程也可以遵循标准差的计算过程。为 了简单,例如对于一个矩阵,所有元素的均值,等于两个方向上先后均值。所以会直接采用如下的形式:

mean_val = mean(mean(mean(train_data,1),2),3)/255;

标准差的计算是每一个通道的对所有样本的求标准差。然后再除以255。

以上这篇Pytorch的mean和std调查实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现的中国剩余定理算法示例
Aug 05 Python
Pandas探索之高性能函数eval和query解析
Oct 28 Python
python-opencv颜色提取分割方法
Dec 08 Python
使用APScheduler3.0.1 实现定时任务的方法
Jul 22 Python
Python定义函数时参数有默认值问题解决
Dec 19 Python
Python实现银行账户资金交易管理系统
Jan 03 Python
40行Python代码实现天气预报和每日鸡汤推送功能
Feb 27 Python
Python如何使用bokeh包和geojson数据绘制地图
Mar 21 Python
python访问hdfs的操作
Jun 06 Python
Python偏函数Partial function使用方法实例详解
Jun 17 Python
详解python datetime模块
Aug 17 Python
使用Python获取爱奇艺电视剧弹幕数据的示例代码
Jan 12 Python
pytorch 图像预处理之减去均值,除以方差的实例
Jan 02 #Python
Linux下升级安装python3.8并配置pip及yum的教程
Jan 02 #Python
pytorch实现focal loss的两种方式小结
Jan 02 #Python
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
Jan 02 #Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
You might like
模拟xcopy的函数
2006/10/09 PHP
php 攻击方法之谈php+mysql注射语句构造
2009/10/30 PHP
ThinkPHP模板判断输出Present标签用法详解
2014/06/30 PHP
php注册系统和使用Xajax即时验证用户名是否被占用
2017/08/31 PHP
PHP面向对象程序设计重载(overloading)操作详解
2019/06/13 PHP
用JavaScrpt实现文件夹简单轻松加密的实现方法图文
2008/09/08 Javascript
Javascript无阻塞加载具体方式
2013/06/28 Javascript
javaScript中两个等于号和三个等于号之间的区别介绍
2014/06/27 Javascript
javascript在当前窗口关闭前检测窗口是否关闭
2014/09/29 Javascript
Eclipse配置Javascript开发环境图文教程
2015/01/29 Javascript
Vue.js第四天学习笔记
2016/12/02 Javascript
JQuery.validationEngine表单验证插件(推荐)
2016/12/10 Javascript
Angular使用$http.jsonp发送跨站请求的方法
2017/03/16 Javascript
使用ef6创建oracle数据库的实体模型遇到的问题及解决方案
2017/11/09 Javascript
p5.js入门教程之图片加载
2018/03/20 Javascript
详解VS Code使用之Vue工程配置format代码格式化
2019/03/20 Javascript
微信小程序图片加载失败时替换为默认图片的方法
2019/12/09 Javascript
详解微信小程序中var、let、const用法与区别
2020/01/11 Javascript
如何在JavaScript中正确处理变量
2020/12/25 Javascript
[01:01:52]DOTA2-DPC中国联赛正赛 iG vs LBZS BO3 第一场 3月4日
2021/03/11 DOTA
Python中声明只包含一个元素的元组数据方法
2014/08/25 Python
最大K个数问题的Python版解法总结
2016/06/16 Python
详解Python中heapq模块的用法
2016/06/28 Python
python微信撤回监测代码
2019/04/29 Python
Python字典深浅拷贝与循环方式方法详解
2020/02/09 Python
HTML5学习笔记之html5与传统html区别
2016/01/06 HTML / CSS
利用Storage Event实现页面间通信的示例代码
2018/07/26 HTML / CSS
德国狗狗用品在线商店:Schecker
2017/03/17 全球购物
无谷物狗粮:Pooch & Mutt
2018/05/23 全球购物
海南地接欢迎词
2014/01/14 职场文书
电气自动化求职信
2014/06/24 职场文书
法学院毕业生求职信
2014/06/25 职场文书
家庭教育的心得体会
2014/09/01 职场文书
竞选学习委员演讲稿
2014/09/01 职场文书
2014年食品安全工作总结
2014/12/04 职场文书
css3 文字断裂效果
2022/04/22 HTML / CSS