pytorch实现mnist分类的示例讲解


Posted in Python onJanuary 10, 2020

torchvision包 包含了目前流行的数据集,模型结构和常用的图片转换工具。

torchvision.datasets中包含了以下数据集

MNIST
COCO(用于图像标注和目标检测)(Captioning and Detection)
LSUN Classification
ImageFolder
Imagenet-12
CIFAR10 and CIFAR100
STL10

torchvision.models

torchvision.models模块的 子模块中包含以下模型结构。
AlexNet
VGG
ResNet
SqueezeNet
DenseNet You can construct a model with random weights by calling its constructor:

pytorch torchvision transform

对PIL.Image进行变换

from __future__ import print_function
import argparse #Python 命令行解析工具
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim 
from torchvision import datasets, transforms

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.conv2_drop = nn.Dropout2d()
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, 10)

  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
    x = x.view(-1, 320)
    x = F.relu(self.fc1(x))
    x = F.dropout(x, training=self.training)
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

def train(args, model, device, train_loader, optimizer, epoch):
  model.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % args.log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))

def test(args, model, device, test_loader):
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device)
      output = model(data)
      test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss
      pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
      correct += pred.eq(target.view_as(pred)).sum().item()

  test_loss /= len(test_loader.dataset)
  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

def main():
  # Training settings
  parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  parser.add_argument('--batch-size', type=int, default=64, metavar='N',
            help='input batch size for training (default: 64)')
  parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
            help='input batch size for testing (default: 1000)')
  parser.add_argument('--epochs', type=int, default=10, metavar='N',
            help='number of epochs to train (default: 10)')
  parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
            help='learning rate (default: 0.01)')
  parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
            help='SGD momentum (default: 0.5)')
  parser.add_argument('--no-cuda', action='store_true', default=False,
            help='disables CUDA training')
  parser.add_argument('--seed', type=int, default=1, metavar='S',
            help='random seed (default: 1)')
  parser.add_argument('--log-interval', type=int, default=10, metavar='N',
            help='how many batches to wait before logging training status')
  args = parser.parse_args()
  use_cuda = not args.no_cuda and torch.cuda.is_available()

  torch.manual_seed(args.seed)

  device = torch.device("cuda" if use_cuda else "cpu")

  kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
  train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
            transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=args.batch_size, shuffle=True, **kwargs)
  test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)


  model = Net().to(device)
  optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

  for epoch in range(1, args.epochs + 1):
    train(args, model, device, train_loader, optimizer, epoch)
    test(args, model, device, test_loader)


if __name__ == '__main__':
  main()

以上这篇pytorch实现mnist分类的示例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python爬虫利用cookie实现模拟登陆实例详解
Jan 12 Python
windows下python安装paramiko模块和pycrypto模块(简单三步)
Jul 06 Python
Python编程之黑板上排列组合,你舍得解开吗
Oct 30 Python
Python模拟脉冲星伪信号频率实例代码
Jan 03 Python
使用python爬虫实现网络股票信息爬取的demo
Jan 05 Python
在Python 2.7即将停止支持时,我们为你带来了一份python 3.x迁移指南
Jan 30 Python
Python列表推导式与生成器表达式用法示例
Feb 08 Python
python版本单链表实现代码
Sep 28 Python
详解opencv Python特征检测及K-最近邻匹配
Jan 21 Python
python 读取文件并把矩阵转成numpy的两种方法
Feb 12 Python
Python实现随机生成任意数量车牌号
Jan 21 Python
举例讲解Python装饰器
Dec 24 Python
pytorch 准备、训练和测试自己的图片数据的方法
Jan 10 #Python
pytorch GAN伪造手写体mnist数据集方式
Jan 10 #Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
You might like
php缓冲 output_buffering的使用详解
2013/06/13 PHP
PHP数组和explode函数示例总结
2015/05/08 PHP
phpMyAdmin通过密码漏洞留后门文件
2018/11/20 PHP
xml转json的js代码
2012/08/28 Javascript
JavaScript学习心得之概述
2015/01/20 Javascript
JQuery中的事件及动画用法实例
2015/01/26 Javascript
JavaScript采用递归算法计算阶乘实例
2015/08/04 Javascript
javascript实现显示和隐藏div方法汇总
2015/08/14 Javascript
JQuery日历插件My97DatePicker日期范围限制
2016/01/20 Javascript
chrome下判断点击input上标签还是其余标签的实现方法
2016/09/18 Javascript
JavaScript算法教程之sku(库存量单位)详解
2017/06/29 Javascript
JavaScript中document.referrer的用法详解
2017/07/04 Javascript
详解vue mixins和extends的巧妙用法
2017/12/20 Javascript
手写Node静态资源服务器的实现方法
2018/03/20 Javascript
vue-cli的工程模板与构建工具详解
2018/09/27 Javascript
D3.js(v3)+react 实现带坐标与比例尺的柱形图 (V3版本)
2019/05/09 Javascript
layui写后台表格思路和赋值用法详解
2019/11/14 Javascript
微信小程序点击item使之滚动到屏幕中间位置
2020/03/25 Javascript
[27:08]完美世界DOTA2联赛PWL S2 SZ vs Rebirth 第二场 11.21
2020/11/23 DOTA
使用Python发送邮件附件以定时备份MySQL的教程
2015/04/25 Python
python使用socket进行简单网络连接的方法
2015/04/29 Python
python3 图片referer防盗链的实现方法
2018/03/12 Python
基于python实现简单日历
2018/07/28 Python
搞定这套Python爬虫面试题(面试会so easy)
2019/04/03 Python
Python 监测文件是否更新的方法
2019/06/10 Python
python 装饰器的实际作用有哪些
2020/09/07 Python
详解Pytorch显存动态分配规律探索
2020/11/17 Python
意大利婴儿产品网上商店:Mukako
2018/10/14 全球购物
利物浦足球俱乐部官方商店(美国):Liverpool FC US
2019/10/09 全球购物
初中学生评语大全
2014/04/24 职场文书
趣味运动会广播稿
2014/09/13 职场文书
2014年小学体育工作总结
2014/12/11 职场文书
房屋认购协议书
2015/01/29 职场文书
销售督导岗位职责
2015/04/10 职场文书
python基础入门之普通操作与函数(三)
2021/06/13 Python
MySQL数据库事务的四大特性
2022/04/20 MySQL