pytorch实现mnist分类的示例讲解


Posted in Python onJanuary 10, 2020

torchvision包 包含了目前流行的数据集,模型结构和常用的图片转换工具。

torchvision.datasets中包含了以下数据集

MNIST
COCO(用于图像标注和目标检测)(Captioning and Detection)
LSUN Classification
ImageFolder
Imagenet-12
CIFAR10 and CIFAR100
STL10

torchvision.models

torchvision.models模块的 子模块中包含以下模型结构。
AlexNet
VGG
ResNet
SqueezeNet
DenseNet You can construct a model with random weights by calling its constructor:

pytorch torchvision transform

对PIL.Image进行变换

from __future__ import print_function
import argparse #Python 命令行解析工具
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim 
from torchvision import datasets, transforms

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.conv2_drop = nn.Dropout2d()
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, 10)

  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
    x = x.view(-1, 320)
    x = F.relu(self.fc1(x))
    x = F.dropout(x, training=self.training)
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

def train(args, model, device, train_loader, optimizer, epoch):
  model.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % args.log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))

def test(args, model, device, test_loader):
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device)
      output = model(data)
      test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss
      pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
      correct += pred.eq(target.view_as(pred)).sum().item()

  test_loss /= len(test_loader.dataset)
  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

def main():
  # Training settings
  parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  parser.add_argument('--batch-size', type=int, default=64, metavar='N',
            help='input batch size for training (default: 64)')
  parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
            help='input batch size for testing (default: 1000)')
  parser.add_argument('--epochs', type=int, default=10, metavar='N',
            help='number of epochs to train (default: 10)')
  parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
            help='learning rate (default: 0.01)')
  parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
            help='SGD momentum (default: 0.5)')
  parser.add_argument('--no-cuda', action='store_true', default=False,
            help='disables CUDA training')
  parser.add_argument('--seed', type=int, default=1, metavar='S',
            help='random seed (default: 1)')
  parser.add_argument('--log-interval', type=int, default=10, metavar='N',
            help='how many batches to wait before logging training status')
  args = parser.parse_args()
  use_cuda = not args.no_cuda and torch.cuda.is_available()

  torch.manual_seed(args.seed)

  device = torch.device("cuda" if use_cuda else "cpu")

  kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
  train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
            transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=args.batch_size, shuffle=True, **kwargs)
  test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)


  model = Net().to(device)
  optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

  for epoch in range(1, args.epochs + 1):
    train(args, model, device, train_loader, optimizer, epoch)
    test(args, model, device, test_loader)


if __name__ == '__main__':
  main()

以上这篇pytorch实现mnist分类的示例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中关于字符串对象的一些基础知识
Apr 08 Python
Pandas 对Dataframe结构排序的实现方法
Apr 10 Python
Python实现全排列的打印
Aug 18 Python
Python实现使用request模块下载图片demo示例
May 24 Python
Python流行ORM框架sqlalchemy安装与使用教程
Jun 04 Python
django自带调试服务器的使用详解
Aug 29 Python
wxPython绘图模块wxPyPlot实现数据可视化
Nov 19 Python
django连接mysql数据库及建表操作实例详解
Dec 10 Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 Python
Python networkx包的实现
Feb 14 Python
详解python tcp编程
Aug 24 Python
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
Jun 05 Python
pytorch 准备、训练和测试自己的图片数据的方法
Jan 10 #Python
pytorch GAN伪造手写体mnist数据集方式
Jan 10 #Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
You might like
PHP实现MySQL更新记录的代码
2008/06/07 PHP
一个简单的PHP验证码实现代码
2014/05/10 PHP
php输出xml属性的方法
2015/03/19 PHP
64位windows系统下安装Memcache缓存
2015/12/06 PHP
WordPress中用于获取文章作者与分类信息的方法整理
2015/12/17 PHP
php判断是否连接上网络的方法实例详解
2016/12/14 PHP
ThinkPHP框架实现的MySQL数据库备份功能示例
2018/05/24 PHP
PHP正则表达式函数preg_replace用法实例分析
2020/06/04 PHP
又一个小巧的图片预加载类
2007/05/05 Javascript
cookie丢失问题(认证失效) Authentication (用户验证信息)也会丢失
2009/06/04 Javascript
js 巧妙去除数组中的重复项
2010/01/25 Javascript
JQuery 1.3.2以上版本中出现pareseerror错误的解决方法
2011/01/11 Javascript
jQuery之网页换肤实现代码
2011/04/30 Javascript
用JS判断IE版本的代码 超管用!
2011/08/09 Javascript
基于jQuery的投票系统显示结果插件
2011/08/12 Javascript
背景图跟随鼠标移动的Mootools插件实现代码
2011/12/12 Javascript
jquery实现页面图片等比例放大缩小功能
2014/02/12 Javascript
js判断iframe内的网页是否滚动到底部触发事件
2014/03/18 Javascript
在javascript中随机数 math random如何生成指定范围数值的随机数
2015/10/21 Javascript
九种原生js动画效果
2015/11/11 Javascript
Javascript中函数名.length属性用法分析(对比arguments.length)
2016/09/16 Javascript
Angular.js基础学习之初始化
2017/03/10 Javascript
2种简单的js倒计时方式
2017/10/20 Javascript
vue 本地服务不能被外部IP访问的完美解决方法
2018/10/29 Javascript
JavaScript如何把两个数组对象合并过程解析
2019/10/10 Javascript
在vue-cli中引入lodash.js并使用详解
2019/11/13 Javascript
python正则表达式去除两个特殊字符间的内容方法
2018/12/24 Python
PyCharm搭建Spark开发环境实现第一个pyspark程序
2019/06/13 Python
Tensorflow卷积实现原理+手写python代码实现卷积教程
2020/05/22 Python
HTML5 WebSocket实现点对点聊天的示例代码
2018/01/31 HTML / CSS
印度在线杂货店:bigbasket
2018/08/23 全球购物
纬创Java面试题笔试题
2014/10/02 面试题
27个经典Linux面试题及答案,你知道几个?
2014/03/11 面试题
安全生产月标语
2014/10/07 职场文书
大学生考试作弊检讨书1000字
2014/10/14 职场文书
优秀员工演讲稿
2019/06/21 职场文书