简单易懂Pytorch实战实例VGG深度网络


Posted in Python onAugust 27, 2019

模型VGG,数据集cifar。对照这份代码走一遍,大概就知道整个pytorch的运行机制。

来源

定义模型:

'''VGG11/13/16/19 in Pytorch.'''
import torch
import torch.nn as nn
from torch.autograd import Variable


cfg = {
  'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
  'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
  'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
  'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

# 模型需继承nn.Module
class VGG(nn.Module):
# 初始化参数:
  def __init__(self, vgg_name):
    super(VGG, self).__init__()
    self.features = self._make_layers(cfg[vgg_name])
    self.classifier = nn.Linear(512, 10)

# 模型计算时的前向过程,也就是按照这个过程进行计算
  def forward(self, x):
    out = self.features(x)
    out = out.view(out.size(0), -1)
    out = self.classifier(out)
    return out

  def _make_layers(self, cfg):
    layers = []
    in_channels = 3
    for x in cfg:
      if x == 'M':
        layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
      else:
        layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
              nn.BatchNorm2d(x),
              nn.ReLU(inplace=True)]
        in_channels = x
    layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
    return nn.Sequential(*layers)

# net = VGG('VGG11')
# x = torch.randn(2,3,32,32)
# print(net(Variable(x)).size())

定义训练过程:

'''Train CIFAR10 with PyTorch.'''
from __future__ import print_function

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

from models import *
from utils import progress_bar
from torch.autograd import Variable

# 获取参数
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
args = parser.parse_args()

use_cuda = torch.cuda.is_available()
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch

# 获取数据集,并先进行预处理
print('==> Preparing data..')
# 图像预处理和增强
transform_train = transforms.Compose([
  transforms.RandomCrop(32, padding=4),
  transforms.RandomHorizontalFlip(),
  transforms.ToTensor(),
  transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 继续训练模型或新建一个模型
if args.resume:
  # Load checkpoint.
  print('==> Resuming from checkpoint..')
  assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
  checkpoint = torch.load('./checkpoint/ckpt.t7')
  net = checkpoint['net']
  best_acc = checkpoint['acc']
  start_epoch = checkpoint['epoch']
else:
  print('==> Building model..')
  net = VGG('VGG16')
  # net = ResNet18()
  # net = PreActResNet18()
  # net = GoogLeNet()
  # net = DenseNet121()
  # net = ResNeXt29_2x64d()
  # net = MobileNet()
  # net = MobileNetV2()
  # net = DPN92()
  # net = ShuffleNetG2()
  # net = SENet18()

# 如果GPU可用,使用GPU
if use_cuda:
  # move param and buffer to GPU
  net.cuda()
  # parallel use GPU
  net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()-1))
  # speed up slightly
  cudnn.benchmark = True


# 定义度量和优化
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

# 训练阶段
def train(epoch):
  print('\nEpoch: %d' % epoch)
  # switch to train mode
  net.train()
  train_loss = 0
  correct = 0
  total = 0
  # batch 数据
  for batch_idx, (inputs, targets) in enumerate(trainloader):
    # 将数据移到GPU上
    if use_cuda:
      inputs, targets = inputs.cuda(), targets.cuda()
    # 先将optimizer梯度先置为0
    optimizer.zero_grad()
    # Variable表示该变量属于计算图的一部分,此处是图计算的开始处。图的leaf variable
    inputs, targets = Variable(inputs), Variable(targets)
    # 模型输出
    outputs = net(inputs)
    # 计算loss,图的终点处
    loss = criterion(outputs, targets)
    # 反向传播,计算梯度
    loss.backward()
    # 更新参数
    optimizer.step()
    # 注意如果你想统计loss,切勿直接使用loss相加,而是使用loss.data[0]。因为loss是计算图的一部分,如果你直接加loss,代表total loss同样属于模型一部分,那么图就越来越大
    train_loss += loss.data[0]
    # 数据统计
    _, predicted = torch.max(outputs.data, 1)
    total += targets.size(0)
    correct += predicted.eq(targets.data).cpu().sum()

    progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
      % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

# 测试阶段
def test(epoch):
  global best_acc
  # 先切到测试模型
  net.eval()
  test_loss = 0
  correct = 0
  total = 0
  for batch_idx, (inputs, targets) in enumerate(testloader):
    if use_cuda:
      inputs, targets = inputs.cuda(), targets.cuda()
    inputs, targets = Variable(inputs, volatile=True), Variable(targets)
    outputs = net(inputs)
    loss = criterion(outputs, targets)
    # loss is variable , if add it(+=loss) directly, there will be a bigger ang bigger graph.
    test_loss += loss.data[0]
    _, predicted = torch.max(outputs.data, 1)
    total += targets.size(0)
    correct += predicted.eq(targets.data).cpu().sum()

    progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
      % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

  # Save checkpoint.
  # 保存模型
  acc = 100.*correct/total
  if acc > best_acc:
    print('Saving..')
    state = {
      'net': net.module if use_cuda else net,
      'acc': acc,
      'epoch': epoch,
    }
    if not os.path.isdir('checkpoint'):
      os.mkdir('checkpoint')
    torch.save(state, './checkpoint/ckpt.t7')
    best_acc = acc

# 运行模型
for epoch in range(start_epoch, start_epoch+200):
  train(epoch)
  test(epoch)
  # 清除部分无用变量 
  torch.cuda.empty_cache()

运行:

新模型:
python main.py --lr=0.01
旧模型继续训练:
python main.py --resume --lr=0.01

一些utility:

'''Some helper functions for PyTorch, including:
  - get_mean_and_std: calculate the mean and std value of dataset.
  - msr_init: net parameter initialization.
  - progress_bar: progress bar mimic xlua.progress.
'''
import os
import sys
import time
import math

import torch.nn as nn
import torch.nn.init as init


def get_mean_and_std(dataset):
  '''Compute the mean and std value of dataset.'''
  dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
  mean = torch.zeros(3)
  std = torch.zeros(3)
  print('==> Computing mean and std..')
  for inputs, targets in dataloader:
    for i in range(3):
      mean[i] += inputs[:,i,:,:].mean()
      std[i] += inputs[:,i,:,:].std()
  mean.div_(len(dataset))
  std.div_(len(dataset))
  return mean, std

def init_params(net):
  '''Init layer parameters.'''
  for m in net.modules():
    if isinstance(m, nn.Conv2d):
      init.kaiming_normal(m.weight, mode='fan_out')
      if m.bias:
        init.constant(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
      init.constant(m.weight, 1)
      init.constant(m.bias, 0)
    elif isinstance(m, nn.Linear):
      init.normal(m.weight, std=1e-3)
      if m.bias:
        init.constant(m.bias, 0)


_, term_width = os.popen('stty size', 'r').read().split()
term_width = int(term_width)

TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
  global last_time, begin_time
  if current == 0:
    begin_time = time.time() # Reset for new bar.

  cur_len = int(TOTAL_BAR_LENGTH*current/total)
  rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

  sys.stdout.write(' [')
  for i in range(cur_len):
    sys.stdout.write('=')
  sys.stdout.write('>')
  for i in range(rest_len):
    sys.stdout.write('.')
  sys.stdout.write(']')

  cur_time = time.time()
  step_time = cur_time - last_time
  last_time = cur_time
  tot_time = cur_time - begin_time

  L = []
  L.append(' Step: %s' % format_time(step_time))
  L.append(' | Tot: %s' % format_time(tot_time))
  if msg:
    L.append(' | ' + msg)

  msg = ''.join(L)
  sys.stdout.write(msg)
  for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
    sys.stdout.write(' ')

  # Go back to the center of the bar.
  for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
    sys.stdout.write('\b')
  sys.stdout.write(' %d/%d ' % (current+1, total))

  if current < total-1:
    sys.stdout.write('\r')
  else:
    sys.stdout.write('\n')
  sys.stdout.flush()

def format_time(seconds):
  days = int(seconds / 3600/24)
  seconds = seconds - days*3600*24
  hours = int(seconds / 3600)
  seconds = seconds - hours*3600
  minutes = int(seconds / 60)
  seconds = seconds - minutes*60
  secondsf = int(seconds)
  seconds = seconds - secondsf
  millis = int(seconds*1000)

  f = ''
  i = 1
  if days > 0:
    f += str(days) + 'D'
    i += 1
  if hours > 0 and i <= 2:
    f += str(hours) + 'h'
    i += 1
  if minutes > 0 and i <= 2:
    f += str(minutes) + 'm'
    i += 1
  if secondsf > 0 and i <= 2:
    f += str(secondsf) + 's'
    i += 1
  if millis > 0 and i <= 2:
    f += str(millis) + 'ms'
    i += 1
  if f == '':
    f = '0ms'
  return f

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python文件和目录操作方法大全(含实例)
Mar 12 Python
Python中使用SAX解析xml实例
Nov 21 Python
python检测是文件还是目录的方法
Jul 03 Python
Python连接SQLServer2000的方法详解
Apr 19 Python
在CentOS6上安装Python2.7的解决方法
Jan 09 Python
python判断文件是否存在,不存在就创建一个的实例
Feb 18 Python
python中metaclass原理与用法详解
Jun 25 Python
在python中修改.properties文件的操作
Apr 08 Python
pycharm导入源码的具体步骤
Aug 04 Python
python批量修改文件名的示例
Sep 27 Python
Python运算符+与+=的方法实例
Feb 18 Python
Python使用PyYAML库读写yaml文件的方法
Apr 06 Python
selenium+PhantomJS爬取豆瓣读书
Aug 26 #Python
python多任务之协程的使用详解
Aug 26 #Python
python数组循环处理方法
Aug 26 #Python
python中利用numpy.array()实现俩个数值列表的对应相加方法
Aug 26 #Python
关于Python中的向量相加和numpy中的向量相加效率对比
Aug 26 #Python
python中sort和sorted排序的实例方法
Aug 26 #Python
对Python 中矩阵或者数组相减的法则详解
Aug 26 #Python
You might like
SONY SRF-40W电路分析
2021/03/02 无线电
PHP Socket 编程
2010/04/09 PHP
PHP IE中下载附件问题解决方法
2014/01/07 PHP
php通过文件流方式复制文件的方法
2015/03/13 PHP
PHP5.4起内置web服务器使用方法
2016/08/09 PHP
[Web]防止用户复制页面内容和另存页面的方法
2009/02/06 Javascript
使用自定义setTimeout和setInterval使之可以传递参数和对象参数
2009/04/24 Javascript
自己的js工具 Event封装
2009/08/21 Javascript
不使用中间变量,交换int型的 a, b两个变量的值。
2010/10/29 Javascript
Javascript实现单张图片浏览
2014/12/18 Javascript
JavaScript给url网址进行encode编码的方法
2015/03/18 Javascript
【经典源码收藏】jQuery实用代码片段(筛选,搜索,样式,清除默认值,多选等)
2016/06/07 Javascript
js canvas实现擦除动画
2016/07/16 Javascript
js实现图片切换(动画版)
2016/12/25 Javascript
微信小程序 列表的上拉加载和下拉刷新的实现
2017/04/01 Javascript
vue.js 实现输入框动态添加功能
2018/06/25 Javascript
jquery的$().each和$.each的区别
2019/01/18 jQuery
js中apply和call的理解与使用方法
2019/11/27 Javascript
解决echarts 一条柱状图显示两个值,类似进度条的问题
2020/07/20 Javascript
VUE实现吸底按钮
2021/03/04 Vue.js
浅谈Python 的枚举 Enum
2017/06/12 Python
flask框架自定义过滤器示例【markdown文件读取和展示功能】
2019/11/08 Python
Python要求O(n)复杂度求无序列表中第K的大元素实例
2020/04/02 Python
Python docutils文档编译过程方法解析
2020/06/23 Python
python 两种方法删除空文件夹
2020/09/29 Python
python爬虫筛选工作实例讲解
2020/11/23 Python
日本7net购物网:书籍、漫画、杂志、DVD、游戏邮购
2017/02/17 全球购物
台湾母婴用品购物网站:Infant婴之房
2018/06/15 全球购物
会计毕业自我鉴定
2014/02/05 职场文书
某某同志考察材料
2014/05/28 职场文书
迎七一演讲稿
2014/09/12 职场文书
2015年七年级班主任工作总结
2015/05/21 职场文书
庆七一晚会主持词
2015/06/30 职场文书
基于Redis实现分布式锁的方法(lua脚本版)
2021/05/12 Redis
在Python中如何使用yield
2021/06/07 Python
Python实现Hash算法
2022/03/18 Python