利用pytorch实现对CIFAR-10数据集的分类


Posted in Python onJanuary 14, 2020

步骤如下:

1.使用torchvision加载并预处理CIFAR-10数据集、

2.定义网络

3.定义损失函数和优化器

4.训练网络并更新网络参数

5.测试网络

运行环境:

windows+python3.6.3+pycharm+pytorch0.3.0

import torchvision as tv
import torchvision.transforms as transforms
import torch as t
from torchvision.transforms import ToPILImage
show=ToPILImage()    #把Tensor转成Image,方便可视化
import matplotlib.pyplot as plt
import torchvision
import numpy as np


###############数据加载与预处理
transform = transforms.Compose([transforms.ToTensor(),#转为tensor
                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),#归一化
                ])
#训练集
trainset=tv.datasets.CIFAR10(root='/python projects/test/data/',
               train=True,
               download=True,
               transform=transform)

trainloader=t.utils.data.DataLoader(trainset,
                  batch_size=4,
                  shuffle=True,
                  num_workers=0)
#测试集
testset=tv.datasets.CIFAR10(root='/python projects/test/data/',
               train=False,
               download=True,
               transform=transform)

testloader=t.utils.data.DataLoader(testset,
                  batch_size=4,
                  shuffle=True,
                  num_workers=0)


classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

(data,label)=trainset[100]
print(classes[label])

show((data+1)/2).resize((100,100))

# dataiter=iter(trainloader)
# images,labels=dataiter.next()
# print(''.join('11%s'%classes[labels[j]] for j in range(4)))
# show(tv.utils.make_grid(images+1)/2).resize((400,100))
def imshow(img):
  img = img / 2 + 0.5
  npimg = img.numpy()
  plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(trainloader)
images, labels = dataiter.next()
print(images.size())
imshow(torchvision.utils.make_grid(images))
plt.show()#关掉图片才能往后继续算


#########################定义网络
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
  def __init__(self):
    super(Net,self).__init__()
    self.conv1=nn.Conv2d(3,6,5)
    self.conv2=nn.Conv2d(6,16,5)
    self.fc1=nn.Linear(16*5*5,120)
    self.fc2=nn.Linear(120,84)
    self.fc3=nn.Linear(84,10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv1(x)),2)
    x = F.max_pool2d(F.relu(self.conv2(x)),2)
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

net=Net()
print(net)

#############定义损失函数和优化器
from torch import optim
criterion=nn.CrossEntropyLoss()
optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.9)

##############训练网络
from torch.autograd import Variable
import time

start_time = time.time()
for epoch in range(2):
  running_loss=0.0
  for i,data in enumerate(trainloader,0):
    #输入数据
    inputs,labels=data
    inputs,labels=Variable(inputs),Variable(labels)
    #梯度清零
    optimizer.zero_grad()

    outputs=net(inputs)
    loss=criterion(outputs,labels)
    loss.backward()
    #更新参数
    optimizer.step()

    # 打印log
    running_loss += loss.data[0]
    if i % 2000 == 1999:
      print('[%d,%5d] loss:%.3f' % (epoch + 1, i + 1, running_loss / 2000))
      running_loss = 0.0
print('finished training')
end_time = time.time()
print("Spend time:", end_time - start_time)

以上这篇利用pytorch实现对CIFAR-10数据集的分类就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之通过Python连接数据库
Oct 28 Python
Python中的rjust()方法使用详解
May 19 Python
Django中对数据查询结果进行排序的方法
Jul 17 Python
Python实现截屏的函数
Jul 26 Python
Python实现注册、登录小程序功能
Sep 21 Python
在Python 中实现图片加框和加字的方法
Jan 26 Python
django的分页器Paginator 从django中导入类
Jul 25 Python
Anaconda3+tensorflow2.0.0+PyCharm安装与环境搭建(图文)
Feb 18 Python
DRF框架API版本管理实现方法解析
Aug 21 Python
python 如何利用argparse解析命令行参数
Sep 11 Python
Django haystack实现全文搜索代码示例
Nov 28 Python
pycharm 实现复制一行的快捷键
Jan 15 Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
python3.7通过thrift操作hbase的示例代码
Jan 14 #Python
解决pytorch DataLoader num_workers出现的问题
Jan 14 #Python
PyTorch实现ResNet50、ResNet101和ResNet152示例
Jan 14 #Python
You might like
php&java(一)
2006/10/09 PHP
php中常见的sql攻击正则表达式汇总
2014/11/06 PHP
php检测url是否存在的方法
2015/04/14 PHP
php调用淘宝开放API实现根据卖家昵称获取卖家店铺ID的方法
2015/07/29 PHP
Yii安装与使用Excel扩展的方法
2016/07/13 PHP
yii2 commands模式以及配置crontab定时任务的方法
2017/08/19 PHP
PHP htmlspecialchars()函数用法与实例讲解
2019/03/08 PHP
php post换行的方法
2020/02/03 PHP
AutoSave/自动存储功能实现
2007/03/24 Javascript
javascript写的一个链表实现代码
2009/10/25 Javascript
JavaScript验证图片类型(扩展名)的函数分享
2014/05/05 Javascript
浅谈$(document)和$(window)的区别
2015/07/15 Javascript
js实现tab切换效果实例
2015/09/16 Javascript
基于Node.js的强大爬虫 能直接发布抓取的文章哦
2016/01/10 Javascript
微信小程序开发教程-手势解锁实例
2017/01/06 Javascript
javascript 产生随机数的几种方法总结
2017/09/26 Javascript
浅谈Vue的加载顺序探讨
2017/10/25 Javascript
JS获取本地地址及天气的方法实例小结
2019/05/10 Javascript
Vue formData实现图片上传
2019/08/20 Javascript
详解Vue.js 响应接口
2020/07/04 Javascript
nuxt.js 在middleware(中间件)中实现路由鉴权操作
2020/11/06 Javascript
Python中使用SAX解析xml实例
2014/11/21 Python
六个窍门助你提高Python运行效率
2015/06/09 Python
Python的socket模块源码中的一些实现要点分析
2016/06/06 Python
浅谈python中scipy.misc.logsumexp函数的运用场景
2016/06/23 Python
Python 登录网站详解及实例
2017/04/11 Python
使用PyQtGraph绘制精美的股票行情K线图的示例代码
2019/03/14 Python
Python如何使用k-means方法将列表中相似的句子归类
2019/08/08 Python
Python Dict找出value大于某值或key大于某值的所有项方式
2020/06/05 Python
lookfantastic荷兰:在线购买奢华护肤、护发和化妆品
2018/11/27 全球购物
New delete 与malloc free 的联系与区别
2013/02/04 面试题
宣传保护环境的公益广告词
2014/03/13 职场文书
大学本科生职业生涯规划书范文
2014/09/14 职场文书
企业法人授权委托书范本
2014/09/23 职场文书
区政府领导班子个人对照检查材料
2014/09/25 职场文书
以权谋私检举信范文
2015/03/02 职场文书