Pytorch的mean和std调查实例


Posted in Python onJanuary 02, 2020

如下所示:

# coding: utf-8

from __future__ import print_function
import copy
import click
import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models, transforms

import matplotlib.pyplot as plt
import load_caffemodel
import scipy.io as sio

# if model has LSTM
# torch.backends.cudnn.enabled = False

imgpath = 'D:/ck/files_detected_face224/'   

imgname = 'S055_002_00000025.png' # anger
image_path = imgpath + imgname

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
raw_image = cv2.imread(image_path)[..., ::-1]
print(raw_image.shape)
raw_image = cv2.resize(raw_image, (224, ) * 2)
image = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(
    mean=mean_file,
    std =std_file,
    #mean = mean_file,
    #std = std_file,
  )
])(raw_image).unsqueeze(0)

print(image.shape)

convert_image1 = image.numpy()
convert_image1 = np.squeeze(convert_image1) # 3* 224 *224, C * H * W
convert_image1 = convert_image1 * np.reshape(std_file,(3,1,1)) + np.reshape(mean_file,(3,1,1))
convert_image1 = np.transpose(convert_image1, (1,2,0)) # H * W * C
print(convert_image1.shape)

convert_image1 = convert_image1 * 255

diff = raw_image - convert_image1
err = np.max(diff)
print(err)
plt.imshow(np.uint8(convert_image1))
plt.show()

结论:

input_image = (raw_image / 255 - mean) ./ std

下面调查均值文件和方差文件是如何生成的:

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
# coding: utf-8
import matplotlib.pyplot as plt
import argparse
import os
import numpy as np
import torchvision
import torchvision.transforms as transforms

dataset_names = ('cifar10','cifar100','mnist')

parser = argparse.ArgumentParser(description='PyTorchLab')
parser.add_argument('-d', '--dataset', metavar='DATA', default='cifar10', choices=dataset_names,
          help='dataset to be used: ' + ' | '.join(dataset_names) + ' (default: cifar10)')

args = parser.parse_args()

data_dir = os.path.join('.', args.dataset)

print(args.dataset)
args.dataset = 'cifar10'
if args.dataset == "cifar10":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(train_set.train_data.mean(axis=(0,1,2))/255)
  print(train_set.train_data.std(axis=(0,1,2))/255)

  # imshow image
  train_data = train_set.train_data
  ind = 100
  img0 = train_data[ind,...]
  ## test channel number, in total , the correct channel is : RGB,not like BGR in caffe
  # error produce
  #b,g,r=cv2.split(img0)
  #img0=cv2.merge([r,g,b])

  print(img0.shape)
  print(type(img0))
  plt.imshow(img0)
  plt.show() # in ship in sea

  #img0 = cv2.resize(img0,(224,224))
  #cv2.imshow('img0',img0)
  #cv2.waitKey()

elif args.dataset == "cifar100":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(np.mean(train_set.train_data, axis=(0,1,2))/255)
  print(np.std(train_set.train_data, axis=(0,1,2))/255)

elif args.dataset == "mnist":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(list(train_set.train_data.size()))
  print(train_set.train_data.float().mean()/255)
  print(train_set.train_data.float().std()/255)

结果:

cifar10
Files already downloaded and verified
(50000, 32, 32, 3)
[ 0.49139968 0.48215841 0.44653091]
[ 0.24703223 0.24348513 0.26158784]
(32, 32, 3)
<class 'numpy.ndarray'>

使用matlab检测是如何计算mean_file和std_file的:

% load cifar10 dataset

data = load('cifar10_train_data.mat');
train_data = data.train_data;
disp(size(train_data));

temp = mean(train_data,1);
disp(size(temp));

train_data = double(train_data);

% compute mean_file 
mean_val = mean(mean(mean(train_data,1),2),3)/255;


% compute std_file 
temp1 = train_data(:,:,:,1);
std_val1 = std(temp1(:))/255;

temp2 = train_data(:,:,:,2);
std_val2 = std(temp2(:))/255;

temp3 = train_data(:,:,:,3);
std_val3 = std(temp3(:))/255;

mean_val = squeeze(mean_val);
std_val = [std_val1, std_val2, std_val3];

disp(mean_val);
disp(std_val);

% result: mean_val: [0.4914, 0.4822, 0.4465]
%     std_val: [0.2470, 0.2435, 0.2616]

均值计算的过程也可以遵循标准差的计算过程。为 了简单,例如对于一个矩阵,所有元素的均值,等于两个方向上先后均值。所以会直接采用如下的形式:

mean_val = mean(mean(mean(train_data,1),2),3)/255;

标准差的计算是每一个通道的对所有样本的求标准差。然后再除以255。

以上这篇Pytorch的mean和std调查实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现的简单窗口倒计时界面实例
May 05 Python
python下setuptools的安装详解及No module named setuptools的解决方法
Jul 06 Python
python可视化爬虫界面之天气查询
Jul 03 Python
python自带tkinter库实现棋盘覆盖图形界面
Jul 17 Python
Django的Modelforms用法简介
Jul 27 Python
Python实现队列的方法示例小结【数组,链表】
Feb 22 Python
PyQt5事件处理之定时在控件上显示信息的代码
Mar 25 Python
Django用户身份验证完成示例代码
Apr 03 Python
Python中zip函数如何使用
Jun 04 Python
pytorch中的weight-initilzation用法
Jun 24 Python
浅谈优化Django ORM中的性能问题
Jul 09 Python
python如何绘制疫情图
Sep 16 Python
pytorch 图像预处理之减去均值,除以方差的实例
Jan 02 #Python
Linux下升级安装python3.8并配置pip及yum的教程
Jan 02 #Python
pytorch实现focal loss的两种方式小结
Jan 02 #Python
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
Jan 02 #Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
You might like
php数组函数序列之array_intersect() 返回两个或多个数组的交集数组
2011/11/10 PHP
PHP数字和字符串ID互转函数(类似优酷ID)
2014/06/30 PHP
destoon实现首页显示供应、企业、资讯条数的方法
2014/07/15 PHP
PHP中魔术变量__METHOD__与__FUNCTION__的区别
2014/09/29 PHP
JQuery给元素绑定click事件多次执行的解决方法
2014/05/29 Javascript
JavaScript实现的一个计算数字步数的算法分享
2014/12/06 Javascript
js实时获取窗口大小变化的实例代码
2016/11/18 Javascript
jQuery事件绑定方法学习总结(推荐)
2016/11/21 Javascript
详解vue跨组件通信的几种方法
2017/06/15 Javascript
微信小程序canvas实现刮刮乐效果
2018/07/09 Javascript
简述JS控制台的使用
2018/07/15 Javascript
jQuery实现的响应鼠标移动方向插件用法示例【附源码下载】
2018/08/28 jQuery
vue+axios实现文件下载及vue中使用axios的实例
2018/09/21 Javascript
Element-ui DatePicker显示周数的方法示例
2019/07/19 Javascript
Jquery Datatables的使用详解
2020/01/30 jQuery
JavaScript将数组转换为链表的方法
2020/02/16 Javascript
OpenLayers3加载常用控件使用方法详解
2020/09/25 Javascript
[01:24]2014DOTA2 TI第二日 YYF表示这届谁赢都有可能
2014/07/11 DOTA
python3+PyQt5实现文档打印功能
2018/04/24 Python
python环形单链表的约瑟夫问题详解
2018/09/27 Python
Opencv+Python 色彩通道拆分及合并的示例
2018/12/08 Python
Python HTML解析模块HTMLParser用法分析【爬虫工具】
2019/04/05 Python
使用PYTHON解析Wireshark的PCAP文件方法
2019/07/23 Python
wxPython实现绘图小例子
2019/11/19 Python
python中导入 train_test_split提示错误的解决
2020/06/19 Python
SmartBuyGlasses比利时:购买品牌太阳镜和眼镜
2019/08/09 全球购物
施华洛世奇新加坡官网:SWAROVSKI新加坡
2020/10/06 全球购物
列车长先进事迹材料
2014/01/25 职场文书
加拿大留学自荐信
2014/01/28 职场文书
大家访活动实施方案
2014/03/10 职场文书
学习党的群众路线实践活动思想汇报
2014/09/12 职场文书
贫困证明模板(3篇)
2014/09/16 职场文书
学习朴航瑛老师爱岗敬业先进事迹思想汇报
2014/09/17 职场文书
防卫过当辩护词
2015/05/21 职场文书
2015年“我们的节日·中秋节”活动总结
2015/07/30 职场文书
2016护理专业求职自荐书
2016/01/28 职场文书