Pytorch的mean和std调查实例


Posted in Python onJanuary 02, 2020

如下所示:

# coding: utf-8

from __future__ import print_function
import copy
import click
import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models, transforms

import matplotlib.pyplot as plt
import load_caffemodel
import scipy.io as sio

# if model has LSTM
# torch.backends.cudnn.enabled = False

imgpath = 'D:/ck/files_detected_face224/'   

imgname = 'S055_002_00000025.png' # anger
image_path = imgpath + imgname

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
raw_image = cv2.imread(image_path)[..., ::-1]
print(raw_image.shape)
raw_image = cv2.resize(raw_image, (224, ) * 2)
image = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(
    mean=mean_file,
    std =std_file,
    #mean = mean_file,
    #std = std_file,
  )
])(raw_image).unsqueeze(0)

print(image.shape)

convert_image1 = image.numpy()
convert_image1 = np.squeeze(convert_image1) # 3* 224 *224, C * H * W
convert_image1 = convert_image1 * np.reshape(std_file,(3,1,1)) + np.reshape(mean_file,(3,1,1))
convert_image1 = np.transpose(convert_image1, (1,2,0)) # H * W * C
print(convert_image1.shape)

convert_image1 = convert_image1 * 255

diff = raw_image - convert_image1
err = np.max(diff)
print(err)
plt.imshow(np.uint8(convert_image1))
plt.show()

结论:

input_image = (raw_image / 255 - mean) ./ std

下面调查均值文件和方差文件是如何生成的:

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
# coding: utf-8
import matplotlib.pyplot as plt
import argparse
import os
import numpy as np
import torchvision
import torchvision.transforms as transforms

dataset_names = ('cifar10','cifar100','mnist')

parser = argparse.ArgumentParser(description='PyTorchLab')
parser.add_argument('-d', '--dataset', metavar='DATA', default='cifar10', choices=dataset_names,
          help='dataset to be used: ' + ' | '.join(dataset_names) + ' (default: cifar10)')

args = parser.parse_args()

data_dir = os.path.join('.', args.dataset)

print(args.dataset)
args.dataset = 'cifar10'
if args.dataset == "cifar10":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(train_set.train_data.mean(axis=(0,1,2))/255)
  print(train_set.train_data.std(axis=(0,1,2))/255)

  # imshow image
  train_data = train_set.train_data
  ind = 100
  img0 = train_data[ind,...]
  ## test channel number, in total , the correct channel is : RGB,not like BGR in caffe
  # error produce
  #b,g,r=cv2.split(img0)
  #img0=cv2.merge([r,g,b])

  print(img0.shape)
  print(type(img0))
  plt.imshow(img0)
  plt.show() # in ship in sea

  #img0 = cv2.resize(img0,(224,224))
  #cv2.imshow('img0',img0)
  #cv2.waitKey()

elif args.dataset == "cifar100":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(np.mean(train_set.train_data, axis=(0,1,2))/255)
  print(np.std(train_set.train_data, axis=(0,1,2))/255)

elif args.dataset == "mnist":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(list(train_set.train_data.size()))
  print(train_set.train_data.float().mean()/255)
  print(train_set.train_data.float().std()/255)

结果:

cifar10
Files already downloaded and verified
(50000, 32, 32, 3)
[ 0.49139968 0.48215841 0.44653091]
[ 0.24703223 0.24348513 0.26158784]
(32, 32, 3)
<class 'numpy.ndarray'>

使用matlab检测是如何计算mean_file和std_file的:

% load cifar10 dataset

data = load('cifar10_train_data.mat');
train_data = data.train_data;
disp(size(train_data));

temp = mean(train_data,1);
disp(size(temp));

train_data = double(train_data);

% compute mean_file 
mean_val = mean(mean(mean(train_data,1),2),3)/255;


% compute std_file 
temp1 = train_data(:,:,:,1);
std_val1 = std(temp1(:))/255;

temp2 = train_data(:,:,:,2);
std_val2 = std(temp2(:))/255;

temp3 = train_data(:,:,:,3);
std_val3 = std(temp3(:))/255;

mean_val = squeeze(mean_val);
std_val = [std_val1, std_val2, std_val3];

disp(mean_val);
disp(std_val);

% result: mean_val: [0.4914, 0.4822, 0.4465]
%     std_val: [0.2470, 0.2435, 0.2616]

均值计算的过程也可以遵循标准差的计算过程。为 了简单,例如对于一个矩阵,所有元素的均值,等于两个方向上先后均值。所以会直接采用如下的形式:

mean_val = mean(mean(mean(train_data,1),2),3)/255;

标准差的计算是每一个通道的对所有样本的求标准差。然后再除以255。

以上这篇Pytorch的mean和std调查实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python遍历目录中的所有文件的方法
Jul 08 Python
Python 窗体(tkinter)按钮 位置实例
Jun 13 Python
python使用 zip 同时迭代多个序列示例
Jul 06 Python
Python学习笔记之While循环用法分析
Aug 14 Python
python 爬取古诗文存入mysql数据库的方法
Jan 08 Python
Python3打包exe代码2种方法实例解析
Feb 17 Python
Python importlib动态导入模块实现代码
Apr 16 Python
聊聊python中的异常嵌套
Sep 01 Python
浅析python函数式编程
Sep 26 Python
python将图片转为矢量图的方法步骤
Mar 30 Python
Python torch.flatten()函数案例详解
Aug 30 Python
python和C/C++混合编程之使用ctypes调用 C/C++的dll
Apr 29 Python
pytorch 图像预处理之减去均值,除以方差的实例
Jan 02 #Python
Linux下升级安装python3.8并配置pip及yum的教程
Jan 02 #Python
pytorch实现focal loss的两种方式小结
Jan 02 #Python
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
Jan 02 #Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
You might like
我的论坛源代码(三)
2006/10/09 PHP
php mssql 分页SQL语句优化 持续影响
2009/04/26 PHP
UCenter Home二次开发指南
2009/05/28 PHP
PHP设计模式 注册表模式(多个类的注册)
2012/02/05 PHP
基于PHP创建Cookie数组的详解
2013/07/03 PHP
PHPMailer的主要功能特点和简单使用说明
2014/02/17 PHP
浅析ThinkPHP的模板输出功能
2014/07/01 PHP
图文介绍PHP添加Redis模块及连接
2015/07/28 PHP
ThinkPHP中Widget扩展的两种写法及调用方法详解
2017/05/04 PHP
PHP实现一个限制实例化次数的类示例
2019/09/16 PHP
PHP call_user_func和call_user_func_array函数的简单理解与应用分析
2019/11/25 PHP
ExtJS 2.0 GridPanel基本表格简明教程
2010/05/25 Javascript
javascript中call和apply方法浅谈
2013/09/27 Javascript
jQuery封装的获取Url中的Get参数示例
2013/11/26 Javascript
给事件响应函数传参数的四种方式小结
2013/12/05 Javascript
jQuery Migrate 1.1.0 Released 注意事项
2014/06/14 Javascript
基于jquery实现的鼠标悬停提示案例
2016/12/11 Javascript
JS实现unicode和UTF-8之间的互相转换互转
2017/07/05 Javascript
基于js中的存储键值对以及注意事项介绍
2018/03/30 Javascript
解决layui追加或者动态修改的表单元素“没效果”的问题
2019/09/18 Javascript
基于Vue全局组件与局部组件的区别说明
2020/08/11 Javascript
简单了解Python生成器是什么
2019/07/02 Python
在django中实现choices字段获取对应字段值
2020/07/12 Python
Python编写单元测试代码实例
2020/09/10 Python
html5跳转小程序wx-open-launch-weapp踩坑
2020/12/02 HTML / CSS
美国修容界大佬创建的个人美妆品牌:Kevyn Aucoin Beauty
2018/12/12 全球购物
英国领先的维生素和营养补充剂直接供应商:Healthspan
2019/04/22 全球购物
常用UNIX 命令(Linux的常用命令)
2015/12/26 面试题
这76道Java面试题及答案,祝你能成功通过面试
2016/04/16 面试题
企业为何需要商业计划书
2013/12/26 职场文书
竞选班长自荐书范文
2014/03/09 职场文书
大学生全国两会报告感想
2014/03/17 职场文书
白岩松演讲
2014/05/21 职场文书
2015年办公室个人工作总结
2015/04/20 职场文书
股权投资协议书
2016/03/23 职场文书
使用scrapy实现增量式爬取方式
2022/06/21 Python