pytorch实现mnist分类的示例讲解


Posted in Python onJanuary 10, 2020

torchvision包 包含了目前流行的数据集,模型结构和常用的图片转换工具。

torchvision.datasets中包含了以下数据集

MNIST
COCO(用于图像标注和目标检测)(Captioning and Detection)
LSUN Classification
ImageFolder
Imagenet-12
CIFAR10 and CIFAR100
STL10

torchvision.models

torchvision.models模块的 子模块中包含以下模型结构。
AlexNet
VGG
ResNet
SqueezeNet
DenseNet You can construct a model with random weights by calling its constructor:

pytorch torchvision transform

对PIL.Image进行变换

from __future__ import print_function
import argparse #Python 命令行解析工具
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim 
from torchvision import datasets, transforms

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.conv2_drop = nn.Dropout2d()
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, 10)

  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
    x = x.view(-1, 320)
    x = F.relu(self.fc1(x))
    x = F.dropout(x, training=self.training)
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

def train(args, model, device, train_loader, optimizer, epoch):
  model.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % args.log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))

def test(args, model, device, test_loader):
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device)
      output = model(data)
      test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss
      pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
      correct += pred.eq(target.view_as(pred)).sum().item()

  test_loss /= len(test_loader.dataset)
  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

def main():
  # Training settings
  parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  parser.add_argument('--batch-size', type=int, default=64, metavar='N',
            help='input batch size for training (default: 64)')
  parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
            help='input batch size for testing (default: 1000)')
  parser.add_argument('--epochs', type=int, default=10, metavar='N',
            help='number of epochs to train (default: 10)')
  parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
            help='learning rate (default: 0.01)')
  parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
            help='SGD momentum (default: 0.5)')
  parser.add_argument('--no-cuda', action='store_true', default=False,
            help='disables CUDA training')
  parser.add_argument('--seed', type=int, default=1, metavar='S',
            help='random seed (default: 1)')
  parser.add_argument('--log-interval', type=int, default=10, metavar='N',
            help='how many batches to wait before logging training status')
  args = parser.parse_args()
  use_cuda = not args.no_cuda and torch.cuda.is_available()

  torch.manual_seed(args.seed)

  device = torch.device("cuda" if use_cuda else "cpu")

  kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
  train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
            transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=args.batch_size, shuffle=True, **kwargs)
  test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)


  model = Net().to(device)
  optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

  for epoch in range(1, args.epochs + 1):
    train(args, model, device, train_loader, optimizer, epoch)
    test(args, model, device, test_loader)


if __name__ == '__main__':
  main()

以上这篇pytorch实现mnist分类的示例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python从MP3文件获取id3的方法
Jun 15 Python
python类:class创建、数据方法属性及访问控制详解
Jul 25 Python
Python中实现switch功能实例解析
Jan 11 Python
Python使用gRPC传输协议教程
Oct 16 Python
破解安装Pycharm的方法
Oct 19 Python
在PyCharm中实现关闭一个死循环程序的方法
Nov 29 Python
在python中利用opencv简单做图片比对的方法
Jan 24 Python
python查询文件夹下excel的sheet名代码实例
Apr 02 Python
python用match()函数爬数据方法详解
Jul 23 Python
基于Python实现签到脚本过程解析
Oct 25 Python
python 8种必备的gui库
Aug 27 Python
PyTorch 如何设置随机数种子使结果可复现
May 12 Python
pytorch 准备、训练和测试自己的图片数据的方法
Jan 10 #Python
pytorch GAN伪造手写体mnist数据集方式
Jan 10 #Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
You might like
php Smarty 字符比较代码
2011/02/27 PHP
Apache中php.ini的设置方法
2013/02/28 PHP
PHP连接操作access数据库实例
2015/03/30 PHP
PHP中key和current,next的联合运用实例分析
2016/03/29 PHP
PHP实现根据数组的值进行分组的方法
2017/04/20 PHP
PHP实现的折半查询算法示例
2017/10/09 PHP
php使用curl模拟浏览器表单上传文件或者图片的方法
2018/11/10 PHP
php用户名的密码加密更安全的方法
2019/06/21 PHP
jquery.autocomplete修改实现键盘上下键自动填充示例
2013/11/19 Javascript
教你如何在 Javascript 文件里使用 .Net MVC Razor 语法
2014/07/23 Javascript
Node.js 学习笔记之简介、安装及配置
2015/03/03 Javascript
JavaScript操作class和style样式代码详解
2016/02/13 Javascript
jstree的简单实例
2016/12/01 Javascript
两种简单的跨域方法(jsonp、php)
2017/01/02 Javascript
vue实现的上传图片到数据库并显示到页面功能示例
2018/03/17 Javascript
Vue多系统切换实现方案
2018/06/05 Javascript
JS 音频可视化插件Wavesurfer.js的使用教程
2018/10/31 Javascript
Vue实现一个无限加载列表功能
2018/11/13 Javascript
vue中将html字符串转换成html后遇到的问题小结
2018/12/10 Javascript
Webpack4+Babel7+ES6兼容IE8的实现
2019/04/10 Javascript
javascript this指向相关问题及改变方法
2020/11/19 Javascript
[05:41]2014DOTA2西雅图国际邀请赛 小组赛7月10日TOPPLAY
2014/07/10 DOTA
[47:50]Secret vs VP 2018国际邀请赛小组赛BO2 第二场 8.17
2018/08/20 DOTA
Python抓取框架 Scrapy的架构
2016/08/12 Python
详解python分布式进程
2018/10/08 Python
python实现文件的备份流程详解
2019/06/18 Python
django 自定义过滤器(filter)处理较为复杂的变量方法
2019/08/12 Python
scrapy利用selenium爬取豆瓣阅读的全步骤
2020/09/20 Python
CSS3属性使网站设计增强同时不消弱可用性
2009/08/29 HTML / CSS
瑞典廉价机票预订网站:Seat24
2018/06/19 全球购物
一份软件工程师的面试试题
2016/02/01 面试题
活动总结怎么写啊
2014/05/07 职场文书
篮球比赛拉拉队口号
2014/06/10 职场文书
2014年材料员工作总结
2014/11/19 职场文书
2015年乡镇党务公开工作总结
2015/05/19 职场文书
python实现学生信息管理系统(面向对象)
2022/06/05 Python