利用pytorch实现对CIFAR-10数据集的分类


Posted in Python onJanuary 14, 2020

步骤如下:

1.使用torchvision加载并预处理CIFAR-10数据集、

2.定义网络

3.定义损失函数和优化器

4.训练网络并更新网络参数

5.测试网络

运行环境:

windows+python3.6.3+pycharm+pytorch0.3.0

import torchvision as tv
import torchvision.transforms as transforms
import torch as t
from torchvision.transforms import ToPILImage
show=ToPILImage()    #把Tensor转成Image,方便可视化
import matplotlib.pyplot as plt
import torchvision
import numpy as np


###############数据加载与预处理
transform = transforms.Compose([transforms.ToTensor(),#转为tensor
                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),#归一化
                ])
#训练集
trainset=tv.datasets.CIFAR10(root='/python projects/test/data/',
               train=True,
               download=True,
               transform=transform)

trainloader=t.utils.data.DataLoader(trainset,
                  batch_size=4,
                  shuffle=True,
                  num_workers=0)
#测试集
testset=tv.datasets.CIFAR10(root='/python projects/test/data/',
               train=False,
               download=True,
               transform=transform)

testloader=t.utils.data.DataLoader(testset,
                  batch_size=4,
                  shuffle=True,
                  num_workers=0)


classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

(data,label)=trainset[100]
print(classes[label])

show((data+1)/2).resize((100,100))

# dataiter=iter(trainloader)
# images,labels=dataiter.next()
# print(''.join('11%s'%classes[labels[j]] for j in range(4)))
# show(tv.utils.make_grid(images+1)/2).resize((400,100))
def imshow(img):
  img = img / 2 + 0.5
  npimg = img.numpy()
  plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(trainloader)
images, labels = dataiter.next()
print(images.size())
imshow(torchvision.utils.make_grid(images))
plt.show()#关掉图片才能往后继续算


#########################定义网络
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
  def __init__(self):
    super(Net,self).__init__()
    self.conv1=nn.Conv2d(3,6,5)
    self.conv2=nn.Conv2d(6,16,5)
    self.fc1=nn.Linear(16*5*5,120)
    self.fc2=nn.Linear(120,84)
    self.fc3=nn.Linear(84,10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv1(x)),2)
    x = F.max_pool2d(F.relu(self.conv2(x)),2)
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

net=Net()
print(net)

#############定义损失函数和优化器
from torch import optim
criterion=nn.CrossEntropyLoss()
optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.9)

##############训练网络
from torch.autograd import Variable
import time

start_time = time.time()
for epoch in range(2):
  running_loss=0.0
  for i,data in enumerate(trainloader,0):
    #输入数据
    inputs,labels=data
    inputs,labels=Variable(inputs),Variable(labels)
    #梯度清零
    optimizer.zero_grad()

    outputs=net(inputs)
    loss=criterion(outputs,labels)
    loss.backward()
    #更新参数
    optimizer.step()

    # 打印log
    running_loss += loss.data[0]
    if i % 2000 == 1999:
      print('[%d,%5d] loss:%.3f' % (epoch + 1, i + 1, running_loss / 2000))
      running_loss = 0.0
print('finished training')
end_time = time.time()
print("Spend time:", end_time - start_time)

以上这篇利用pytorch实现对CIFAR-10数据集的分类就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python生成pdf文件的方法
Aug 04 Python
详解django中自定义标签和过滤器
Jul 03 Python
Python实现爬取百度贴吧帖子所有楼层图片的爬虫示例
Apr 26 Python
解决DataFrame排序sort的问题
Jun 07 Python
利用Python实现Shp格式向GeoJSON的转换方法
Jul 09 Python
Python 点击指定位置验证码破解的实现代码
Sep 11 Python
django model object序列化实例
Mar 13 Python
解决在keras中使用model.save()函数保存模型失败的问题
May 21 Python
Keras 快速解决OOM超内存的问题
Jun 11 Python
详解Python中的路径问题
Sep 02 Python
Python基于Tkinter开发一个爬取B站直播弹幕的工具
May 06 Python
详解Python牛顿插值法
May 11 Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
python3.7通过thrift操作hbase的示例代码
Jan 14 #Python
解决pytorch DataLoader num_workers出现的问题
Jan 14 #Python
PyTorch实现ResNet50、ResNet101和ResNet152示例
Jan 14 #Python
You might like
在PHP中使用Sockets 从Usenet中获取文件
2008/01/10 PHP
linux下 C语言对 php 扩展
2008/12/14 PHP
基于HBase Thrift接口的一些使用问题及相关注意事项的详解
2013/06/03 PHP
php实现计数器方法小结
2015/01/05 PHP
Mac OS下配置PHP+MySql环境
2015/02/25 PHP
适用于初学者的简易PHP文件上传类
2015/10/29 PHP
Laravel 5.3 学习笔记之 配置
2016/08/28 PHP
PHP小白必须要知道的php基础知识(超实用)
2017/10/10 PHP
PHP聊天室简单实现方法详解
2018/12/08 PHP
YII2框架中使用RBAC对模块,控制器,方法的权限控制及规则的使用示例
2020/03/18 PHP
用JavaScript隐藏控件的方法
2009/09/21 Javascript
用函数模板,写一个简单高效的 JSON 查询器的方法介绍
2013/04/17 Javascript
JavaScript中字符串分割函数split用法实例
2015/04/07 Javascript
详谈javascript中的cookie
2015/06/03 Javascript
JS操作XML实例总结(加载与解析XML文件、字符串)
2015/12/08 Javascript
JavaScript电子时钟倒计时
2016/01/09 Javascript
基于javascript实现tab选项卡切换特效调试笔记
2016/03/30 Javascript
bootstrapfileinput实现文件自动上传
2016/11/08 Javascript
详解vee-validate的使用个人小结
2017/06/07 Javascript
Javascript格式化并高亮xml字符串的方法及注意事项
2018/08/13 Javascript
基于Vue 2.0 监听文本框内容变化及ref的使用说明介绍
2018/08/24 Javascript
移动端底部导航固定配合vue-router实现组件切换功能
2019/06/13 Javascript
python利用rsa库做公钥解密的方法教程
2017/12/10 Python
python多进程重复加载的解决方式
2019/12/13 Python
Python telnet登陆功能实现代码
2020/04/16 Python
Python基于pip实现离线打包过程详解
2020/05/15 Python
Myprotein瑞士官方网站:运动营养和健身网上商店
2019/09/25 全球购物
小学生竞选班长演讲稿
2014/04/24 职场文书
酒店管理求职信
2014/06/09 职场文书
九九重阳节标语
2014/10/07 职场文书
2014年学生会个人工作总结
2014/11/07 职场文书
党员思想汇报材料
2014/12/19 职场文书
庆六一开幕词
2015/01/29 职场文书
永远是春天观后感
2015/06/12 职场文书
MySQL 覆盖索引的优点
2021/05/19 MySQL
MySQL中datetime时间字段的四舍五入操作
2021/10/05 MySQL