利用pytorch实现对CIFAR-10数据集的分类


Posted in Python onJanuary 14, 2020

步骤如下:

1.使用torchvision加载并预处理CIFAR-10数据集、

2.定义网络

3.定义损失函数和优化器

4.训练网络并更新网络参数

5.测试网络

运行环境:

windows+python3.6.3+pycharm+pytorch0.3.0

import torchvision as tv
import torchvision.transforms as transforms
import torch as t
from torchvision.transforms import ToPILImage
show=ToPILImage()    #把Tensor转成Image,方便可视化
import matplotlib.pyplot as plt
import torchvision
import numpy as np


###############数据加载与预处理
transform = transforms.Compose([transforms.ToTensor(),#转为tensor
                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),#归一化
                ])
#训练集
trainset=tv.datasets.CIFAR10(root='/python projects/test/data/',
               train=True,
               download=True,
               transform=transform)

trainloader=t.utils.data.DataLoader(trainset,
                  batch_size=4,
                  shuffle=True,
                  num_workers=0)
#测试集
testset=tv.datasets.CIFAR10(root='/python projects/test/data/',
               train=False,
               download=True,
               transform=transform)

testloader=t.utils.data.DataLoader(testset,
                  batch_size=4,
                  shuffle=True,
                  num_workers=0)


classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

(data,label)=trainset[100]
print(classes[label])

show((data+1)/2).resize((100,100))

# dataiter=iter(trainloader)
# images,labels=dataiter.next()
# print(''.join('11%s'%classes[labels[j]] for j in range(4)))
# show(tv.utils.make_grid(images+1)/2).resize((400,100))
def imshow(img):
  img = img / 2 + 0.5
  npimg = img.numpy()
  plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(trainloader)
images, labels = dataiter.next()
print(images.size())
imshow(torchvision.utils.make_grid(images))
plt.show()#关掉图片才能往后继续算


#########################定义网络
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
  def __init__(self):
    super(Net,self).__init__()
    self.conv1=nn.Conv2d(3,6,5)
    self.conv2=nn.Conv2d(6,16,5)
    self.fc1=nn.Linear(16*5*5,120)
    self.fc2=nn.Linear(120,84)
    self.fc3=nn.Linear(84,10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv1(x)),2)
    x = F.max_pool2d(F.relu(self.conv2(x)),2)
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

net=Net()
print(net)

#############定义损失函数和优化器
from torch import optim
criterion=nn.CrossEntropyLoss()
optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.9)

##############训练网络
from torch.autograd import Variable
import time

start_time = time.time()
for epoch in range(2):
  running_loss=0.0
  for i,data in enumerate(trainloader,0):
    #输入数据
    inputs,labels=data
    inputs,labels=Variable(inputs),Variable(labels)
    #梯度清零
    optimizer.zero_grad()

    outputs=net(inputs)
    loss=criterion(outputs,labels)
    loss.backward()
    #更新参数
    optimizer.step()

    # 打印log
    running_loss += loss.data[0]
    if i % 2000 == 1999:
      print('[%d,%5d] loss:%.3f' % (epoch + 1, i + 1, running_loss / 2000))
      running_loss = 0.0
print('finished training')
end_time = time.time()
print("Spend time:", end_time - start_time)

以上这篇利用pytorch实现对CIFAR-10数据集的分类就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Django框架中用context来解析模板的方法
Jul 20 Python
安装ElasticSearch搜索工具并配置Python驱动的方法
Dec 22 Python
python实现识别相似图片小结
Feb 22 Python
基于python yield机制的异步操作同步化编程模型
Mar 18 Python
Python中shutil模块的常用文件操作函数用法示例
Jul 05 Python
Django数据库操作的实例(增删改查)
Sep 04 Python
python基本语法练习实例
Sep 19 Python
基于Django用户认证系统详解
Feb 21 Python
Python selenium实现微博自动登录的示例代码
May 16 Python
详解python分布式进程
Oct 08 Python
Python使用pandas对数据进行差分运算的方法
Dec 22 Python
python识别文字(基于tesseract)代码实例
Aug 24 Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
python3.7通过thrift操作hbase的示例代码
Jan 14 #Python
解决pytorch DataLoader num_workers出现的问题
Jan 14 #Python
PyTorch实现ResNet50、ResNet101和ResNet152示例
Jan 14 #Python
You might like
《PHP编程最快明白》第六讲:Mysql数据库操作
2010/11/01 PHP
smarty自定义函数用法示例
2016/05/20 PHP
PHP微信支付实例解析
2016/07/22 PHP
初学js 新节点的创建 删除 的步骤
2011/07/04 Javascript
JSON+HTML实现国家省市联动选择效果
2014/05/18 Javascript
jQuery+css实现的蓝色水平二级导航菜单效果代码
2015/09/11 Javascript
js使用cookie记录用户名的方法
2015/11/26 Javascript
微信小程序 生命周期详解
2016/10/12 Javascript
js实现省份下拉菜单效果
2017/02/15 Javascript
JavaScript无操作后屏保功能的实现方法
2017/07/04 Javascript
Angular6封装http请求的步骤详解
2018/08/13 Javascript
vue-test-utils初使用详解
2019/05/23 Javascript
layui动态表头的实现代码
2019/08/22 Javascript
解决layer.prompt无效的问题
2019/09/24 Javascript
Vue实现剪贴板复制功能
2019/12/31 Javascript
Node使用Nodemailer发送邮件的方法实现
2020/02/24 Javascript
如何使用JavaScript检测空闲的浏览器选项卡
2020/05/28 Javascript
Vue+element-ui添加自定义右键菜单的方法示例
2020/12/08 Vue.js
[03:56]显微镜下的DOTA2第十一期——鬼畜的死亡先知播音员
2014/06/23 DOTA
Python的iOS自动化打包实例代码
2018/11/22 Python
基于Python对数据shape的常见操作详解
2018/12/25 Python
Python脚本按照当前日期创建多级目录
2019/03/01 Python
Python多进程方式抓取基金网站内容的方法分析
2019/06/03 Python
对PyQt5基本窗口控件 QMainWindow的使用详解
2019/06/19 Python
django 基于中间件实现限制ip频繁访问过程详解
2019/07/30 Python
Python 动态变量名定义与调用方法
2020/02/09 Python
python读写数据读写csv文件(pandas用法)
2020/12/14 Python
html5 Canvas画图教程(2)—画直线与设置线条的样式如颜色/端点/交汇点
2013/01/09 HTML / CSS
经济与贸易专业应届生求职信
2013/11/19 职场文书
外贸销售员求职的自我评价
2013/11/23 职场文书
致跳远运动员加油稿
2014/02/11 职场文书
纪念九一八爱国演讲稿600字
2014/09/14 职场文书
毕业生捐书活动倡议书
2015/04/27 职场文书
2016党员学习《反对自由主义》心得体会
2016/01/22 职场文书
2019年大学生职业生涯规划书
2019/03/25 职场文书
《勇者辞职不干了》ED主题曲无字幕动画MV公开
2022/04/13 日漫