Pytorch的mean和std调查实例


Posted in Python onJanuary 02, 2020

如下所示:

# coding: utf-8

from __future__ import print_function
import copy
import click
import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models, transforms

import matplotlib.pyplot as plt
import load_caffemodel
import scipy.io as sio

# if model has LSTM
# torch.backends.cudnn.enabled = False

imgpath = 'D:/ck/files_detected_face224/'   

imgname = 'S055_002_00000025.png' # anger
image_path = imgpath + imgname

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
raw_image = cv2.imread(image_path)[..., ::-1]
print(raw_image.shape)
raw_image = cv2.resize(raw_image, (224, ) * 2)
image = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(
    mean=mean_file,
    std =std_file,
    #mean = mean_file,
    #std = std_file,
  )
])(raw_image).unsqueeze(0)

print(image.shape)

convert_image1 = image.numpy()
convert_image1 = np.squeeze(convert_image1) # 3* 224 *224, C * H * W
convert_image1 = convert_image1 * np.reshape(std_file,(3,1,1)) + np.reshape(mean_file,(3,1,1))
convert_image1 = np.transpose(convert_image1, (1,2,0)) # H * W * C
print(convert_image1.shape)

convert_image1 = convert_image1 * 255

diff = raw_image - convert_image1
err = np.max(diff)
print(err)
plt.imshow(np.uint8(convert_image1))
plt.show()

结论:

input_image = (raw_image / 255 - mean) ./ std

下面调查均值文件和方差文件是如何生成的:

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
# coding: utf-8
import matplotlib.pyplot as plt
import argparse
import os
import numpy as np
import torchvision
import torchvision.transforms as transforms

dataset_names = ('cifar10','cifar100','mnist')

parser = argparse.ArgumentParser(description='PyTorchLab')
parser.add_argument('-d', '--dataset', metavar='DATA', default='cifar10', choices=dataset_names,
          help='dataset to be used: ' + ' | '.join(dataset_names) + ' (default: cifar10)')

args = parser.parse_args()

data_dir = os.path.join('.', args.dataset)

print(args.dataset)
args.dataset = 'cifar10'
if args.dataset == "cifar10":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(train_set.train_data.mean(axis=(0,1,2))/255)
  print(train_set.train_data.std(axis=(0,1,2))/255)

  # imshow image
  train_data = train_set.train_data
  ind = 100
  img0 = train_data[ind,...]
  ## test channel number, in total , the correct channel is : RGB,not like BGR in caffe
  # error produce
  #b,g,r=cv2.split(img0)
  #img0=cv2.merge([r,g,b])

  print(img0.shape)
  print(type(img0))
  plt.imshow(img0)
  plt.show() # in ship in sea

  #img0 = cv2.resize(img0,(224,224))
  #cv2.imshow('img0',img0)
  #cv2.waitKey()

elif args.dataset == "cifar100":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(np.mean(train_set.train_data, axis=(0,1,2))/255)
  print(np.std(train_set.train_data, axis=(0,1,2))/255)

elif args.dataset == "mnist":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(list(train_set.train_data.size()))
  print(train_set.train_data.float().mean()/255)
  print(train_set.train_data.float().std()/255)

结果:

cifar10
Files already downloaded and verified
(50000, 32, 32, 3)
[ 0.49139968 0.48215841 0.44653091]
[ 0.24703223 0.24348513 0.26158784]
(32, 32, 3)
<class 'numpy.ndarray'>

使用matlab检测是如何计算mean_file和std_file的:

% load cifar10 dataset

data = load('cifar10_train_data.mat');
train_data = data.train_data;
disp(size(train_data));

temp = mean(train_data,1);
disp(size(temp));

train_data = double(train_data);

% compute mean_file 
mean_val = mean(mean(mean(train_data,1),2),3)/255;


% compute std_file 
temp1 = train_data(:,:,:,1);
std_val1 = std(temp1(:))/255;

temp2 = train_data(:,:,:,2);
std_val2 = std(temp2(:))/255;

temp3 = train_data(:,:,:,3);
std_val3 = std(temp3(:))/255;

mean_val = squeeze(mean_val);
std_val = [std_val1, std_val2, std_val3];

disp(mean_val);
disp(std_val);

% result: mean_val: [0.4914, 0.4822, 0.4465]
%     std_val: [0.2470, 0.2435, 0.2616]

均值计算的过程也可以遵循标准差的计算过程。为 了简单,例如对于一个矩阵,所有元素的均值,等于两个方向上先后均值。所以会直接采用如下的形式:

mean_val = mean(mean(mean(train_data,1),2),3)/255;

标准差的计算是每一个通道的对所有样本的求标准差。然后再除以255。

以上这篇Pytorch的mean和std调查实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
零基础写python爬虫之抓取百度贴吧代码分享
Nov 06 Python
python&amp;MongoDB爬取图书馆借阅记录
Feb 05 Python
基于python 字符编码的理解
Sep 02 Python
Python使用正则表达式过滤或替换HTML标签的方法详解
Sep 25 Python
Python操作Sql Server 2008数据库的方法详解
May 17 Python
Python使用pymongo库操作MongoDB数据库的方法实例
Feb 22 Python
详解pandas.DataFrame中删除包涵特定字符串所在的行
Apr 04 Python
python集合删除多种方法详解
Feb 10 Python
From CSV to SQLite3 by python 导入csv到sqlite实例
Feb 14 Python
python实现吃苹果小游戏
Mar 21 Python
django配置app中的静态文件步骤
Mar 27 Python
Pandas实现DataFrame的简单运算、统计与排序
Mar 31 Python
pytorch 图像预处理之减去均值,除以方差的实例
Jan 02 #Python
Linux下升级安装python3.8并配置pip及yum的教程
Jan 02 #Python
pytorch实现focal loss的两种方式小结
Jan 02 #Python
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
Jan 02 #Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
You might like
php抓取网站图片并保存的实现方法
2015/10/29 PHP
Laravel框架数据库迁移操作实例详解
2020/04/06 PHP
JQuery跨Iframe选择实现代码
2010/08/19 Javascript
JQuery自适应IFrame高度(支持嵌套 兼容IE,ff,safafi,chrome)
2011/03/28 Javascript
加载 Javascript 最佳实践
2011/10/30 Javascript
Js 导出table内容到Excel的简单实例
2013/11/19 Javascript
jQuery表格插件datatables用法总结
2014/09/05 Javascript
最新最热最实用的15个jQuery插件汇总
2015/07/05 Javascript
谈谈JavaScript中function多重理解
2015/08/28 Javascript
18个非常棒的jQuery代码片段
2015/11/02 Javascript
判断数组是否包含某个元素的js函数实现方法
2016/05/19 Javascript
js与jquery正则验证电子邮箱、手机号、邮政编码的方法
2016/07/04 Javascript
jQuery模拟淘宝购物车功能
2017/02/27 Javascript
vue webuploader 文件上传组件开发
2017/09/23 Javascript
移动前端图片压缩上传的实例
2017/12/06 Javascript
nodejs取得当前执行路径的方法
2018/05/13 NodeJs
深入理解JS异步编程-Promise
2019/06/03 Javascript
深入理解javascript prototype的相关知识
2019/09/19 Javascript
Python中__new__与__init__方法的区别详解
2015/05/04 Python
python中利用Future对象异步返回结果示例代码
2017/09/07 Python
Python生成器定义与简单用法实例分析
2018/04/30 Python
如何用python写一个简单的词法分析器
2018/12/18 Python
django+echart数据动态显示的例子
2019/08/12 Python
Pytorch的mean和std调查实例
2020/01/02 Python
Python实现网络聊天室的示例代码(支持多人聊天与私聊)
2021/01/27 Python
HTML5中使用postMessage实现两个网页间传递数据
2016/06/22 HTML / CSS
Hotter Shoes英国官网:英伦风格,舒适的鞋子
2017/12/28 全球购物
英国在线电子和小工具商店:TecoBuy
2018/10/06 全球购物
土木工程专业自荐信
2013/10/04 职场文书
我的小天地教学反思
2014/04/30 职场文书
咖啡店创业计划书
2014/08/15 职场文书
2015年大学生实习评语
2015/03/25 职场文书
项目合作意向书
2015/05/08 职场文书
学校安全管理制度
2015/08/06 职场文书
MySQL 自定义变量的概念及特点
2021/05/13 MySQL
Vue鼠标滚轮滚动切换路由效果的实现方法
2021/08/04 Vue.js