pytorch实现mnist分类的示例讲解


Posted in Python onJanuary 10, 2020

torchvision包 包含了目前流行的数据集,模型结构和常用的图片转换工具。

torchvision.datasets中包含了以下数据集

MNIST
COCO(用于图像标注和目标检测)(Captioning and Detection)
LSUN Classification
ImageFolder
Imagenet-12
CIFAR10 and CIFAR100
STL10

torchvision.models

torchvision.models模块的 子模块中包含以下模型结构。
AlexNet
VGG
ResNet
SqueezeNet
DenseNet You can construct a model with random weights by calling its constructor:

pytorch torchvision transform

对PIL.Image进行变换

from __future__ import print_function
import argparse #Python 命令行解析工具
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim 
from torchvision import datasets, transforms

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.conv2_drop = nn.Dropout2d()
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, 10)

  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
    x = x.view(-1, 320)
    x = F.relu(self.fc1(x))
    x = F.dropout(x, training=self.training)
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

def train(args, model, device, train_loader, optimizer, epoch):
  model.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % args.log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))

def test(args, model, device, test_loader):
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device)
      output = model(data)
      test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss
      pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
      correct += pred.eq(target.view_as(pred)).sum().item()

  test_loss /= len(test_loader.dataset)
  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

def main():
  # Training settings
  parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  parser.add_argument('--batch-size', type=int, default=64, metavar='N',
            help='input batch size for training (default: 64)')
  parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
            help='input batch size for testing (default: 1000)')
  parser.add_argument('--epochs', type=int, default=10, metavar='N',
            help='number of epochs to train (default: 10)')
  parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
            help='learning rate (default: 0.01)')
  parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
            help='SGD momentum (default: 0.5)')
  parser.add_argument('--no-cuda', action='store_true', default=False,
            help='disables CUDA training')
  parser.add_argument('--seed', type=int, default=1, metavar='S',
            help='random seed (default: 1)')
  parser.add_argument('--log-interval', type=int, default=10, metavar='N',
            help='how many batches to wait before logging training status')
  args = parser.parse_args()
  use_cuda = not args.no_cuda and torch.cuda.is_available()

  torch.manual_seed(args.seed)

  device = torch.device("cuda" if use_cuda else "cpu")

  kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
  train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
            transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=args.batch_size, shuffle=True, **kwargs)
  test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)


  model = Net().to(device)
  optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

  for epoch in range(1, args.epochs + 1):
    train(args, model, device, train_loader, optimizer, epoch)
    test(args, model, device, test_loader)


if __name__ == '__main__':
  main()

以上这篇pytorch实现mnist分类的示例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
wxpython 最小化到托盘与欢迎图片的实现方法
Jun 09 Python
Python赋值语句后逗号的作用分析
Jun 08 Python
Python的Flask框架中配置多个子域名的方法讲解
Jun 07 Python
python实现读取excel写入mysql的小工具详解
Nov 20 Python
python中将\\uxxxx转换为Unicode字符串的方法
Sep 06 Python
python 判断矩阵中每行非零个数的方法
Jan 26 Python
解决Numpy中sum函数求和结果维度的问题
Dec 06 Python
Python识别html主要文本框过程解析
Feb 18 Python
pandas中的ExcelWriter和ExcelFile的实现方法
Apr 24 Python
python 元组的使用方法
Jun 09 Python
Tensorflow之MNIST CNN实现并保存、加载模型
Jun 17 Python
python使用建议技巧分享(三)
Aug 18 Python
pytorch 准备、训练和测试自己的图片数据的方法
Jan 10 #Python
pytorch GAN伪造手写体mnist数据集方式
Jan 10 #Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
You might like
PHP 代码规范小结
2012/03/08 PHP
PHP5.3与5.5废弃与过期函数整理汇总
2014/07/10 PHP
laravel 5.5 关闭token的3种实现方式
2019/10/24 PHP
浅说js变量
2011/05/25 Javascript
浅谈JavaScript中null和undefined
2015/07/09 Javascript
怎么通过onclick事件获取js函数返回值(代码少)
2015/07/28 Javascript
angular2使用简单介绍
2016/03/01 Javascript
Javascript将数字转化成为货币格式字符串
2016/06/22 Javascript
jQuery实现的无限级下拉菜单功能示例
2016/09/12 Javascript
原生JS实现网络彩票投注效果
2016/09/25 Javascript
jQuery设置图片等比例缩小的方法
2017/04/29 jQuery
微信小程序获取微信运动步数的实例代码
2017/07/20 Javascript
详解Nodejs 通过 fs.createWriteStream 保存文件
2017/10/10 NodeJs
Vue中computed与methods的区别详解
2018/03/24 Javascript
node.js利用socket.io实现多人在线匹配联机五子棋
2018/05/31 Javascript
js中的闭包实例展示
2018/11/01 Javascript
vue下canvas裁剪图片实例讲解
2020/04/16 Javascript
Vue中父子组件的值传递与方法传递
2020/09/28 Javascript
Python中的生成器和yield详细介绍
2015/01/09 Python
利用python写个下载teahour音频的小脚本
2017/05/08 Python
Python实现读取json文件到excel表
2017/11/18 Python
python爬虫 基于requests模块的get请求实现详解
2019/08/20 Python
Python实现桌面翻译工具【新手必学】
2020/02/12 Python
为什么说python更适合树莓派编程
2020/07/20 Python
前端实现弹幕效果的方法总结(包含css3和canvas的实现方式)
2018/07/12 HTML / CSS
瑞典首都斯德哥尔摩的多元奢侈时尚品牌:Acne Studios
2017/07/09 全球购物
瑞典廉价机票预订网站:Seat24
2018/06/19 全球购物
Fox Racing英国官网:越野摩托车和山地自行车服装
2020/02/26 全球购物
介绍一下sql server的安全性
2014/08/10 面试题
公司股东合作协议书
2014/09/14 职场文书
对外汉语专业大学生职业生涯规划书
2014/10/11 职场文书
老公出轨后的保证书
2015/05/08 职场文书
新闻稿件写作范文
2015/07/18 职场文书
vue实现锚点定位功能
2021/06/29 Vue.js
【海涛DOTA】D-cup邀请赛NV.cn vs DT.Love
2022/04/01 DOTA
Go 内联优化让程序员爱不释手
2022/06/21 Golang