利用pytorch实现对CIFAR-10数据集的分类


Posted in Python onJanuary 14, 2020

步骤如下:

1.使用torchvision加载并预处理CIFAR-10数据集、

2.定义网络

3.定义损失函数和优化器

4.训练网络并更新网络参数

5.测试网络

运行环境:

windows+python3.6.3+pycharm+pytorch0.3.0

import torchvision as tv
import torchvision.transforms as transforms
import torch as t
from torchvision.transforms import ToPILImage
show=ToPILImage()    #把Tensor转成Image,方便可视化
import matplotlib.pyplot as plt
import torchvision
import numpy as np


###############数据加载与预处理
transform = transforms.Compose([transforms.ToTensor(),#转为tensor
                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),#归一化
                ])
#训练集
trainset=tv.datasets.CIFAR10(root='/python projects/test/data/',
               train=True,
               download=True,
               transform=transform)

trainloader=t.utils.data.DataLoader(trainset,
                  batch_size=4,
                  shuffle=True,
                  num_workers=0)
#测试集
testset=tv.datasets.CIFAR10(root='/python projects/test/data/',
               train=False,
               download=True,
               transform=transform)

testloader=t.utils.data.DataLoader(testset,
                  batch_size=4,
                  shuffle=True,
                  num_workers=0)


classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

(data,label)=trainset[100]
print(classes[label])

show((data+1)/2).resize((100,100))

# dataiter=iter(trainloader)
# images,labels=dataiter.next()
# print(''.join('11%s'%classes[labels[j]] for j in range(4)))
# show(tv.utils.make_grid(images+1)/2).resize((400,100))
def imshow(img):
  img = img / 2 + 0.5
  npimg = img.numpy()
  plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(trainloader)
images, labels = dataiter.next()
print(images.size())
imshow(torchvision.utils.make_grid(images))
plt.show()#关掉图片才能往后继续算


#########################定义网络
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
  def __init__(self):
    super(Net,self).__init__()
    self.conv1=nn.Conv2d(3,6,5)
    self.conv2=nn.Conv2d(6,16,5)
    self.fc1=nn.Linear(16*5*5,120)
    self.fc2=nn.Linear(120,84)
    self.fc3=nn.Linear(84,10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv1(x)),2)
    x = F.max_pool2d(F.relu(self.conv2(x)),2)
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

net=Net()
print(net)

#############定义损失函数和优化器
from torch import optim
criterion=nn.CrossEntropyLoss()
optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.9)

##############训练网络
from torch.autograd import Variable
import time

start_time = time.time()
for epoch in range(2):
  running_loss=0.0
  for i,data in enumerate(trainloader,0):
    #输入数据
    inputs,labels=data
    inputs,labels=Variable(inputs),Variable(labels)
    #梯度清零
    optimizer.zero_grad()

    outputs=net(inputs)
    loss=criterion(outputs,labels)
    loss.backward()
    #更新参数
    optimizer.step()

    # 打印log
    running_loss += loss.data[0]
    if i % 2000 == 1999:
      print('[%d,%5d] loss:%.3f' % (epoch + 1, i + 1, running_loss / 2000))
      running_loss = 0.0
print('finished training')
end_time = time.time()
print("Spend time:", end_time - start_time)

以上这篇利用pytorch实现对CIFAR-10数据集的分类就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python批量导出导入MySQL用户的方法
Nov 15 Python
Python中使用第三方库xlrd来读取Excel示例
Apr 05 Python
Python3 操作符重载方法示例
Nov 23 Python
不管你的Python报什么错,用这个模块就能正常运行
Sep 14 Python
Python+OpenCV感兴趣区域ROI提取方法
Jan 10 Python
PYQT5实现控制台显示功能的方法
Jun 25 Python
在PyCharm的 Terminal(终端)切换Python版本的方法
Aug 02 Python
python根据时间获取周数代码实例
Sep 30 Python
Python pickle模块实现对象序列化
Nov 22 Python
PyQt5中多线程模块QThread使用方法的实现
Jan 31 Python
PyCharm 2020 激活到 2100 年的教程
Mar 25 Python
Python利用Faiss库实现ANN近邻搜索的方法详解
Aug 03 Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
python3.7通过thrift操作hbase的示例代码
Jan 14 #Python
解决pytorch DataLoader num_workers出现的问题
Jan 14 #Python
PyTorch实现ResNet50、ResNet101和ResNet152示例
Jan 14 #Python
You might like
如何在WIN2K下安装PHP4.04
2006/10/09 PHP
详解HTTP Cookie状态管理机制
2016/01/14 PHP
php实现遍历文件夹的方法汇总
2017/03/02 PHP
js获取变量
2006/08/24 Javascript
js获取元素外链样式的方法
2015/01/27 Javascript
jQuery实现自定义事件的方法
2015/04/17 Javascript
javascript排序函数实现数字排序
2015/06/26 Javascript
URL中“#” “?” &“”号的作用浅析
2017/02/04 Javascript
js CSS3实现卡牌旋转切换效果
2017/07/04 Javascript
详解基于Vue+Koa的pm2配置
2017/10/24 Javascript
js技巧之十几行的代码实现vue.watch代码
2018/06/09 Javascript
微信小程序实现星级评分和展示
2018/07/05 Javascript
js 解析 JSON 数据简单示例
2020/04/21 Javascript
详解elementUI中input框无法输入的问题
2020/04/27 Javascript
vue.js实现双击放大预览功能
2020/06/23 Javascript
[02:15]2015国际邀请赛选手档案IG.Ferrari 430
2015/07/30 DOTA
[06:06]2018DOTA2亚洲邀请赛主赛事第四日战况回顾 全明星赛欢乐上演
2018/04/07 DOTA
Python错误提示:[Errno 24] Too many open files的分析与解决
2017/02/16 Python
Python之web模板应用
2017/12/26 Python
Python matplotlib绘图可视化知识点整理(小结)
2018/03/16 Python
pytorch训练imagenet分类的方法
2018/07/27 Python
python json.loads兼容单引号数据的方法
2018/12/19 Python
深入理解Python异常处理的哲学
2019/02/01 Python
浅谈python常用程序算法
2019/03/22 Python
Django重置migrations文件的方法步骤
2019/05/01 Python
你应该知道的Python3.6、3.7、3.8新特性小结
2020/05/12 Python
python 5个顶级异步框架推荐
2020/09/09 Python
工业自动化毕业生自荐信范文
2014/01/04 职场文书
采购部部长岗位职责
2014/02/06 职场文书
酒店保安领班职务说明书
2014/03/04 职场文书
学习新党章心得体会2016
2016/01/15 职场文书
攻击最高的10只幽灵系神奇宝贝,坚盾剑怪排第一,第五最为可怕
2022/03/18 日漫
MySQL创建管理HASH分区
2022/04/13 MySQL
MySQL 数据库范式化设计理论
2022/04/22 MySQL
windows系统安装配置nginx环境
2022/06/28 Servers
mysql sock文件存储了什么信息
2022/07/15 MySQL