利用pytorch实现对CIFAR-10数据集的分类


Posted in Python onJanuary 14, 2020

步骤如下:

1.使用torchvision加载并预处理CIFAR-10数据集、

2.定义网络

3.定义损失函数和优化器

4.训练网络并更新网络参数

5.测试网络

运行环境:

windows+python3.6.3+pycharm+pytorch0.3.0

import torchvision as tv
import torchvision.transforms as transforms
import torch as t
from torchvision.transforms import ToPILImage
show=ToPILImage()    #把Tensor转成Image,方便可视化
import matplotlib.pyplot as plt
import torchvision
import numpy as np


###############数据加载与预处理
transform = transforms.Compose([transforms.ToTensor(),#转为tensor
                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),#归一化
                ])
#训练集
trainset=tv.datasets.CIFAR10(root='/python projects/test/data/',
               train=True,
               download=True,
               transform=transform)

trainloader=t.utils.data.DataLoader(trainset,
                  batch_size=4,
                  shuffle=True,
                  num_workers=0)
#测试集
testset=tv.datasets.CIFAR10(root='/python projects/test/data/',
               train=False,
               download=True,
               transform=transform)

testloader=t.utils.data.DataLoader(testset,
                  batch_size=4,
                  shuffle=True,
                  num_workers=0)


classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

(data,label)=trainset[100]
print(classes[label])

show((data+1)/2).resize((100,100))

# dataiter=iter(trainloader)
# images,labels=dataiter.next()
# print(''.join('11%s'%classes[labels[j]] for j in range(4)))
# show(tv.utils.make_grid(images+1)/2).resize((400,100))
def imshow(img):
  img = img / 2 + 0.5
  npimg = img.numpy()
  plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(trainloader)
images, labels = dataiter.next()
print(images.size())
imshow(torchvision.utils.make_grid(images))
plt.show()#关掉图片才能往后继续算


#########################定义网络
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
  def __init__(self):
    super(Net,self).__init__()
    self.conv1=nn.Conv2d(3,6,5)
    self.conv2=nn.Conv2d(6,16,5)
    self.fc1=nn.Linear(16*5*5,120)
    self.fc2=nn.Linear(120,84)
    self.fc3=nn.Linear(84,10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv1(x)),2)
    x = F.max_pool2d(F.relu(self.conv2(x)),2)
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

net=Net()
print(net)

#############定义损失函数和优化器
from torch import optim
criterion=nn.CrossEntropyLoss()
optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.9)

##############训练网络
from torch.autograd import Variable
import time

start_time = time.time()
for epoch in range(2):
  running_loss=0.0
  for i,data in enumerate(trainloader,0):
    #输入数据
    inputs,labels=data
    inputs,labels=Variable(inputs),Variable(labels)
    #梯度清零
    optimizer.zero_grad()

    outputs=net(inputs)
    loss=criterion(outputs,labels)
    loss.backward()
    #更新参数
    optimizer.step()

    # 打印log
    running_loss += loss.data[0]
    if i % 2000 == 1999:
      print('[%d,%5d] loss:%.3f' % (epoch + 1, i + 1, running_loss / 2000))
      running_loss = 0.0
print('finished training')
end_time = time.time()
print("Spend time:", end_time - start_time)

以上这篇利用pytorch实现对CIFAR-10数据集的分类就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅谈用VSCode写python的正确姿势
Dec 16 Python
python实现支付宝当面付(扫码支付)功能
May 30 Python
Python多线程编程之多线程加锁操作示例
Sep 06 Python
Django管理员账号和密码忘记的完美解决方法
Dec 06 Python
对pandas处理json数据的方法详解
Feb 08 Python
Python安装selenium包详细过程
Jul 23 Python
树莓派极简安装OpenCv的方法步骤
Oct 10 Python
TensorFlow MNIST手写数据集的实现方法
Feb 05 Python
Python3 元组tuple入门基础
Feb 09 Python
Pytest框架之fixture的详细使用教程
Apr 07 Python
keras训练曲线,混淆矩阵,CNN层输出可视化实例
Jun 15 Python
python 用opencv实现霍夫线变换
Nov 27 Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
python3.7通过thrift操作hbase的示例代码
Jan 14 #Python
解决pytorch DataLoader num_workers出现的问题
Jan 14 #Python
PyTorch实现ResNet50、ResNet101和ResNet152示例
Jan 14 #Python
You might like
php遍历删除整个目录及文件的方法
2015/03/13 PHP
IE浏览器PNG图片透明效果代码
2008/09/02 Javascript
jquery ui dialog里调用datepicker的问题
2009/08/06 Javascript
七个很有意思的PHP函数
2014/05/12 Javascript
JavaScript判断前缀、后缀是否是空格的方法
2015/04/15 Javascript
jQuery原生的动画效果
2015/07/10 Javascript
JavaScript实现显示函数调用堆栈的方法
2016/04/21 Javascript
微信小程序实现图片预加载组件
2017/01/18 Javascript
基于Vuejs和Element的注册插件的编写方法
2017/07/03 Javascript
jQuery EasyUI开发技巧总结
2017/09/26 jQuery
重新认识vue之事件阻止冒泡的实现
2018/08/02 Javascript
angular 实现下拉列表组件的示例代码
2019/03/09 Javascript
小程序显示弹窗时禁止下层的内容滚动实现方法
2019/03/20 Javascript
[50:24]VGJ.S vs Pain 2018国际邀请赛小组赛BO2 第二场 8.17
2018/08/20 DOTA
从零学python系列之新版本导入httplib模块报ImportError解决方案
2014/05/23 Python
在Python中使用SimpleParse模块进行解析的教程
2015/04/11 Python
python批量制作雷达图的实现方法
2016/07/26 Python
对Python中列表和数组的赋值,浅拷贝和深拷贝的实例讲解
2018/06/28 Python
python实现控制电脑鼠标和键盘,登录QQ的方法示例
2019/07/06 Python
pandas中遍历dataframe的每一个元素的实现
2019/10/23 Python
Python使用进程Process模块管理资源
2020/03/05 Python
如何基于python实现年会抽奖工具
2020/10/20 Python
关于python中remove的一些坑小结
2021/01/04 Python
python解包概念及实例
2021/02/17 Python
世界上最受欢迎的钓鱼诱饵:Rapala
2019/05/02 全球购物
sealed修饰符是干什么的
2012/10/23 面试题
一道Delphi上机题
2012/06/04 面试题
应届专科生个人的自我评价
2014/01/05 职场文书
甜美蛋糕店创业计划书
2014/01/30 职场文书
文化与传播毕业生求职信
2014/03/09 职场文书
高中英语演讲稿范文
2014/04/24 职场文书
房屋出售授权委托书
2014/10/12 职场文书
2015年办公室工作总结范文
2015/03/31 职场文书
小学运动会开幕词
2016/03/04 职场文书
Django cookie和session的应用场景及如何使用
2021/04/29 Python
MYSQL如何查看操作日志详解
2022/05/30 MySQL