利用pytorch实现对CIFAR-10数据集的分类


Posted in Python onJanuary 14, 2020

步骤如下:

1.使用torchvision加载并预处理CIFAR-10数据集、

2.定义网络

3.定义损失函数和优化器

4.训练网络并更新网络参数

5.测试网络

运行环境:

windows+python3.6.3+pycharm+pytorch0.3.0

import torchvision as tv
import torchvision.transforms as transforms
import torch as t
from torchvision.transforms import ToPILImage
show=ToPILImage()    #把Tensor转成Image,方便可视化
import matplotlib.pyplot as plt
import torchvision
import numpy as np


###############数据加载与预处理
transform = transforms.Compose([transforms.ToTensor(),#转为tensor
                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),#归一化
                ])
#训练集
trainset=tv.datasets.CIFAR10(root='/python projects/test/data/',
               train=True,
               download=True,
               transform=transform)

trainloader=t.utils.data.DataLoader(trainset,
                  batch_size=4,
                  shuffle=True,
                  num_workers=0)
#测试集
testset=tv.datasets.CIFAR10(root='/python projects/test/data/',
               train=False,
               download=True,
               transform=transform)

testloader=t.utils.data.DataLoader(testset,
                  batch_size=4,
                  shuffle=True,
                  num_workers=0)


classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

(data,label)=trainset[100]
print(classes[label])

show((data+1)/2).resize((100,100))

# dataiter=iter(trainloader)
# images,labels=dataiter.next()
# print(''.join('11%s'%classes[labels[j]] for j in range(4)))
# show(tv.utils.make_grid(images+1)/2).resize((400,100))
def imshow(img):
  img = img / 2 + 0.5
  npimg = img.numpy()
  plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(trainloader)
images, labels = dataiter.next()
print(images.size())
imshow(torchvision.utils.make_grid(images))
plt.show()#关掉图片才能往后继续算


#########################定义网络
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
  def __init__(self):
    super(Net,self).__init__()
    self.conv1=nn.Conv2d(3,6,5)
    self.conv2=nn.Conv2d(6,16,5)
    self.fc1=nn.Linear(16*5*5,120)
    self.fc2=nn.Linear(120,84)
    self.fc3=nn.Linear(84,10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv1(x)),2)
    x = F.max_pool2d(F.relu(self.conv2(x)),2)
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

net=Net()
print(net)

#############定义损失函数和优化器
from torch import optim
criterion=nn.CrossEntropyLoss()
optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.9)

##############训练网络
from torch.autograd import Variable
import time

start_time = time.time()
for epoch in range(2):
  running_loss=0.0
  for i,data in enumerate(trainloader,0):
    #输入数据
    inputs,labels=data
    inputs,labels=Variable(inputs),Variable(labels)
    #梯度清零
    optimizer.zero_grad()

    outputs=net(inputs)
    loss=criterion(outputs,labels)
    loss.backward()
    #更新参数
    optimizer.step()

    # 打印log
    running_loss += loss.data[0]
    if i % 2000 == 1999:
      print('[%d,%5d] loss:%.3f' % (epoch + 1, i + 1, running_loss / 2000))
      running_loss = 0.0
print('finished training')
end_time = time.time()
print("Spend time:", end_time - start_time)

以上这篇利用pytorch实现对CIFAR-10数据集的分类就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python ip正则式
May 07 Python
python实现删除文件与目录的方法
Nov 10 Python
Python批量创建迅雷任务及创建多个文件
Feb 13 Python
Python中遇到的小问题及解决方法汇总
Jan 11 Python
对python中return和print的一些理解
Aug 18 Python
深入学习Python中的上下文管理器与else块
Aug 27 Python
Python logging管理不同级别log打印和存储实例
Jan 19 Python
pyqt5简介及安装方法介绍
Jan 31 Python
在python中bool函数的取值方法
Nov 01 Python
五分钟带你搞懂python 迭代器与生成器
Aug 30 Python
基于Python实现将列表数据生成折线图
Mar 23 Python
如何用六步教会你使用python爬虫爬取数据
Apr 06 Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
python3.7通过thrift操作hbase的示例代码
Jan 14 #Python
解决pytorch DataLoader num_workers出现的问题
Jan 14 #Python
PyTorch实现ResNet50、ResNet101和ResNet152示例
Jan 14 #Python
You might like
php判断当前操作系统类型
2015/10/28 PHP
php实现统计目录文件大小的函数
2015/12/25 PHP
PHP读取文件内容的五种方式
2015/12/28 PHP
php文件上传类完整实例
2016/05/14 PHP
popdiv
2006/07/14 Javascript
利用JQuery的load函数动态加载其它页面的内容的实现代码
2010/12/14 Javascript
JavaScript中获取样式的原生方法小结
2014/10/08 Javascript
node-webkit打包成exe文件被360误报木马的解决方法
2015/03/11 Javascript
jquery实现的淡入淡出下拉菜单效果
2015/08/25 Javascript
jQuery 中的 DOM 操作
2016/04/26 Javascript
Angular.JS学习之依赖注入$injector详析
2016/10/20 Javascript
Mac下通过brew安装指定版本的nodejs教程
2018/05/17 NodeJs
原生JS实现的放大镜特效示例【测试可用】
2018/12/08 Javascript
layui type2 通过url给iframe子页面传值的例子
2019/09/06 Javascript
vue实现修改图片后实时更新
2019/11/14 Javascript
JSONObject与JSONArray使用方法解析
2020/09/28 Javascript
uniapp实现可以左右滑动导航栏
2020/10/21 Javascript
[02:35]DOTA2超级联赛专访XB 难忘一年九冠称王
2013/06/20 DOTA
[01:11]辉夜杯战队访谈宣传片—CDEC.Y
2015/12/26 DOTA
简单的Python2.7编程初学经验总结
2015/04/01 Python
Python3读取UTF-8文件及统计文件行数的方法
2015/05/22 Python
TensorFlow深度学习之卷积神经网络CNN
2018/03/09 Python
PyQt5每天必学之工具提示功能
2018/04/19 Python
python脚本监控logstash进程并邮件告警实例
2020/04/28 Python
python和php哪个更适合写爬虫
2020/06/22 Python
Pyecharts 中Geo函数常用参数的用法说明
2021/02/01 Python
CSS3 简单又实用的5个属性
2010/03/04 HTML / CSS
CSS去掉A标签(链接)虚线框的方法
2014/04/01 HTML / CSS
HTML5触摸事件实现移动端简易进度条的实现方法
2018/05/04 HTML / CSS
大学生演讲稿范文
2014/01/11 职场文书
采购部部长岗位职责
2014/02/06 职场文书
电子工程专业毕业生求职信
2014/03/14 职场文书
幼儿园教师演讲稿
2014/05/06 职场文书
中学生关于梦想的演讲稿
2014/08/22 职场文书
车间质检员岗位职责
2015/04/08 职场文书
javaScript Array api梳理
2021/03/31 Javascript