pytorch实现mnist分类的示例讲解


Posted in Python onJanuary 10, 2020

torchvision包 包含了目前流行的数据集,模型结构和常用的图片转换工具。

torchvision.datasets中包含了以下数据集

MNIST
COCO(用于图像标注和目标检测)(Captioning and Detection)
LSUN Classification
ImageFolder
Imagenet-12
CIFAR10 and CIFAR100
STL10

torchvision.models

torchvision.models模块的 子模块中包含以下模型结构。
AlexNet
VGG
ResNet
SqueezeNet
DenseNet You can construct a model with random weights by calling its constructor:

pytorch torchvision transform

对PIL.Image进行变换

from __future__ import print_function
import argparse #Python 命令行解析工具
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim 
from torchvision import datasets, transforms

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.conv2_drop = nn.Dropout2d()
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, 10)

  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
    x = x.view(-1, 320)
    x = F.relu(self.fc1(x))
    x = F.dropout(x, training=self.training)
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

def train(args, model, device, train_loader, optimizer, epoch):
  model.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % args.log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))

def test(args, model, device, test_loader):
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device)
      output = model(data)
      test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss
      pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
      correct += pred.eq(target.view_as(pred)).sum().item()

  test_loss /= len(test_loader.dataset)
  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

def main():
  # Training settings
  parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  parser.add_argument('--batch-size', type=int, default=64, metavar='N',
            help='input batch size for training (default: 64)')
  parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
            help='input batch size for testing (default: 1000)')
  parser.add_argument('--epochs', type=int, default=10, metavar='N',
            help='number of epochs to train (default: 10)')
  parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
            help='learning rate (default: 0.01)')
  parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
            help='SGD momentum (default: 0.5)')
  parser.add_argument('--no-cuda', action='store_true', default=False,
            help='disables CUDA training')
  parser.add_argument('--seed', type=int, default=1, metavar='S',
            help='random seed (default: 1)')
  parser.add_argument('--log-interval', type=int, default=10, metavar='N',
            help='how many batches to wait before logging training status')
  args = parser.parse_args()
  use_cuda = not args.no_cuda and torch.cuda.is_available()

  torch.manual_seed(args.seed)

  device = torch.device("cuda" if use_cuda else "cpu")

  kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
  train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
            transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=args.batch_size, shuffle=True, **kwargs)
  test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)


  model = Net().to(device)
  optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

  for epoch in range(1, args.epochs + 1):
    train(args, model, device, train_loader, optimizer, epoch)
    test(args, model, device, test_loader)


if __name__ == '__main__':
  main()

以上这篇pytorch实现mnist分类的示例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Nginx+uWsgi实现Python的Django框架站点动静分离
Mar 21 Python
Python中的os.path路径模块中的操作方法总结
Jul 07 Python
mysql 之通过配置文件链接数据库
Aug 12 Python
对Python 3.2 迭代器的next函数实例讲解
Oct 18 Python
Python开发网站目录扫描器的实现
Feb 21 Python
python使用wxpy轻松实现微信防撤回的方法
Feb 21 Python
Python基础教程之if判断,while循环,循环嵌套
Apr 25 Python
python opencv 读取图片 返回图片某像素点的b,g,r值的实现方法
Jul 03 Python
python实现beta分布概率密度函数的方法
Jul 08 Python
利用python3 的pygame模块实现塔防游戏
Dec 30 Python
Pytorch提取模型特征向量保存至csv的例子
Jan 03 Python
30行Python代码实现高分辨率图像导航的方法
May 22 Python
pytorch 准备、训练和测试自己的图片数据的方法
Jan 10 #Python
pytorch GAN伪造手写体mnist数据集方式
Jan 10 #Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
You might like
短波的认识
2021/03/01 无线电
一步一步学习PHP(6) 面向对象
2010/02/16 PHP
php判断是否为json格式的方法
2014/03/04 PHP
php查询whois信息的方法
2015/06/08 PHP
2017年最新PHP经典面试题目汇总(上篇)
2017/03/17 PHP
实例分析PHP将字符串转换成数字的方法
2019/01/27 PHP
jQuery获取Select选择的Text和Value(详细汇总)
2013/01/25 Javascript
jQuery 获取URL的GET参数值的小例子
2013/04/18 Javascript
JS截取字符串常用方法详细整理
2013/10/28 Javascript
javascript中创建对象的几种方法总结
2013/11/01 Javascript
js解析json读取List中的实体对象示例
2014/03/11 Javascript
处理文本部分内容的TextRange对象应用实例
2014/07/29 Javascript
jQuery层次选择器用法示例
2016/09/09 Javascript
ES6入门教程之Iterator与for...of循环详解
2017/05/17 Javascript
ES6中新增的Object.assign()方法详解
2017/09/22 Javascript
基于滚动条位置判断的简单实例
2017/12/14 Javascript
vue2.0 根据状态值进行样式的改变展示方法
2018/03/13 Javascript
学习Vue组件实例
2018/04/28 Javascript
vue 实现在函数中触发路由跳转的示例
2018/09/01 Javascript
vue解决一个方法同时发送多个请求的问题
2018/09/25 Javascript
微信小程序获取用户openid的实现
2018/12/24 Javascript
仅用50行Python代码实现一个简单的代理服务器
2015/04/08 Python
python 移除字符串尾部的数字方法
2018/07/17 Python
Django基础知识 web框架的本质详解
2019/07/18 Python
Python 使用 docopt 解析json参数文件过程讲解
2019/08/13 Python
Python 实现文件读写、坐标寻址、查找替换功能
2019/09/11 Python
django商品分类及商品数据建模实例详解
2020/01/03 Python
python批量替换文件名中的共同字符实例
2020/03/05 Python
使用已经得到的keras模型识别自己手写的数字方式
2020/06/29 Python
详解纯CSS3制作的20种loading动效
2017/07/05 HTML / CSS
吉尔德利巧克力公司:Ghirardelli Chocolate Company
2019/03/27 全球购物
创业计划书的写作技巧及要点
2014/01/31 职场文书
学生操行评语大全
2014/04/24 职场文书
应届生找工作求职信
2014/06/24 职场文书
审计局班子四风对照检查材料思想汇报
2014/10/07 职场文书
Python 匹配文本并在其上一行追加文本
2022/05/11 Python