利用pytorch实现对CIFAR-10数据集的分类


Posted in Python onJanuary 14, 2020

步骤如下:

1.使用torchvision加载并预处理CIFAR-10数据集、

2.定义网络

3.定义损失函数和优化器

4.训练网络并更新网络参数

5.测试网络

运行环境:

windows+python3.6.3+pycharm+pytorch0.3.0

import torchvision as tv
import torchvision.transforms as transforms
import torch as t
from torchvision.transforms import ToPILImage
show=ToPILImage()    #把Tensor转成Image,方便可视化
import matplotlib.pyplot as plt
import torchvision
import numpy as np


###############数据加载与预处理
transform = transforms.Compose([transforms.ToTensor(),#转为tensor
                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),#归一化
                ])
#训练集
trainset=tv.datasets.CIFAR10(root='/python projects/test/data/',
               train=True,
               download=True,
               transform=transform)

trainloader=t.utils.data.DataLoader(trainset,
                  batch_size=4,
                  shuffle=True,
                  num_workers=0)
#测试集
testset=tv.datasets.CIFAR10(root='/python projects/test/data/',
               train=False,
               download=True,
               transform=transform)

testloader=t.utils.data.DataLoader(testset,
                  batch_size=4,
                  shuffle=True,
                  num_workers=0)


classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

(data,label)=trainset[100]
print(classes[label])

show((data+1)/2).resize((100,100))

# dataiter=iter(trainloader)
# images,labels=dataiter.next()
# print(''.join('11%s'%classes[labels[j]] for j in range(4)))
# show(tv.utils.make_grid(images+1)/2).resize((400,100))
def imshow(img):
  img = img / 2 + 0.5
  npimg = img.numpy()
  plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(trainloader)
images, labels = dataiter.next()
print(images.size())
imshow(torchvision.utils.make_grid(images))
plt.show()#关掉图片才能往后继续算


#########################定义网络
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
  def __init__(self):
    super(Net,self).__init__()
    self.conv1=nn.Conv2d(3,6,5)
    self.conv2=nn.Conv2d(6,16,5)
    self.fc1=nn.Linear(16*5*5,120)
    self.fc2=nn.Linear(120,84)
    self.fc3=nn.Linear(84,10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv1(x)),2)
    x = F.max_pool2d(F.relu(self.conv2(x)),2)
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

net=Net()
print(net)

#############定义损失函数和优化器
from torch import optim
criterion=nn.CrossEntropyLoss()
optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.9)

##############训练网络
from torch.autograd import Variable
import time

start_time = time.time()
for epoch in range(2):
  running_loss=0.0
  for i,data in enumerate(trainloader,0):
    #输入数据
    inputs,labels=data
    inputs,labels=Variable(inputs),Variable(labels)
    #梯度清零
    optimizer.zero_grad()

    outputs=net(inputs)
    loss=criterion(outputs,labels)
    loss.backward()
    #更新参数
    optimizer.step()

    # 打印log
    running_loss += loss.data[0]
    if i % 2000 == 1999:
      print('[%d,%5d] loss:%.3f' % (epoch + 1, i + 1, running_loss / 2000))
      running_loss = 0.0
print('finished training')
end_time = time.time()
print("Spend time:", end_time - start_time)

以上这篇利用pytorch实现对CIFAR-10数据集的分类就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在pycharm上mongodb配置及可视化设置方法
Nov 30 Python
python实现简单五子棋游戏
Jun 18 Python
python如何以表格形式打印输出的方法示例
Jun 21 Python
Python调用C语言的实现
Jul 26 Python
大家都说好用的Python命令行库click的使用
Nov 07 Python
python绘制雪景图
Dec 16 Python
Pytorch GPU显存充足却显示out of memory的解决方式
Jan 13 Python
VSCODE配置Markdown及Markdown基础语法详解
Jan 19 Python
Django后端按照日期查询的方法教程
Feb 28 Python
pycharm配置安装autopep8自动规范代码的实现
Mar 02 Python
Python基础之教你怎么在M1系统上使用pandas
May 08 Python
python和Appium的移动端多设备自动化测试框架
Apr 26 Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
python3.7通过thrift操作hbase的示例代码
Jan 14 #Python
解决pytorch DataLoader num_workers出现的问题
Jan 14 #Python
PyTorch实现ResNet50、ResNet101和ResNet152示例
Jan 14 #Python
You might like
十天学会php之第十天
2006/10/09 PHP
Http 1.1 Etag 与 Last-Modified提高php效率
2008/01/10 PHP
php $_SERVER["REQUEST_URI"]获取值的通用解决方法
2010/06/21 PHP
PHP循环语句笔记(foreach,list)
2011/11/29 PHP
PHP计算当前坐标3公里内4个角落的最大最小经纬度实例
2016/02/26 PHP
PHP精确计算功能示例
2016/11/29 PHP
PHP实现Redis单据锁以及防止并发重复写入
2018/04/10 PHP
PHP实现的pdo连接数据库并插入数据功能简单示例
2019/03/30 PHP
关于javascript中this关键字(翻译+自我理解)
2010/10/20 Javascript
理解JAVASCRIPT中hasOwnProperty()的作用
2013/06/05 Javascript
javascript去掉代码里面的注释
2015/07/24 Javascript
全面解析Bootstrap中scrollspy(滚动监听)的使用方法
2016/06/06 Javascript
Bootstrap零基础学习第一课之模板
2016/07/18 Javascript
js实现颜色阶梯渐变效果(Gradient算法)
2017/03/21 Javascript
详解如何快速配置webpack多入口脚手架
2018/12/28 Javascript
vue-cli3 项目从搭建优化到docker部署的方法
2019/01/28 Javascript
js实现页面多个日期时间倒计时效果
2019/06/20 Javascript
jquery弹窗时禁止body滚动条滚动的例子
2019/09/21 jQuery
vue-drawer-layout实现手势滑出菜单栏
2020/11/19 Vue.js
[52:26]完美世界DOTA2联赛决赛 FTD vs Phoenix 第一场 11.08
2020/11/11 DOTA
python MysqlDb模块安装及其使用详解
2018/02/23 Python
python如何实现异步调用函数执行
2019/07/08 Python
Python 将json序列化后的字符串转换成字典(推荐)
2020/01/06 Python
tensorflow模型转ncnn的操作方式
2020/05/25 Python
Python实现Appium端口检测与释放的实现
2020/12/31 Python
聊聊Python pandas 中loc函数的使用,及跟iloc的区别说明
2021/03/03 Python
中学生校园广播稿
2014/01/16 职场文书
教师一岗双责责任书
2014/04/16 职场文书
六年级学生评语
2014/04/22 职场文书
二年级学生评语大全
2014/04/23 职场文书
教师年度考核自我评鉴
2015/08/11 职场文书
物业管理交接协议书
2016/03/24 职场文书
《中国古代诗歌散文欣赏》高中语文教材
2019/08/20 职场文书
创业的9条正确思考方式
2019/08/26 职场文书
Appium中scroll和drag_and_drop根据元素位置滑动
2022/02/15 Python
疑《守望先锋2》A测截图泄露 或将推出新模式、新界面
2022/04/03 其他游戏