pytorch实现mnist分类的示例讲解


Posted in Python onJanuary 10, 2020

torchvision包 包含了目前流行的数据集,模型结构和常用的图片转换工具。

torchvision.datasets中包含了以下数据集

MNIST
COCO(用于图像标注和目标检测)(Captioning and Detection)
LSUN Classification
ImageFolder
Imagenet-12
CIFAR10 and CIFAR100
STL10

torchvision.models

torchvision.models模块的 子模块中包含以下模型结构。
AlexNet
VGG
ResNet
SqueezeNet
DenseNet You can construct a model with random weights by calling its constructor:

pytorch torchvision transform

对PIL.Image进行变换

from __future__ import print_function
import argparse #Python 命令行解析工具
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim 
from torchvision import datasets, transforms

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.conv2_drop = nn.Dropout2d()
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, 10)

  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
    x = x.view(-1, 320)
    x = F.relu(self.fc1(x))
    x = F.dropout(x, training=self.training)
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

def train(args, model, device, train_loader, optimizer, epoch):
  model.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % args.log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))

def test(args, model, device, test_loader):
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device)
      output = model(data)
      test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss
      pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
      correct += pred.eq(target.view_as(pred)).sum().item()

  test_loss /= len(test_loader.dataset)
  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

def main():
  # Training settings
  parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  parser.add_argument('--batch-size', type=int, default=64, metavar='N',
            help='input batch size for training (default: 64)')
  parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
            help='input batch size for testing (default: 1000)')
  parser.add_argument('--epochs', type=int, default=10, metavar='N',
            help='number of epochs to train (default: 10)')
  parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
            help='learning rate (default: 0.01)')
  parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
            help='SGD momentum (default: 0.5)')
  parser.add_argument('--no-cuda', action='store_true', default=False,
            help='disables CUDA training')
  parser.add_argument('--seed', type=int, default=1, metavar='S',
            help='random seed (default: 1)')
  parser.add_argument('--log-interval', type=int, default=10, metavar='N',
            help='how many batches to wait before logging training status')
  args = parser.parse_args()
  use_cuda = not args.no_cuda and torch.cuda.is_available()

  torch.manual_seed(args.seed)

  device = torch.device("cuda" if use_cuda else "cpu")

  kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
  train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
            transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=args.batch_size, shuffle=True, **kwargs)
  test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)


  model = Net().to(device)
  optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

  for epoch in range(1, args.epochs + 1):
    train(args, model, device, train_loader, optimizer, epoch)
    test(args, model, device, test_loader)


if __name__ == '__main__':
  main()

以上这篇pytorch实现mnist分类的示例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python自动重试HTTP连接装饰器
Apr 28 Python
Python 快速实现CLI 应用程序的脚手架
Dec 05 Python
PyQt5每天必学之关闭窗口
Apr 19 Python
Python计算开方、立方、圆周率,精确到小数点后任意位的方法
Jul 17 Python
Python列表生成式与生成器操作示例
Aug 01 Python
对Python3.x版本print函数左右对齐详解
Dec 22 Python
python3利用Socket实现通信的方法示例
May 06 Python
python多进程读图提取特征存npy
May 21 Python
python创建子类的方法分析
Nov 28 Python
Python时间差中seconds和total_seconds的区别详解
Dec 26 Python
pandas的相关系数与协方差实例
Dec 27 Python
python中使用paramiko模块并实现远程连接服务器执行上传下载功能
Feb 29 Python
pytorch 准备、训练和测试自己的图片数据的方法
Jan 10 #Python
pytorch GAN伪造手写体mnist数据集方式
Jan 10 #Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
You might like
php将文件夹打包成zip文件的简单实现方法
2016/10/04 PHP
php事件驱动化设计详解
2016/11/10 PHP
php 遍历目录,生成目录下每个文件的md5值并写入到结果文件中
2016/12/12 PHP
PHP/ThinkPHP实现批量打包下载文件的方法示例
2017/07/31 PHP
JavaScript Event学习第十一章 按键的检测
2010/02/10 Javascript
javascript感应鼠标图片透明度显示的方法
2015/02/24 Javascript
jQuery头像裁剪工具jcrop用法实例(附演示与demo源码下载)
2016/01/22 Javascript
微信小程序实战之顶部导航栏(选项卡)(1)
2020/06/19 Javascript
JS模拟实现哈希表及应用详解
2018/05/04 Javascript
自定义vue组件发布到npm的方法
2018/05/09 Javascript
Vue 通过自定义指令回顾v-内置指令(小结)
2018/09/03 Javascript
详解如何使用nvm管理Node.js多版本
2019/05/06 Javascript
vue中keep-alive、activated的探讨和使用详解
2020/07/26 Javascript
[02:53]DOTA2英雄基础教程 山岭巨人小小
2013/12/09 DOTA
python encode和decode的妙用
2009/09/02 Python
python 算法 排序实现快速排序
2012/06/05 Python
Python中利用函数装饰器实现备忘功能
2015/03/30 Python
使用python进行文本预处理和提取特征的实例
2018/06/05 Python
梅尔频率倒谱系数(mfcc)及Python实现
2019/06/18 Python
使用WingPro 7 设置Python路径的方法
2019/07/24 Python
基于django传递数据到后端的例子
2019/08/16 Python
python 多维高斯分布数据生成方式
2019/12/09 Python
详解python算法常用技巧与内置库
2020/10/17 Python
python文件路径操作方法总结
2020/12/21 Python
CSS3实现线性渐变用法示例代码详解
2020/08/07 HTML / CSS
CSS3制作3D立方体loading特效
2020/11/09 HTML / CSS
财务副总经理工作职责
2013/11/25 职场文书
党员学习十八大感想
2014/01/17 职场文书
初中科学教学反思
2014/01/21 职场文书
物理研修随笔感言
2014/02/14 职场文书
学习党的群众路线实践活动思想汇报
2014/09/12 职场文书
2014年党的群众路线教育实践活动整改措施(个人版)
2014/09/25 职场文书
起诉离婚协议书样本
2014/11/25 职场文书
2015商场元旦促销活动策划方案
2014/12/09 职场文书
毛主席纪念堂观后感
2015/06/17 职场文书
考生诚信考试承诺书(2016版)
2016/03/25 职场文书