pytorch实现mnist分类的示例讲解


Posted in Python onJanuary 10, 2020

torchvision包 包含了目前流行的数据集,模型结构和常用的图片转换工具。

torchvision.datasets中包含了以下数据集

MNIST
COCO(用于图像标注和目标检测)(Captioning and Detection)
LSUN Classification
ImageFolder
Imagenet-12
CIFAR10 and CIFAR100
STL10

torchvision.models

torchvision.models模块的 子模块中包含以下模型结构。
AlexNet
VGG
ResNet
SqueezeNet
DenseNet You can construct a model with random weights by calling its constructor:

pytorch torchvision transform

对PIL.Image进行变换

from __future__ import print_function
import argparse #Python 命令行解析工具
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim 
from torchvision import datasets, transforms

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.conv2_drop = nn.Dropout2d()
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, 10)

  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
    x = x.view(-1, 320)
    x = F.relu(self.fc1(x))
    x = F.dropout(x, training=self.training)
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

def train(args, model, device, train_loader, optimizer, epoch):
  model.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % args.log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))

def test(args, model, device, test_loader):
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device)
      output = model(data)
      test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss
      pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
      correct += pred.eq(target.view_as(pred)).sum().item()

  test_loss /= len(test_loader.dataset)
  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

def main():
  # Training settings
  parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  parser.add_argument('--batch-size', type=int, default=64, metavar='N',
            help='input batch size for training (default: 64)')
  parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
            help='input batch size for testing (default: 1000)')
  parser.add_argument('--epochs', type=int, default=10, metavar='N',
            help='number of epochs to train (default: 10)')
  parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
            help='learning rate (default: 0.01)')
  parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
            help='SGD momentum (default: 0.5)')
  parser.add_argument('--no-cuda', action='store_true', default=False,
            help='disables CUDA training')
  parser.add_argument('--seed', type=int, default=1, metavar='S',
            help='random seed (default: 1)')
  parser.add_argument('--log-interval', type=int, default=10, metavar='N',
            help='how many batches to wait before logging training status')
  args = parser.parse_args()
  use_cuda = not args.no_cuda and torch.cuda.is_available()

  torch.manual_seed(args.seed)

  device = torch.device("cuda" if use_cuda else "cpu")

  kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
  train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
            transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=args.batch_size, shuffle=True, **kwargs)
  test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)


  model = Net().to(device)
  optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

  for epoch in range(1, args.epochs + 1):
    train(args, model, device, train_loader, optimizer, epoch)
    test(args, model, device, test_loader)


if __name__ == '__main__':
  main()

以上这篇pytorch实现mnist分类的示例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python多线程扫描端口示例
Jan 16 Python
Python Trie树实现字典排序
Mar 28 Python
python常见排序算法基础教程
Apr 13 Python
python学习笔记之列表(list)与元组(tuple)详解
Nov 23 Python
Python获取本机所有网卡ip,掩码和广播地址实例代码
Jan 22 Python
python下解压缩zip文件并删除文件的实例
Apr 24 Python
python 实现求解字符串集的最长公共前缀方法
Jul 20 Python
Flask框架URL管理操作示例【基于@app.route】
Jul 23 Python
python几种常用功能实现代码实例
Dec 25 Python
pandas之分组groupby()的使用整理与总结
Jun 18 Python
Python入门之使用pandas分析excel数据
May 12 Python
Python利用zhdate模块实现农历日期处理
Mar 31 Python
pytorch 准备、训练和测试自己的图片数据的方法
Jan 10 #Python
pytorch GAN伪造手写体mnist数据集方式
Jan 10 #Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
You might like
如何用C语言编写PHP扩展的详解
2013/06/13 PHP
PHP图片处理之使用imagecopy函数添加图片水印实例
2014/11/19 PHP
php实现使用正则将文本中的网址转换成链接标签
2014/12/03 PHP
PHP+ajax分页实例简析
2015/12/07 PHP
javascript 函数速查表
2010/02/07 Javascript
json2.js的初步学习与了解
2011/10/06 Javascript
9款2014最热门jQuery实用特效推荐
2014/12/07 Javascript
使用jQuery+EasyUI实现CheckBoxTree的级联选中特效
2015/12/06 Javascript
基于javascript实现简单的抽奖系统
2020/04/15 Javascript
JavaScript数组实现数据结构中的队列与堆栈
2016/05/26 Javascript
Bootstrap3 内联单选和多选框
2016/12/29 Javascript
关于Javascript中document.cookie的使用
2017/03/08 Javascript
jQuery is not defined 错误原因与解决方法小结
2017/03/19 Javascript
微信小程序商城项目之侧栏分类效果(1)
2017/04/17 Javascript
详解JavaScript调用栈、尾递归和手动优化
2017/06/03 Javascript
js使用formData实现批量上传
2020/03/27 Javascript
jQuery AJAX与jQuery事件的分析讲解
2019/02/18 jQuery
Python实现栈的方法
2015/05/26 Python
Python实现处理逆波兰表达式示例
2018/07/30 Python
Python标准库shutil用法实例详解
2018/08/13 Python
django foreignkey(外键)的实现
2019/07/29 Python
python+Selenium自动化测试——输入,点击操作
2020/03/06 Python
详解使用scrapy进行模拟登陆三种方式
2021/02/21 Python
携程旅行网:中国领先的在线旅行服务公司
2017/02/17 全球购物
商务英语应届生自我鉴定
2013/12/08 职场文书
大学总结自我鉴定
2014/01/18 职场文书
设计专业毕业生求职信
2014/06/25 职场文书
计算机实训报告总结
2014/11/05 职场文书
肖申克救赎观后感
2015/06/02 职场文书
2015年教务处干事工作总结
2015/07/22 职场文书
运输公司工作总结
2015/08/11 职场文书
大学学习委员竞选稿
2015/11/20 职场文书
少先大队干部竞选稿
2015/11/20 职场文书
Python3.8官网文档之类的基础语法阅读
2021/09/04 Python
手写实现JS中的new
2021/11/07 Javascript
Python&Matlab实现灰狼优化算法的示例代码
2022/03/21 Python