简单易懂Pytorch实战实例VGG深度网络


Posted in Python onAugust 27, 2019

模型VGG,数据集cifar。对照这份代码走一遍,大概就知道整个pytorch的运行机制。

来源

定义模型:

'''VGG11/13/16/19 in Pytorch.'''
import torch
import torch.nn as nn
from torch.autograd import Variable


cfg = {
  'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
  'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
  'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
  'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

# 模型需继承nn.Module
class VGG(nn.Module):
# 初始化参数:
  def __init__(self, vgg_name):
    super(VGG, self).__init__()
    self.features = self._make_layers(cfg[vgg_name])
    self.classifier = nn.Linear(512, 10)

# 模型计算时的前向过程,也就是按照这个过程进行计算
  def forward(self, x):
    out = self.features(x)
    out = out.view(out.size(0), -1)
    out = self.classifier(out)
    return out

  def _make_layers(self, cfg):
    layers = []
    in_channels = 3
    for x in cfg:
      if x == 'M':
        layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
      else:
        layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
              nn.BatchNorm2d(x),
              nn.ReLU(inplace=True)]
        in_channels = x
    layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
    return nn.Sequential(*layers)

# net = VGG('VGG11')
# x = torch.randn(2,3,32,32)
# print(net(Variable(x)).size())

定义训练过程:

'''Train CIFAR10 with PyTorch.'''
from __future__ import print_function

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

from models import *
from utils import progress_bar
from torch.autograd import Variable

# 获取参数
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
args = parser.parse_args()

use_cuda = torch.cuda.is_available()
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch

# 获取数据集,并先进行预处理
print('==> Preparing data..')
# 图像预处理和增强
transform_train = transforms.Compose([
  transforms.RandomCrop(32, padding=4),
  transforms.RandomHorizontalFlip(),
  transforms.ToTensor(),
  transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 继续训练模型或新建一个模型
if args.resume:
  # Load checkpoint.
  print('==> Resuming from checkpoint..')
  assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
  checkpoint = torch.load('./checkpoint/ckpt.t7')
  net = checkpoint['net']
  best_acc = checkpoint['acc']
  start_epoch = checkpoint['epoch']
else:
  print('==> Building model..')
  net = VGG('VGG16')
  # net = ResNet18()
  # net = PreActResNet18()
  # net = GoogLeNet()
  # net = DenseNet121()
  # net = ResNeXt29_2x64d()
  # net = MobileNet()
  # net = MobileNetV2()
  # net = DPN92()
  # net = ShuffleNetG2()
  # net = SENet18()

# 如果GPU可用,使用GPU
if use_cuda:
  # move param and buffer to GPU
  net.cuda()
  # parallel use GPU
  net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()-1))
  # speed up slightly
  cudnn.benchmark = True


# 定义度量和优化
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

# 训练阶段
def train(epoch):
  print('\nEpoch: %d' % epoch)
  # switch to train mode
  net.train()
  train_loss = 0
  correct = 0
  total = 0
  # batch 数据
  for batch_idx, (inputs, targets) in enumerate(trainloader):
    # 将数据移到GPU上
    if use_cuda:
      inputs, targets = inputs.cuda(), targets.cuda()
    # 先将optimizer梯度先置为0
    optimizer.zero_grad()
    # Variable表示该变量属于计算图的一部分,此处是图计算的开始处。图的leaf variable
    inputs, targets = Variable(inputs), Variable(targets)
    # 模型输出
    outputs = net(inputs)
    # 计算loss,图的终点处
    loss = criterion(outputs, targets)
    # 反向传播,计算梯度
    loss.backward()
    # 更新参数
    optimizer.step()
    # 注意如果你想统计loss,切勿直接使用loss相加,而是使用loss.data[0]。因为loss是计算图的一部分,如果你直接加loss,代表total loss同样属于模型一部分,那么图就越来越大
    train_loss += loss.data[0]
    # 数据统计
    _, predicted = torch.max(outputs.data, 1)
    total += targets.size(0)
    correct += predicted.eq(targets.data).cpu().sum()

    progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
      % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

# 测试阶段
def test(epoch):
  global best_acc
  # 先切到测试模型
  net.eval()
  test_loss = 0
  correct = 0
  total = 0
  for batch_idx, (inputs, targets) in enumerate(testloader):
    if use_cuda:
      inputs, targets = inputs.cuda(), targets.cuda()
    inputs, targets = Variable(inputs, volatile=True), Variable(targets)
    outputs = net(inputs)
    loss = criterion(outputs, targets)
    # loss is variable , if add it(+=loss) directly, there will be a bigger ang bigger graph.
    test_loss += loss.data[0]
    _, predicted = torch.max(outputs.data, 1)
    total += targets.size(0)
    correct += predicted.eq(targets.data).cpu().sum()

    progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
      % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

  # Save checkpoint.
  # 保存模型
  acc = 100.*correct/total
  if acc > best_acc:
    print('Saving..')
    state = {
      'net': net.module if use_cuda else net,
      'acc': acc,
      'epoch': epoch,
    }
    if not os.path.isdir('checkpoint'):
      os.mkdir('checkpoint')
    torch.save(state, './checkpoint/ckpt.t7')
    best_acc = acc

# 运行模型
for epoch in range(start_epoch, start_epoch+200):
  train(epoch)
  test(epoch)
  # 清除部分无用变量 
  torch.cuda.empty_cache()

运行:

新模型:
python main.py --lr=0.01
旧模型继续训练:
python main.py --resume --lr=0.01

一些utility:

'''Some helper functions for PyTorch, including:
  - get_mean_and_std: calculate the mean and std value of dataset.
  - msr_init: net parameter initialization.
  - progress_bar: progress bar mimic xlua.progress.
'''
import os
import sys
import time
import math

import torch.nn as nn
import torch.nn.init as init


def get_mean_and_std(dataset):
  '''Compute the mean and std value of dataset.'''
  dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
  mean = torch.zeros(3)
  std = torch.zeros(3)
  print('==> Computing mean and std..')
  for inputs, targets in dataloader:
    for i in range(3):
      mean[i] += inputs[:,i,:,:].mean()
      std[i] += inputs[:,i,:,:].std()
  mean.div_(len(dataset))
  std.div_(len(dataset))
  return mean, std

def init_params(net):
  '''Init layer parameters.'''
  for m in net.modules():
    if isinstance(m, nn.Conv2d):
      init.kaiming_normal(m.weight, mode='fan_out')
      if m.bias:
        init.constant(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
      init.constant(m.weight, 1)
      init.constant(m.bias, 0)
    elif isinstance(m, nn.Linear):
      init.normal(m.weight, std=1e-3)
      if m.bias:
        init.constant(m.bias, 0)


_, term_width = os.popen('stty size', 'r').read().split()
term_width = int(term_width)

TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
  global last_time, begin_time
  if current == 0:
    begin_time = time.time() # Reset for new bar.

  cur_len = int(TOTAL_BAR_LENGTH*current/total)
  rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

  sys.stdout.write(' [')
  for i in range(cur_len):
    sys.stdout.write('=')
  sys.stdout.write('>')
  for i in range(rest_len):
    sys.stdout.write('.')
  sys.stdout.write(']')

  cur_time = time.time()
  step_time = cur_time - last_time
  last_time = cur_time
  tot_time = cur_time - begin_time

  L = []
  L.append(' Step: %s' % format_time(step_time))
  L.append(' | Tot: %s' % format_time(tot_time))
  if msg:
    L.append(' | ' + msg)

  msg = ''.join(L)
  sys.stdout.write(msg)
  for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
    sys.stdout.write(' ')

  # Go back to the center of the bar.
  for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
    sys.stdout.write('\b')
  sys.stdout.write(' %d/%d ' % (current+1, total))

  if current < total-1:
    sys.stdout.write('\r')
  else:
    sys.stdout.write('\n')
  sys.stdout.flush()

def format_time(seconds):
  days = int(seconds / 3600/24)
  seconds = seconds - days*3600*24
  hours = int(seconds / 3600)
  seconds = seconds - hours*3600
  minutes = int(seconds / 60)
  seconds = seconds - minutes*60
  secondsf = int(seconds)
  seconds = seconds - secondsf
  millis = int(seconds*1000)

  f = ''
  i = 1
  if days > 0:
    f += str(days) + 'D'
    i += 1
  if hours > 0 and i <= 2:
    f += str(hours) + 'h'
    i += 1
  if minutes > 0 and i <= 2:
    f += str(minutes) + 'm'
    i += 1
  if secondsf > 0 and i <= 2:
    f += str(secondsf) + 's'
    i += 1
  if millis > 0 and i <= 2:
    f += str(millis) + 'ms'
    i += 1
  if f == '':
    f = '0ms'
  return f

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现方便使用的级联进度信息实例
May 05 Python
python获取多线程及子线程的返回值
Nov 15 Python
Python程序运行原理图文解析
Feb 10 Python
PyQt5每天必学之工具提示功能
Apr 19 Python
Ubuntu下使用python读取doc和docx文档的内容方法
May 08 Python
Python2和Python3之间的str处理方式导致乱码的讲解
Jan 03 Python
python程序控制NAO机器人行走
Apr 29 Python
浅谈Python编程中3个常用的数据结构和算法
Apr 30 Python
Numpy中对向量、矩阵的使用详解
Oct 29 Python
利用python实现AR教程
Nov 20 Python
Python输出指定字符串的方法
Feb 06 Python
如何通过python计算圆周率PI
Nov 11 Python
selenium+PhantomJS爬取豆瓣读书
Aug 26 #Python
python多任务之协程的使用详解
Aug 26 #Python
python数组循环处理方法
Aug 26 #Python
python中利用numpy.array()实现俩个数值列表的对应相加方法
Aug 26 #Python
关于Python中的向量相加和numpy中的向量相加效率对比
Aug 26 #Python
python中sort和sorted排序的实例方法
Aug 26 #Python
对Python 中矩阵或者数组相减的法则详解
Aug 26 #Python
You might like
PHP及Zend Engine的线程安全模型分析
2011/11/10 PHP
php求正负数数组中连续元素最大值示例
2014/04/11 PHP
浅谈PHP面向对象之访问者模式+组合模式
2017/05/22 PHP
PHP以json或xml格式返回请求数据的方法
2018/05/31 PHP
JavaScript 学习 - 提高篇
2007/02/02 Javascript
juqery 学习之四 筛选过滤
2010/11/30 Javascript
JavaScript改变HTML元素的样式改变CSS及元素属性
2013/11/12 Javascript
JavaScript的this关键字的理解
2016/06/18 Javascript
jQuery实现简单的滑动导航代码(移动端)
2017/05/22 jQuery
vue实现app页面切换动画效果实例
2017/05/23 Javascript
十大 Node.js 的 Web 框架(快速提升工作效率)
2017/06/30 Javascript
解决Vue页面固定滚动位置的处理办法
2017/07/13 Javascript
浅谈如何通过node.js对数据进行MD5加密
2018/05/16 Javascript
详解在Node.js中发起HTTP请求的5种方法
2019/01/10 Javascript
浅谈layui 数据表格前后台传值的问题
2019/09/12 Javascript
如何在Vue.JS中使用图标组件
2020/08/04 Javascript
[01:32]DOTA2次级联赛——首支职业女子战队选拔赛全记录
2014/10/23 DOTA
[01:06:59]完美世界DOTA2联赛PWL S2 Magma vs FTD 第一场 11.29
2020/12/02 DOTA
Python struct模块解析
2014/06/12 Python
在windows下Python打印彩色字体的方法
2018/05/15 Python
python 在某.py文件中调用其他.py内的函数的方法
2019/06/25 Python
python基于socket模拟实现ssh远程执行命令
2020/12/05 Python
澳大利亚领先的在线药房:Pharmacy Online(有中文站)
2020/02/22 全球购物
介绍一下常见的木马种类
2014/11/15 面试题
EJB与JAVA BEAN的区别
2016/08/29 面试题
求职自荐书范文
2013/12/04 职场文书
上课玩手机检讨书
2014/02/08 职场文书
法人授权委托书
2014/04/03 职场文书
2014年团工作总结
2014/11/27 职场文书
介绍信范文
2015/01/31 职场文书
2015年学雷锋活动总结
2015/02/06 职场文书
2015年社会实践个人总结
2015/03/06 职场文书
员工手册董事长致辞
2015/07/29 职场文书
高中物理教学反思
2016/02/19 职场文书
缓存替换策略及应用(以Redis、InnoDB为例)
2021/07/25 Redis
vue使用wavesurfer.js解决音频可视化播放问题
2022/04/04 Vue.js