pytorch实现mnist分类的示例讲解


Posted in Python onJanuary 10, 2020

torchvision包 包含了目前流行的数据集,模型结构和常用的图片转换工具。

torchvision.datasets中包含了以下数据集

MNIST
COCO(用于图像标注和目标检测)(Captioning and Detection)
LSUN Classification
ImageFolder
Imagenet-12
CIFAR10 and CIFAR100
STL10

torchvision.models

torchvision.models模块的 子模块中包含以下模型结构。
AlexNet
VGG
ResNet
SqueezeNet
DenseNet You can construct a model with random weights by calling its constructor:

pytorch torchvision transform

对PIL.Image进行变换

from __future__ import print_function
import argparse #Python 命令行解析工具
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim 
from torchvision import datasets, transforms

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.conv2_drop = nn.Dropout2d()
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, 10)

  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
    x = x.view(-1, 320)
    x = F.relu(self.fc1(x))
    x = F.dropout(x, training=self.training)
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

def train(args, model, device, train_loader, optimizer, epoch):
  model.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % args.log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))

def test(args, model, device, test_loader):
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device)
      output = model(data)
      test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss
      pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
      correct += pred.eq(target.view_as(pred)).sum().item()

  test_loss /= len(test_loader.dataset)
  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

def main():
  # Training settings
  parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  parser.add_argument('--batch-size', type=int, default=64, metavar='N',
            help='input batch size for training (default: 64)')
  parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
            help='input batch size for testing (default: 1000)')
  parser.add_argument('--epochs', type=int, default=10, metavar='N',
            help='number of epochs to train (default: 10)')
  parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
            help='learning rate (default: 0.01)')
  parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
            help='SGD momentum (default: 0.5)')
  parser.add_argument('--no-cuda', action='store_true', default=False,
            help='disables CUDA training')
  parser.add_argument('--seed', type=int, default=1, metavar='S',
            help='random seed (default: 1)')
  parser.add_argument('--log-interval', type=int, default=10, metavar='N',
            help='how many batches to wait before logging training status')
  args = parser.parse_args()
  use_cuda = not args.no_cuda and torch.cuda.is_available()

  torch.manual_seed(args.seed)

  device = torch.device("cuda" if use_cuda else "cpu")

  kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
  train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
            transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=args.batch_size, shuffle=True, **kwargs)
  test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)


  model = Net().to(device)
  optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

  for epoch in range(1, args.epochs + 1):
    train(args, model, device, train_loader, optimizer, epoch)
    test(args, model, device, test_loader)


if __name__ == '__main__':
  main()

以上这篇pytorch实现mnist分类的示例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python socket 超时设置 errno 10054
Jul 01 Python
用Python从零实现贝叶斯分类器的机器学习的教程
Mar 31 Python
使用Django Form解决表单数据无法动态刷新的两种方法
Jul 14 Python
python针对不定分隔符切割提取字符串的方法
Oct 26 Python
Python爬虫文件下载图文教程
Dec 23 Python
Python OpenCV之图片缩放的实现(cv2.resize)
Jun 28 Python
python多线程使用方法实例详解
Dec 30 Python
浅谈django 重载str 方法
May 19 Python
Python 跨.py文件调用自定义函数说明
Jun 01 Python
解决redis与Python交互取出来的是bytes类型的问题
Jul 16 Python
pycharm 多行批量缩进和反向缩进快捷键介绍
Jan 15 Python
OpenCV-Python实现油画效果的实例
Jun 08 Python
pytorch 准备、训练和测试自己的图片数据的方法
Jan 10 #Python
pytorch GAN伪造手写体mnist数据集方式
Jan 10 #Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
You might like
PHP中file_get_contents高?用法实例
2014/09/24 PHP
2014最热门的24个php类库汇总
2014/12/18 PHP
php计算两个整数的最大公约数常用算法小结
2015/03/05 PHP
php跨服务器访问方法小结
2015/05/12 PHP
详解PHP对象的串行化与反串行化
2016/01/24 PHP
php设计模式之委托模式
2016/02/13 PHP
PHP内存溢出优化代码详解
2021/02/26 PHP
JS实现下拉框的动态添加(附效果)
2013/04/03 Javascript
jquery prop的使用介绍及与attr的区别
2013/12/19 Javascript
jquery常用方法及使用示例汇总
2014/11/08 Javascript
基于BootStrap Metronic开发框架经验小结【六】对话框及提示框的处理和优化
2016/05/12 Javascript
JS for...in 遍历语句用法实例分析
2016/08/24 Javascript
利用jquery实现瀑布流3种案例
2016/09/18 Javascript
详解打造 Vue.js 可复用组件
2017/03/24 Javascript
详解node HTTP请求客户端 - Request
2017/05/05 Javascript
js字符串与Unicode编码互相转换
2017/05/17 Javascript
JS实现发送短信验证后按钮倒计时功能(防止刷新倒计时失效)
2017/07/07 Javascript
使用Vue-scroller页面input框不能触发滑动的问题及解决方法
2020/08/08 Javascript
[14:20]刀塔大凶女神互压各路奇葩屌丝
2014/05/16 DOTA
通过代码实例展示Python中列表生成式的用法
2015/03/31 Python
Python实现简单遗传算法(SGA)
2018/01/29 Python
python实现拓扑排序的基本教程
2018/03/11 Python
python通过ffmgep从视频中抽帧的方法
2018/12/05 Python
如何利用python给图片添加半透明水印
2019/09/06 Python
Django之使用内置函数和celery发邮件的方法示例
2019/09/16 Python
django框架创建应用操作示例
2019/09/26 Python
使用纯 CSS 创作一个脉动 loader效果的源码
2018/09/28 HTML / CSS
为奢侈时尚带来了慈善元素:Olivela
2018/09/29 全球购物
在C语言中"指针和数组等价"到底是什么意思?
2014/03/24 面试题
综合测评自我鉴定
2013/10/08 职场文书
生产文员岗位职责
2014/04/05 职场文书
纪律教育月活动总结
2014/08/26 职场文书
2014年机关党建工作总结
2014/11/11 职场文书
2015年幼儿园元旦游艺活动策划书
2014/12/09 职场文书
解决Django transaction进行事务管理踩过的坑
2021/04/24 Python
Java各种比较对象的方式的对比总结
2021/06/20 Java/Android