Pytorch的mean和std调查实例


Posted in Python onJanuary 02, 2020

如下所示:

# coding: utf-8

from __future__ import print_function
import copy
import click
import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models, transforms

import matplotlib.pyplot as plt
import load_caffemodel
import scipy.io as sio

# if model has LSTM
# torch.backends.cudnn.enabled = False

imgpath = 'D:/ck/files_detected_face224/'   

imgname = 'S055_002_00000025.png' # anger
image_path = imgpath + imgname

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
raw_image = cv2.imread(image_path)[..., ::-1]
print(raw_image.shape)
raw_image = cv2.resize(raw_image, (224, ) * 2)
image = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(
    mean=mean_file,
    std =std_file,
    #mean = mean_file,
    #std = std_file,
  )
])(raw_image).unsqueeze(0)

print(image.shape)

convert_image1 = image.numpy()
convert_image1 = np.squeeze(convert_image1) # 3* 224 *224, C * H * W
convert_image1 = convert_image1 * np.reshape(std_file,(3,1,1)) + np.reshape(mean_file,(3,1,1))
convert_image1 = np.transpose(convert_image1, (1,2,0)) # H * W * C
print(convert_image1.shape)

convert_image1 = convert_image1 * 255

diff = raw_image - convert_image1
err = np.max(diff)
print(err)
plt.imshow(np.uint8(convert_image1))
plt.show()

结论:

input_image = (raw_image / 255 - mean) ./ std

下面调查均值文件和方差文件是如何生成的:

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
# coding: utf-8
import matplotlib.pyplot as plt
import argparse
import os
import numpy as np
import torchvision
import torchvision.transforms as transforms

dataset_names = ('cifar10','cifar100','mnist')

parser = argparse.ArgumentParser(description='PyTorchLab')
parser.add_argument('-d', '--dataset', metavar='DATA', default='cifar10', choices=dataset_names,
          help='dataset to be used: ' + ' | '.join(dataset_names) + ' (default: cifar10)')

args = parser.parse_args()

data_dir = os.path.join('.', args.dataset)

print(args.dataset)
args.dataset = 'cifar10'
if args.dataset == "cifar10":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(train_set.train_data.mean(axis=(0,1,2))/255)
  print(train_set.train_data.std(axis=(0,1,2))/255)

  # imshow image
  train_data = train_set.train_data
  ind = 100
  img0 = train_data[ind,...]
  ## test channel number, in total , the correct channel is : RGB,not like BGR in caffe
  # error produce
  #b,g,r=cv2.split(img0)
  #img0=cv2.merge([r,g,b])

  print(img0.shape)
  print(type(img0))
  plt.imshow(img0)
  plt.show() # in ship in sea

  #img0 = cv2.resize(img0,(224,224))
  #cv2.imshow('img0',img0)
  #cv2.waitKey()

elif args.dataset == "cifar100":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(np.mean(train_set.train_data, axis=(0,1,2))/255)
  print(np.std(train_set.train_data, axis=(0,1,2))/255)

elif args.dataset == "mnist":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(list(train_set.train_data.size()))
  print(train_set.train_data.float().mean()/255)
  print(train_set.train_data.float().std()/255)

结果:

cifar10
Files already downloaded and verified
(50000, 32, 32, 3)
[ 0.49139968 0.48215841 0.44653091]
[ 0.24703223 0.24348513 0.26158784]
(32, 32, 3)
<class 'numpy.ndarray'>

使用matlab检测是如何计算mean_file和std_file的:

% load cifar10 dataset

data = load('cifar10_train_data.mat');
train_data = data.train_data;
disp(size(train_data));

temp = mean(train_data,1);
disp(size(temp));

train_data = double(train_data);

% compute mean_file 
mean_val = mean(mean(mean(train_data,1),2),3)/255;


% compute std_file 
temp1 = train_data(:,:,:,1);
std_val1 = std(temp1(:))/255;

temp2 = train_data(:,:,:,2);
std_val2 = std(temp2(:))/255;

temp3 = train_data(:,:,:,3);
std_val3 = std(temp3(:))/255;

mean_val = squeeze(mean_val);
std_val = [std_val1, std_val2, std_val3];

disp(mean_val);
disp(std_val);

% result: mean_val: [0.4914, 0.4822, 0.4465]
%     std_val: [0.2470, 0.2435, 0.2616]

均值计算的过程也可以遵循标准差的计算过程。为 了简单,例如对于一个矩阵,所有元素的均值,等于两个方向上先后均值。所以会直接采用如下的形式:

mean_val = mean(mean(mean(train_data,1),2),3)/255;

标准差的计算是每一个通道的对所有样本的求标准差。然后再除以255。

以上这篇Pytorch的mean和std调查实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
PHP webshell检查工具 python实现代码
Sep 15 Python
python中查找excel某一列的重复数据 剔除之后打印
Feb 10 Python
Python编程中使用Pillow来处理图像的基础教程
Nov 20 Python
Python常用算法学习基础教程
Apr 13 Python
django实现同一个ip十分钟内只能注册一次的实例
Nov 03 Python
Python3结合Dlib实现人脸识别和剪切
Jan 24 Python
idea创建springMVC框架和配置小文件的教程图解
Sep 18 Python
python找出一个列表中相同元素的多个索引实例
Jun 11 Python
python爬虫开发之使用python爬虫库requests,urllib与今日头条搜索功能爬取搜索内容实例
Mar 10 Python
Python在字符串中处理html和xml的方法
Jul 31 Python
Python批量删除mysql中千万级大量数据的脚本分享
Dec 03 Python
如何用Python徒手写线性回归
Jan 25 Python
pytorch 图像预处理之减去均值,除以方差的实例
Jan 02 #Python
Linux下升级安装python3.8并配置pip及yum的教程
Jan 02 #Python
pytorch实现focal loss的两种方式小结
Jan 02 #Python
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
Jan 02 #Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
You might like
php查看session内容的函数
2008/08/27 PHP
PhpMyAdmin中无法导入sql文件的解决办法
2010/01/08 PHP
php中可能用来加密字符串的函数[base64_encode、urlencode、sha1]
2012/01/16 PHP
优化PHP代码技巧的小结
2013/06/02 PHP
php设计模式之单例、多例设计模式的应用分析
2013/06/30 PHP
PHP入门经历和学习过程分享
2014/04/11 PHP
jQuery bxCarousel实现图片滚动切换效果示例代码
2013/05/15 Javascript
Node.js中使用mongoskin操作mongoDB实例
2014/09/28 Javascript
BOOTSTRAP时间控件显示在模态框下面的bug修复
2015/02/05 Javascript
JavaScript实现跨浏览器的添加及删除事件绑定函数实例
2015/08/04 Javascript
ES6中Math对象的部分扩展
2017/02/20 Javascript
Bootstrap表单简单实现代码
2017/03/06 Javascript
JS基于正则表达式的替换操作(replace)用法示例
2017/04/28 Javascript
详解刷新页面vuex数据不消失和不跳转页面的解决
2018/01/30 Javascript
深入理解Vue 组件之间传值
2018/08/16 Javascript
Vue实现数据请求拦截
2019/10/23 Javascript
在Django中创建动态视图的教程
2015/07/15 Python
Pycharm远程调试openstack的方法
2017/11/21 Python
通过Py2exe将自己的python程序打包成.exe/.app的方法
2018/05/26 Python
Python实现自定义函数的5种常见形式分析
2018/06/16 Python
python绘制直线的方法
2018/06/30 Python
Python寻找路径和查找文件路径的示例
2019/07/10 Python
django框架模板语言使用方法详解
2019/07/18 Python
Python SQLAlchemy入门教程(基本用法)
2019/11/11 Python
windows、linux下打包Python3程序详细方法
2020/03/17 Python
Python3.9.0 a1安装pygame出错解决全过程(小结)
2021/02/02 Python
html5 冒号分隔符对齐的实现
2019/07/31 HTML / CSS
李维斯德国官方网上商店:Levi’s德国
2016/09/10 全球购物
英国领先的汽车轮胎和快速健康中心:Kwik Fit
2017/10/29 全球购物
捷克浴室和厨房设备购物网站:SIKO
2018/08/11 全球购物
英文简历中的自荐信范文
2013/12/14 职场文书
人力资源部培训专员岗位职责
2014/01/02 职场文书
2014政务公开实施方案
2014/02/19 职场文书
成龙霸王洗发水广告词
2014/03/14 职场文书
三月法制宣传月活动总结
2014/07/03 职场文书
投标承诺函格式
2015/01/21 职场文书