利用pytorch实现对CIFAR-10数据集的分类


Posted in Python onJanuary 14, 2020

步骤如下:

1.使用torchvision加载并预处理CIFAR-10数据集、

2.定义网络

3.定义损失函数和优化器

4.训练网络并更新网络参数

5.测试网络

运行环境:

windows+python3.6.3+pycharm+pytorch0.3.0

import torchvision as tv
import torchvision.transforms as transforms
import torch as t
from torchvision.transforms import ToPILImage
show=ToPILImage()    #把Tensor转成Image,方便可视化
import matplotlib.pyplot as plt
import torchvision
import numpy as np


###############数据加载与预处理
transform = transforms.Compose([transforms.ToTensor(),#转为tensor
                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),#归一化
                ])
#训练集
trainset=tv.datasets.CIFAR10(root='/python projects/test/data/',
               train=True,
               download=True,
               transform=transform)

trainloader=t.utils.data.DataLoader(trainset,
                  batch_size=4,
                  shuffle=True,
                  num_workers=0)
#测试集
testset=tv.datasets.CIFAR10(root='/python projects/test/data/',
               train=False,
               download=True,
               transform=transform)

testloader=t.utils.data.DataLoader(testset,
                  batch_size=4,
                  shuffle=True,
                  num_workers=0)


classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

(data,label)=trainset[100]
print(classes[label])

show((data+1)/2).resize((100,100))

# dataiter=iter(trainloader)
# images,labels=dataiter.next()
# print(''.join('11%s'%classes[labels[j]] for j in range(4)))
# show(tv.utils.make_grid(images+1)/2).resize((400,100))
def imshow(img):
  img = img / 2 + 0.5
  npimg = img.numpy()
  plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(trainloader)
images, labels = dataiter.next()
print(images.size())
imshow(torchvision.utils.make_grid(images))
plt.show()#关掉图片才能往后继续算


#########################定义网络
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
  def __init__(self):
    super(Net,self).__init__()
    self.conv1=nn.Conv2d(3,6,5)
    self.conv2=nn.Conv2d(6,16,5)
    self.fc1=nn.Linear(16*5*5,120)
    self.fc2=nn.Linear(120,84)
    self.fc3=nn.Linear(84,10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv1(x)),2)
    x = F.max_pool2d(F.relu(self.conv2(x)),2)
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

net=Net()
print(net)

#############定义损失函数和优化器
from torch import optim
criterion=nn.CrossEntropyLoss()
optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.9)

##############训练网络
from torch.autograd import Variable
import time

start_time = time.time()
for epoch in range(2):
  running_loss=0.0
  for i,data in enumerate(trainloader,0):
    #输入数据
    inputs,labels=data
    inputs,labels=Variable(inputs),Variable(labels)
    #梯度清零
    optimizer.zero_grad()

    outputs=net(inputs)
    loss=criterion(outputs,labels)
    loss.backward()
    #更新参数
    optimizer.step()

    # 打印log
    running_loss += loss.data[0]
    if i % 2000 == 1999:
      print('[%d,%5d] loss:%.3f' % (epoch + 1, i + 1, running_loss / 2000))
      running_loss = 0.0
print('finished training')
end_time = time.time()
print("Spend time:", end_time - start_time)

以上这篇利用pytorch实现对CIFAR-10数据集的分类就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
基于wxpython开发的简单gui计算器实例
May 30 Python
Python抓取电影天堂电影信息的代码
Apr 07 Python
解决Python的str强转int时遇到的问题
Apr 09 Python
pytorch + visdom 处理简单分类问题的示例
Jun 04 Python
详解Python 爬取13个旅游城市,告诉你五一大家最爱去哪玩?
May 07 Python
Python实现RGB与HSI颜色空间的互换方式
Nov 27 Python
flask框架配置mysql数据库操作详解
Nov 29 Python
Django之form组件自动校验数据实现
Jan 14 Python
keras Lambda自定义层实现数据的切片方式,Lambda传参数
Jun 11 Python
Python3实现英文字母转换哥特式字体实例代码
Sep 01 Python
python中pickle模块浅析
Dec 29 Python
Pygame Event事件模块的详细示例
Nov 17 Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
python3.7通过thrift操作hbase的示例代码
Jan 14 #Python
解决pytorch DataLoader num_workers出现的问题
Jan 14 #Python
PyTorch实现ResNet50、ResNet101和ResNet152示例
Jan 14 #Python
You might like
php调用方法mssql_fetch_row、mssql_fetch_array、mssql_fetch_assoc和mssql_fetch_objcect读取数据的区别
2012/08/08 PHP
PHP静态调用非静态方法的应用分析
2013/05/02 PHP
Prototype PeriodicalExecuter对象 学习
2009/07/19 Javascript
jQuery+CSS 实现的超Sexy下拉菜单
2010/01/17 Javascript
Extjs中ComboBoxTree实现的下拉框树效果(自写)
2013/05/28 Javascript
jquery常用特效方法使用示例
2014/04/25 Javascript
jquery单行文字向上滚动效果的实现代码
2014/09/05 Javascript
AngularJS 整理一些优化的小技巧
2016/08/18 Javascript
vue.js中指令Directives详解
2017/03/20 Javascript
vue中实现移动端的scroll滚动方法
2018/03/03 Javascript
Vue微信项目按需授权登录策略实践思路详解
2018/05/07 Javascript
详解Vue、element-ui、axios实现省市区三级联动
2019/05/07 Javascript
Vue 修改网站图标的方法
2020/12/31 Vue.js
[01:06:07]2014 DOTA2国际邀请赛中国区预选赛5.21 DT VS CIS
2014/05/22 DOTA
Python里隐藏的“禅”
2014/06/16 Python
Python2.x版本中maketrans()方法的使用介绍
2015/05/19 Python
Python聚类算法之凝聚层次聚类实例分析
2015/11/20 Python
用pickle存储Python的原生对象方法
2017/04/28 Python
Python判断两个对象相等的原理
2017/12/12 Python
python中数据爬虫requests库使用方法详解
2018/02/11 Python
对django的User模型和四种扩展/重写方法小结
2019/08/17 Python
浅谈Pycharm最有必要改的几个默认设置项
2020/02/14 Python
Python通过两个dataframe用for循环求笛卡尔积
2020/04/29 Python
Python 通过爬虫实现GitHub网页的模拟登录的示例代码
2020/08/17 Python
来自美国主售篮球鞋的零售商店:KICKSUSA
2017/11/28 全球购物
欧克利英国官网:Oakley英国
2019/08/24 全球购物
医院办公室主任职责
2013/12/29 职场文书
财务专业大学生职业生涯规划范文
2013/12/30 职场文书
签约仪式主持词
2014/03/19 职场文书
创建青年文明号材料
2014/05/09 职场文书
投标邀请书范本
2015/02/02 职场文书
政审证明材料
2015/06/19 职场文书
导游词之山西-五老峰
2019/10/07 职场文书
redis三种高可用方式部署的实现
2021/05/11 Redis
vue中利用mqtt服务端实现即时通讯的步骤记录
2021/07/01 Vue.js
Windows环境下实现批量执行Sql文件
2021/10/05 SQL Server