简单易懂Pytorch实战实例VGG深度网络


Posted in Python onAugust 27, 2019

模型VGG,数据集cifar。对照这份代码走一遍,大概就知道整个pytorch的运行机制。

来源

定义模型:

'''VGG11/13/16/19 in Pytorch.'''
import torch
import torch.nn as nn
from torch.autograd import Variable


cfg = {
  'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
  'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
  'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
  'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

# 模型需继承nn.Module
class VGG(nn.Module):
# 初始化参数:
  def __init__(self, vgg_name):
    super(VGG, self).__init__()
    self.features = self._make_layers(cfg[vgg_name])
    self.classifier = nn.Linear(512, 10)

# 模型计算时的前向过程,也就是按照这个过程进行计算
  def forward(self, x):
    out = self.features(x)
    out = out.view(out.size(0), -1)
    out = self.classifier(out)
    return out

  def _make_layers(self, cfg):
    layers = []
    in_channels = 3
    for x in cfg:
      if x == 'M':
        layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
      else:
        layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
              nn.BatchNorm2d(x),
              nn.ReLU(inplace=True)]
        in_channels = x
    layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
    return nn.Sequential(*layers)

# net = VGG('VGG11')
# x = torch.randn(2,3,32,32)
# print(net(Variable(x)).size())

定义训练过程:

'''Train CIFAR10 with PyTorch.'''
from __future__ import print_function

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

from models import *
from utils import progress_bar
from torch.autograd import Variable

# 获取参数
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
args = parser.parse_args()

use_cuda = torch.cuda.is_available()
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch

# 获取数据集,并先进行预处理
print('==> Preparing data..')
# 图像预处理和增强
transform_train = transforms.Compose([
  transforms.RandomCrop(32, padding=4),
  transforms.RandomHorizontalFlip(),
  transforms.ToTensor(),
  transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 继续训练模型或新建一个模型
if args.resume:
  # Load checkpoint.
  print('==> Resuming from checkpoint..')
  assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
  checkpoint = torch.load('./checkpoint/ckpt.t7')
  net = checkpoint['net']
  best_acc = checkpoint['acc']
  start_epoch = checkpoint['epoch']
else:
  print('==> Building model..')
  net = VGG('VGG16')
  # net = ResNet18()
  # net = PreActResNet18()
  # net = GoogLeNet()
  # net = DenseNet121()
  # net = ResNeXt29_2x64d()
  # net = MobileNet()
  # net = MobileNetV2()
  # net = DPN92()
  # net = ShuffleNetG2()
  # net = SENet18()

# 如果GPU可用,使用GPU
if use_cuda:
  # move param and buffer to GPU
  net.cuda()
  # parallel use GPU
  net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()-1))
  # speed up slightly
  cudnn.benchmark = True


# 定义度量和优化
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

# 训练阶段
def train(epoch):
  print('\nEpoch: %d' % epoch)
  # switch to train mode
  net.train()
  train_loss = 0
  correct = 0
  total = 0
  # batch 数据
  for batch_idx, (inputs, targets) in enumerate(trainloader):
    # 将数据移到GPU上
    if use_cuda:
      inputs, targets = inputs.cuda(), targets.cuda()
    # 先将optimizer梯度先置为0
    optimizer.zero_grad()
    # Variable表示该变量属于计算图的一部分,此处是图计算的开始处。图的leaf variable
    inputs, targets = Variable(inputs), Variable(targets)
    # 模型输出
    outputs = net(inputs)
    # 计算loss,图的终点处
    loss = criterion(outputs, targets)
    # 反向传播,计算梯度
    loss.backward()
    # 更新参数
    optimizer.step()
    # 注意如果你想统计loss,切勿直接使用loss相加,而是使用loss.data[0]。因为loss是计算图的一部分,如果你直接加loss,代表total loss同样属于模型一部分,那么图就越来越大
    train_loss += loss.data[0]
    # 数据统计
    _, predicted = torch.max(outputs.data, 1)
    total += targets.size(0)
    correct += predicted.eq(targets.data).cpu().sum()

    progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
      % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

# 测试阶段
def test(epoch):
  global best_acc
  # 先切到测试模型
  net.eval()
  test_loss = 0
  correct = 0
  total = 0
  for batch_idx, (inputs, targets) in enumerate(testloader):
    if use_cuda:
      inputs, targets = inputs.cuda(), targets.cuda()
    inputs, targets = Variable(inputs, volatile=True), Variable(targets)
    outputs = net(inputs)
    loss = criterion(outputs, targets)
    # loss is variable , if add it(+=loss) directly, there will be a bigger ang bigger graph.
    test_loss += loss.data[0]
    _, predicted = torch.max(outputs.data, 1)
    total += targets.size(0)
    correct += predicted.eq(targets.data).cpu().sum()

    progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
      % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

  # Save checkpoint.
  # 保存模型
  acc = 100.*correct/total
  if acc > best_acc:
    print('Saving..')
    state = {
      'net': net.module if use_cuda else net,
      'acc': acc,
      'epoch': epoch,
    }
    if not os.path.isdir('checkpoint'):
      os.mkdir('checkpoint')
    torch.save(state, './checkpoint/ckpt.t7')
    best_acc = acc

# 运行模型
for epoch in range(start_epoch, start_epoch+200):
  train(epoch)
  test(epoch)
  # 清除部分无用变量 
  torch.cuda.empty_cache()

运行:

新模型:
python main.py --lr=0.01
旧模型继续训练:
python main.py --resume --lr=0.01

一些utility:

'''Some helper functions for PyTorch, including:
  - get_mean_and_std: calculate the mean and std value of dataset.
  - msr_init: net parameter initialization.
  - progress_bar: progress bar mimic xlua.progress.
'''
import os
import sys
import time
import math

import torch.nn as nn
import torch.nn.init as init


def get_mean_and_std(dataset):
  '''Compute the mean and std value of dataset.'''
  dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
  mean = torch.zeros(3)
  std = torch.zeros(3)
  print('==> Computing mean and std..')
  for inputs, targets in dataloader:
    for i in range(3):
      mean[i] += inputs[:,i,:,:].mean()
      std[i] += inputs[:,i,:,:].std()
  mean.div_(len(dataset))
  std.div_(len(dataset))
  return mean, std

def init_params(net):
  '''Init layer parameters.'''
  for m in net.modules():
    if isinstance(m, nn.Conv2d):
      init.kaiming_normal(m.weight, mode='fan_out')
      if m.bias:
        init.constant(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
      init.constant(m.weight, 1)
      init.constant(m.bias, 0)
    elif isinstance(m, nn.Linear):
      init.normal(m.weight, std=1e-3)
      if m.bias:
        init.constant(m.bias, 0)


_, term_width = os.popen('stty size', 'r').read().split()
term_width = int(term_width)

TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
  global last_time, begin_time
  if current == 0:
    begin_time = time.time() # Reset for new bar.

  cur_len = int(TOTAL_BAR_LENGTH*current/total)
  rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

  sys.stdout.write(' [')
  for i in range(cur_len):
    sys.stdout.write('=')
  sys.stdout.write('>')
  for i in range(rest_len):
    sys.stdout.write('.')
  sys.stdout.write(']')

  cur_time = time.time()
  step_time = cur_time - last_time
  last_time = cur_time
  tot_time = cur_time - begin_time

  L = []
  L.append(' Step: %s' % format_time(step_time))
  L.append(' | Tot: %s' % format_time(tot_time))
  if msg:
    L.append(' | ' + msg)

  msg = ''.join(L)
  sys.stdout.write(msg)
  for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
    sys.stdout.write(' ')

  # Go back to the center of the bar.
  for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
    sys.stdout.write('\b')
  sys.stdout.write(' %d/%d ' % (current+1, total))

  if current < total-1:
    sys.stdout.write('\r')
  else:
    sys.stdout.write('\n')
  sys.stdout.flush()

def format_time(seconds):
  days = int(seconds / 3600/24)
  seconds = seconds - days*3600*24
  hours = int(seconds / 3600)
  seconds = seconds - hours*3600
  minutes = int(seconds / 60)
  seconds = seconds - minutes*60
  secondsf = int(seconds)
  seconds = seconds - secondsf
  millis = int(seconds*1000)

  f = ''
  i = 1
  if days > 0:
    f += str(days) + 'D'
    i += 1
  if hours > 0 and i <= 2:
    f += str(hours) + 'h'
    i += 1
  if minutes > 0 and i <= 2:
    f += str(minutes) + 'm'
    i += 1
  if secondsf > 0 and i <= 2:
    f += str(secondsf) + 's'
    i += 1
  if millis > 0 and i <= 2:
    f += str(millis) + 'ms'
    i += 1
  if f == '':
    f = '0ms'
  return f

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 数据清洗之数据合并、转换、过滤、排序
Feb 12 Python
3个用于数据科学的顶级Python库
Sep 29 Python
使用PyQtGraph绘制精美的股票行情K线图的示例代码
Mar 14 Python
搞定这套Python爬虫面试题(面试会so easy)
Apr 03 Python
Python选择网卡发包及接收数据包
Apr 04 Python
pyqt实现.ui文件批量转换为对应.py文件脚本
Jun 19 Python
Django 实现图片上传和显示过程详解
Jul 18 Python
python中dict()的高级用法实现
Nov 13 Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 Python
tensorflow指定GPU与动态分配GPU memory设置
Feb 03 Python
使用Python pip怎么升级pip
Aug 11 Python
详解解决jupyter不能使用pytorch的问题
Feb 18 Python
selenium+PhantomJS爬取豆瓣读书
Aug 26 #Python
python多任务之协程的使用详解
Aug 26 #Python
python数组循环处理方法
Aug 26 #Python
python中利用numpy.array()实现俩个数值列表的对应相加方法
Aug 26 #Python
关于Python中的向量相加和numpy中的向量相加效率对比
Aug 26 #Python
python中sort和sorted排序的实例方法
Aug 26 #Python
对Python 中矩阵或者数组相减的法则详解
Aug 26 #Python
You might like
解决php中Cannot send session cache limiter 的问题的方法
2007/04/27 PHP
浅谈JavaScript中面向对象技术的模拟
2006/09/25 Javascript
一直复略了的一个问题,关于表单重复提交
2007/02/15 Javascript
用Div仿showModalDialog模式菜单的效果的代码
2007/03/05 Javascript
善用事件代理,警惕闭包的性能陷阱。
2011/01/20 Javascript
原来Jquery.load的方法可以一直load下去
2011/03/28 Javascript
javascipt匹配单行和多行注释的正则表达式
2013/11/20 Javascript
node.js入门教程迷你书、node.js入门web应用开发完全示例
2014/04/06 Javascript
javascript打开word文档的方法
2014/04/16 Javascript
JS实现5秒钟自动封锁div层的方法
2015/02/20 Javascript
jquery实现最简单的滑动菜单效果代码
2015/09/12 Javascript
基于jQuery 实现bootstrapValidator下的全局验证
2015/12/07 Javascript
浅谈js函数中的实例对象、类对象、局部变量(局部函数)
2016/11/20 Javascript
js实现文本上下来回滚动
2017/02/03 Javascript
vue左侧菜单,树形图递归实现代码
2018/08/24 Javascript
js实现可爱的气泡特效
2020/09/05 Javascript
微信小程序中target和currentTarget的区别小结
2020/11/06 Javascript
使用Vant完成DatetimePicker 日期的选择器操作
2020/11/12 Javascript
[14:50]2018DOTA2亚洲邀请赛开幕式
2018/04/03 DOTA
[01:32:50]DOTA2-DPC中国联赛 正赛 DLG vs XG BO3 第一场 1月25日
2021/03/11 DOTA
详解python使用Nginx和uWSGI来运行Python应用
2018/01/09 Python
python2.7到3.x迁移指南
2018/02/01 Python
Python socket实现的简单通信功能示例
2018/08/21 Python
简单了解Pandas缺失值处理方法
2019/11/16 Python
pyinstaller打包找不到文件的问题解决
2020/04/15 Python
Tensorflow全局设置可见GPU编号操作
2020/06/30 Python
移动通信专业自荐信范文
2013/11/12 职场文书
最新奶茶店创业计划书范文
2014/02/08 职场文书
薪酬专员岗位职责
2014/02/18 职场文书
婚前保证书
2014/04/29 职场文书
企业演讲稿范文大全
2014/05/20 职场文书
投标人廉洁自律承诺书
2014/05/26 职场文书
教师自我剖析材料
2014/09/29 职场文书
二审答辩状范文
2015/05/22 职场文书
装修公司管理制度
2015/08/05 职场文书
Java基于Dijkstra算法实现校园导游程序
2022/03/17 Java/Android