pytorch实现对输入超过三通道的数据进行训练


Posted in Python onJanuary 15, 2020

案例背景:视频识别

假设每次输入是8s的灰度视频,视频帧率为25fps,则视频由200帧图像序列构成.每帧是一副单通道的灰度图像,通过pythonb里面的np.stack(深度拼接)可将200帧拼接成200通道的深度数据.进而送到网络里面去训练.

如果输入图像200通道觉得多,可以对视频进行抽帧,针对具体场景可以随机抽帧或等间隔抽帧.比如这里等间隔抽取40帧.则最后输入视频相当于输入一个40通道的图像数据了.

pytorch对超过三通道数据的加载:

读取视频每一帧,转为array格式,然后依次将每一帧进行深度拼接,最后得到一个40通道的array格式的深度数据,保存到pickle里.

对每个视频都进行上述操作,保存到pickle里.

我这里将火的视频深度数据保存在一个.pkl文件中,一共2504个火的视频,即2504个火的深度数据.

将非火的视频深度数据保存在一个.pkl文件中,一共3985个非火的视频,即3985个非火的深度数据.

数据加载

import torch 
from torch.utils import data
import os
from PIL import Image
import numpy as np
import pickle
 
class Fire_Unfire(data.Dataset):
  def __init__(self,fire_path,unfire_path):
    self.pickle_fire = open(fire_path,'rb')
    self.pickle_unfire = open(unfire_path,'rb')
    
  def __getitem__(self,index):
    if index <2504:
      fire = pickle.load(self.pickle_fire)#高*宽*通道
      fire = fire.transpose(2,0,1)#通道*高*宽
      data = torch.from_numpy(fire)
      label = 1
      return data,label
    elif index>=2504 and index<6489:
      unfire = pickle.load(self.pickle_unfire)
      unfire = unfire.transpose(2,0,1)
      data = torch.from_numpy(unfire)
      label = 0
      return data,label
    
  def __len__(self):
    return 6489
root_path = './datasets/train'
dataset = Fire_Unfire(root_path +'/fire_train.pkl',root_path +'/unfire_train.pkl')
 
#转换成pytorch网络输入的格式(批量大小,通道数,高,宽)
from torch.utils.data import DataLoader
fire_dataloader = DataLoader(dataset,batch_size=4,shuffle=True,drop_last = True)

模型训练

import torch
from torch.utils import data
from nets.mobilenet import mobilenet
from config.config import default_config
from torch.autograd import Variable as V
import numpy as np
import sys
import time
 
opt = default_config()
def train():
  #模型定义
  model = mobilenet().cuda()
  if opt.pretrain_model:
    model.load_state_dict(torch.load(opt.pretrain_model))
  
  #损失函数
  criterion = torch.nn.CrossEntropyLoss().cuda()
  
  #学习率
  lr = opt.lr
  
  #优化器
  optimizer = torch.optim.SGD(model.parameters(),lr = lr,weight_decay=opt.weight_decay)
  
  
  pre_loss = 0.0
  #训练
  for epoch in range(opt.max_epoch):
     #训练数据
    train_data = Fire_Unfire(opt.root_path +'/fire_train.pkl',opt.root_path +'/unfire_train.pkl')
    train_dataloader = data.DataLoader(train_data,batch_size=opt.batch_size,shuffle=True,drop_last = True)
    loss_sum = 0.0
    for i,(datas,labels) in enumerate(train_dataloader):
      #print(i,datas.size(),labels)
      #梯度清零
      optimizer.zero_grad()
      #输入
      input = V(datas.cuda()).float()
      #目标
      target = V(labels.cuda()).long()
      #输出
      score = model(input).cuda()
      #损失
      loss = criterion(score,target)
      loss_sum += loss
      #反向传播
      loss.backward()
      #梯度更新
      optimizer.step()      
    print('{}{}{}{}{}'.format('epoch:',epoch,',','loss:',loss))
    torch.save(model.state_dict(),'models/mobilenet_%d.pth'%(epoch+370))

RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target'

解决方案:target = target.long()

以上这篇pytorch实现对输入超过三通道的数据进行训练就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python常见文件操作的函数示例代码
Nov 15 Python
Python中的两个内置模块介绍
Apr 05 Python
python django 实现验证码的功能实例代码
May 18 Python
Python开发的HTTP库requests详解
Aug 29 Python
Python+pandas计算数据相关系数的实例
Jul 03 Python
python对矩阵进行转置的2种处理方法
Jul 17 Python
python爬虫之爬取百度音乐的实现方法
Aug 24 Python
关于Python形参打包与解包小技巧分享
Aug 24 Python
python实现实时视频流播放代码实例
Jan 11 Python
使用python turtle画高达
Jan 19 Python
python中count函数简单的实例讲解
Feb 06 Python
Python 的 sum() Pythonic 的求和方法详细
Oct 16 Python
Pytorch 定义MyDatasets实现多通道分别输入不同数据方式
Jan 15 #Python
pytorch构建多模型实例
Jan 15 #Python
利用Pytorch实现简单的线性回归算法
Jan 15 #Python
pytorch实现线性拟合方式
Jan 15 #Python
Python 支持向量机分类器的实现
Jan 15 #Python
pytorch-神经网络拟合曲线实例
Jan 15 #Python
Pytorch中的VGG实现修改最后一层FC
Jan 15 #Python
You might like
学习使用curl采集curl使用方法
2012/01/11 PHP
php实现的支持断点续传的文件下载类
2014/09/23 PHP
discuz图片顺序混乱解决方案
2015/07/29 PHP
php生成随机数/生成随机字符串的方法小结【5种方法】
2020/05/27 PHP
让IE6支持min-width和max-width的方法
2010/06/25 Javascript
js数组的基本用法及数组根据下标(数值或字符)移除元素
2013/10/20 Javascript
深入理解javascript中return的作用
2013/12/30 Javascript
使用jQuery重置(reset)表单的方法
2014/05/05 Javascript
使表格的标题列可左右拉伸jquery插件封装
2014/11/24 Javascript
javascript实现checkBox的全选,反选与赋值
2015/03/12 Javascript
JQuery中Bind()事件用法分析
2015/05/05 Javascript
详解如何使用Vue2做服务端渲染
2017/03/29 Javascript
自定义类似于jQuery UI Selectable 的Vue指令v-selectable
2017/08/23 jQuery
JS设计模式之状态模式概念与用法分析
2018/02/05 Javascript
JS实现的简单分页功能示例
2018/08/23 Javascript
10分钟彻底搞懂Http的强制缓存和协商缓存(小结)
2018/08/30 Javascript
Spring boot 和Vue开发中CORS跨域问题解决
2018/09/05 Javascript
vue debug 二种方法
2018/09/16 Javascript
微信小程序实现的canvas合成图片功能示例
2019/05/03 Javascript
详解Vue调用手机相机和相册以及上传
2019/05/05 Javascript
微信小程序下拉加载和上拉刷新两种实现方法详解
2019/09/05 Javascript
Python3中类、模块、错误与异常、文件的简易教程
2017/11/20 Python
python根据url地址下载小文件的实例
2018/12/18 Python
python使用PIL模块获取图片像素点的方法
2019/01/08 Python
python按照多个条件排序的方法
2019/02/08 Python
TensorFlow2.1.0安装过程中setuptools、wrapt等相关错误指南
2020/04/08 Python
用python给csv里的数据排序的具体代码
2020/07/17 Python
pycharm 实现调试窗口恢复
2021/02/05 Python
HTML5 Web 存储详解
2016/09/16 HTML / CSS
使用HTML5 Canvas API中的clip()方法裁剪区域图像
2016/03/25 HTML / CSS
美国领先的家居装饰和礼品商店:Kirkland’s
2017/01/30 全球购物
霸气押韵的班级口号
2014/06/09 职场文书
授权委托书
2014/07/31 职场文书
浅谈JS的原型和原型链
2021/06/04 Javascript
WCG2010 星际争霸决赛 Flash vs Goojila 1 星际经典比赛回顾
2022/04/01 星际争霸
win10更新失败无限重启解决方法
2022/04/19 数码科技