pytorch实现mnist分类的示例讲解


Posted in Python onJanuary 10, 2020

torchvision包 包含了目前流行的数据集,模型结构和常用的图片转换工具。

torchvision.datasets中包含了以下数据集

MNIST
COCO(用于图像标注和目标检测)(Captioning and Detection)
LSUN Classification
ImageFolder
Imagenet-12
CIFAR10 and CIFAR100
STL10

torchvision.models

torchvision.models模块的 子模块中包含以下模型结构。
AlexNet
VGG
ResNet
SqueezeNet
DenseNet You can construct a model with random weights by calling its constructor:

pytorch torchvision transform

对PIL.Image进行变换

from __future__ import print_function
import argparse #Python 命令行解析工具
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim 
from torchvision import datasets, transforms

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.conv2_drop = nn.Dropout2d()
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, 10)

  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
    x = x.view(-1, 320)
    x = F.relu(self.fc1(x))
    x = F.dropout(x, training=self.training)
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

def train(args, model, device, train_loader, optimizer, epoch):
  model.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % args.log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))

def test(args, model, device, test_loader):
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device)
      output = model(data)
      test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss
      pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
      correct += pred.eq(target.view_as(pred)).sum().item()

  test_loss /= len(test_loader.dataset)
  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

def main():
  # Training settings
  parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  parser.add_argument('--batch-size', type=int, default=64, metavar='N',
            help='input batch size for training (default: 64)')
  parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
            help='input batch size for testing (default: 1000)')
  parser.add_argument('--epochs', type=int, default=10, metavar='N',
            help='number of epochs to train (default: 10)')
  parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
            help='learning rate (default: 0.01)')
  parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
            help='SGD momentum (default: 0.5)')
  parser.add_argument('--no-cuda', action='store_true', default=False,
            help='disables CUDA training')
  parser.add_argument('--seed', type=int, default=1, metavar='S',
            help='random seed (default: 1)')
  parser.add_argument('--log-interval', type=int, default=10, metavar='N',
            help='how many batches to wait before logging training status')
  args = parser.parse_args()
  use_cuda = not args.no_cuda and torch.cuda.is_available()

  torch.manual_seed(args.seed)

  device = torch.device("cuda" if use_cuda else "cpu")

  kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
  train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
            transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=args.batch_size, shuffle=True, **kwargs)
  test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)


  model = Net().to(device)
  optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

  for epoch in range(1, args.epochs + 1):
    train(args, model, device, train_loader, optimizer, epoch)
    test(args, model, device, test_loader)


if __name__ == '__main__':
  main()

以上这篇pytorch实现mnist分类的示例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
利用Python命令行传递实例化对象的方法
Nov 02 Python
利用Python如何制作好玩的GIF动图详解
Jul 11 Python
Python 2.7中文显示与处理方法
Jul 16 Python
利用python循环创建多个文件的方法
Oct 25 Python
Python打开文件,将list、numpy数组内容写入txt文件中的方法
Oct 26 Python
Linux 修改Python命令的方法示例
Dec 03 Python
python提取具有某种特定字符串的行数据方法
Dec 11 Python
Python代码块及缓存机制原理详解
Dec 13 Python
pytorch 实现在预训练模型的 input上增减通道
Jan 06 Python
python zip,lambda,map函数代码实例
Apr 04 Python
一文详述 Python 中的 property 语法
Sep 01 Python
基于Python实现一个春节倒计时脚本
Jan 22 Python
pytorch 准备、训练和测试自己的图片数据的方法
Jan 10 #Python
pytorch GAN伪造手写体mnist数据集方式
Jan 10 #Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
You might like
PHP函数之error_reporting(E_ALL ^ E_NOTICE)详细说明
2011/07/01 PHP
如何用php获取文件名后缀
2013/06/09 PHP
PHP6 中可能会出现的新特性预览
2014/04/04 PHP
php与flash as3 socket通信传送文件实现代码
2014/08/16 PHP
JavaScript 数组循环引起的思考
2010/01/01 Javascript
浅谈tudou土豆网首页图片延迟加载的效果
2010/06/23 Javascript
事件绑定之小测试  onclick && addEventListener
2011/07/31 Javascript
js 动态为textbox添加下拉框数据源的方法
2014/04/24 Javascript
jQuery常用数据处理方法小结
2015/02/20 Javascript
jQuery弹出层插件Lightbox_me使用指南
2015/04/21 Javascript
javascript实现状态栏中文字动态显示的方法
2015/10/20 Javascript
Angular2 (RC5) 路由与导航详解
2016/09/21 Javascript
jquery 仿锚点跳转到页面指定位置的实例
2017/02/14 Javascript
Vue2 Vue-cli中使用Typescript的配置详解
2017/07/24 Javascript
浅谈Vue.js 1.x 和 2.x 实例的生命周期
2017/07/25 Javascript
Angular2 组件间通过@Input @Output通讯示例
2017/08/24 Javascript
Angular实现下拉框模糊查询功能示例
2018/01/03 Javascript
深入了解响应式React Native Echarts组件
2019/05/29 Javascript
javascript中innerHTML 获取或替换html内容的实现代码
2020/03/17 Javascript
python利用dir函数查看类中所有成员函数示例代码
2017/09/08 Python
python实现从pdf文件中提取文本,并自动翻译的方法
2018/11/28 Python
在Python中字典根据多项规则排序的方法
2019/01/21 Python
详解python多线程之间的同步(一)
2019/04/03 Python
tensorflow 保存模型和取出中间权重例子
2020/01/24 Python
Python实现捕获异常发生的文件和具体行数
2020/04/25 Python
Python3+PyCharm+Django+Django REST framework配置与简单开发教程
2021/02/16 Python
荷兰网上买鞋:MooieSchoenen.nl
2017/09/12 全球购物
电话客服工作职责
2014/07/27 职场文书
高一军训的心得体会
2014/09/01 职场文书
个人授权委托书模板
2014/09/14 职场文书
作风整顿个人剖析材料
2014/10/06 职场文书
标会主持词应该怎么写?
2019/08/15 职场文书
SQLServer中exists和except用法介绍
2021/12/04 SQL Server
python获取字符串中的email
2022/03/31 Python
Win11软件图标固定到任务栏
2022/04/19 数码科技
GPU服务器的多用户配置方法
2022/07/07 Servers