pytorch实现对输入超过三通道的数据进行训练


Posted in Python onJanuary 15, 2020

案例背景:视频识别

假设每次输入是8s的灰度视频,视频帧率为25fps,则视频由200帧图像序列构成.每帧是一副单通道的灰度图像,通过pythonb里面的np.stack(深度拼接)可将200帧拼接成200通道的深度数据.进而送到网络里面去训练.

如果输入图像200通道觉得多,可以对视频进行抽帧,针对具体场景可以随机抽帧或等间隔抽帧.比如这里等间隔抽取40帧.则最后输入视频相当于输入一个40通道的图像数据了.

pytorch对超过三通道数据的加载:

读取视频每一帧,转为array格式,然后依次将每一帧进行深度拼接,最后得到一个40通道的array格式的深度数据,保存到pickle里.

对每个视频都进行上述操作,保存到pickle里.

我这里将火的视频深度数据保存在一个.pkl文件中,一共2504个火的视频,即2504个火的深度数据.

将非火的视频深度数据保存在一个.pkl文件中,一共3985个非火的视频,即3985个非火的深度数据.

数据加载

import torch 
from torch.utils import data
import os
from PIL import Image
import numpy as np
import pickle
 
class Fire_Unfire(data.Dataset):
  def __init__(self,fire_path,unfire_path):
    self.pickle_fire = open(fire_path,'rb')
    self.pickle_unfire = open(unfire_path,'rb')
    
  def __getitem__(self,index):
    if index <2504:
      fire = pickle.load(self.pickle_fire)#高*宽*通道
      fire = fire.transpose(2,0,1)#通道*高*宽
      data = torch.from_numpy(fire)
      label = 1
      return data,label
    elif index>=2504 and index<6489:
      unfire = pickle.load(self.pickle_unfire)
      unfire = unfire.transpose(2,0,1)
      data = torch.from_numpy(unfire)
      label = 0
      return data,label
    
  def __len__(self):
    return 6489
root_path = './datasets/train'
dataset = Fire_Unfire(root_path +'/fire_train.pkl',root_path +'/unfire_train.pkl')
 
#转换成pytorch网络输入的格式(批量大小,通道数,高,宽)
from torch.utils.data import DataLoader
fire_dataloader = DataLoader(dataset,batch_size=4,shuffle=True,drop_last = True)

模型训练

import torch
from torch.utils import data
from nets.mobilenet import mobilenet
from config.config import default_config
from torch.autograd import Variable as V
import numpy as np
import sys
import time
 
opt = default_config()
def train():
  #模型定义
  model = mobilenet().cuda()
  if opt.pretrain_model:
    model.load_state_dict(torch.load(opt.pretrain_model))
  
  #损失函数
  criterion = torch.nn.CrossEntropyLoss().cuda()
  
  #学习率
  lr = opt.lr
  
  #优化器
  optimizer = torch.optim.SGD(model.parameters(),lr = lr,weight_decay=opt.weight_decay)
  
  
  pre_loss = 0.0
  #训练
  for epoch in range(opt.max_epoch):
     #训练数据
    train_data = Fire_Unfire(opt.root_path +'/fire_train.pkl',opt.root_path +'/unfire_train.pkl')
    train_dataloader = data.DataLoader(train_data,batch_size=opt.batch_size,shuffle=True,drop_last = True)
    loss_sum = 0.0
    for i,(datas,labels) in enumerate(train_dataloader):
      #print(i,datas.size(),labels)
      #梯度清零
      optimizer.zero_grad()
      #输入
      input = V(datas.cuda()).float()
      #目标
      target = V(labels.cuda()).long()
      #输出
      score = model(input).cuda()
      #损失
      loss = criterion(score,target)
      loss_sum += loss
      #反向传播
      loss.backward()
      #梯度更新
      optimizer.step()      
    print('{}{}{}{}{}'.format('epoch:',epoch,',','loss:',loss))
    torch.save(model.state_dict(),'models/mobilenet_%d.pth'%(epoch+370))

RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target'

解决方案:target = target.long()

以上这篇pytorch实现对输入超过三通道的数据进行训练就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python下载图片实现方法(超简单)
Jul 21 Python
利用python批量修改word文件名的方法示例
Oct 17 Python
itchat接口使用示例
Oct 23 Python
Scrapy框架CrawlSpiders的介绍以及使用详解
Nov 29 Python
Python 从一个文件中调用另一个文件的类方法
Jan 10 Python
python装饰器简介---这一篇也许就够了(推荐)
Apr 01 Python
python3.7 使用pymssql往sqlserver插入数据的方法
Jul 08 Python
python小程序实现刷票功能详解
Jul 17 Python
Pytest mark使用实例及原理解析
Feb 22 Python
pyinstaller打包找不到文件的问题解决
Apr 15 Python
完美处理python与anaconda环境变量的冲突问题
Apr 07 Python
详解Python描述符的工作原理
Jun 11 Python
Pytorch 定义MyDatasets实现多通道分别输入不同数据方式
Jan 15 #Python
pytorch构建多模型实例
Jan 15 #Python
利用Pytorch实现简单的线性回归算法
Jan 15 #Python
pytorch实现线性拟合方式
Jan 15 #Python
Python 支持向量机分类器的实现
Jan 15 #Python
pytorch-神经网络拟合曲线实例
Jan 15 #Python
Pytorch中的VGG实现修改最后一层FC
Jan 15 #Python
You might like
模拟flock实现文件锁定
2007/02/14 PHP
php中判断字符串是否全是中文或含有中文的实现代码
2011/09/16 PHP
PHP两种快速排序算法实例
2015/02/15 PHP
PHP单例模式模拟Java Bean实现方法示例
2018/12/07 PHP
tp5.1 框架数据库-数据集操作实例分析
2020/05/26 PHP
DOM下的节点属性和操作小结
2009/05/14 Javascript
JavaScript初学者需要了解10个小技巧
2010/08/25 Javascript
javascript 函数声明与函数表达式的区别介绍
2013/10/05 Javascript
不到30行JS代码实现Excel表格的方法
2014/11/15 Javascript
JavaScript实现强制重定向至HTTPS页面
2015/06/10 Javascript
微信小程序 wx.request(接口调用方式)详解及实例
2016/11/23 Javascript
基于jPlayer三分屏的制作方法
2016/12/21 Javascript
老生常谈jquery id选择器和class选择器的区别
2017/02/12 Javascript
layui弹出层效果实现代码
2017/05/19 Javascript
vue Render中slots的使用的实例代码
2017/07/19 Javascript
webpack将js打包后的map文件详解
2018/02/22 Javascript
详解webpack模块加载器兼打包工具
2018/09/11 Javascript
Javascript如何递归遍历本地文件夹
2020/08/06 Javascript
Python编程中的反模式实例分析
2014/12/08 Python
由Python运算π的值深入Python中科学计算的实现
2015/04/17 Python
python获取外网ip地址的方法总结
2015/07/02 Python
python 循环遍历字典元素的简单方法
2016/09/11 Python
Python 中Pickle库的使用详解
2018/02/24 Python
浅谈flask源码之请求过程
2018/07/26 Python
六行python代码的爱心曲线详解
2019/05/17 Python
Django中使用 Closure Table 储存无限分级数据
2019/06/06 Python
Python中py文件转换成exe可执行文件的方法
2019/06/14 Python
基于python实现自动化办公学习笔记(CSV、word、Excel、PPT)
2019/08/06 Python
python3.7 openpyxl 删除指定一列或者一行的代码
2019/10/08 Python
python标准库os库的函数介绍
2020/02/12 Python
使用TFRecord存取多个数据案例
2020/02/17 Python
python序列类型种类详解
2020/02/26 Python
纯css3实现照片墙效果
2014/12/26 HTML / CSS
Hanky Panky官方网站:内衣和睡衣
2019/07/25 全球购物
应届中专生自荐书范文
2014/02/13 职场文书
python数字图像处理之对比度与亮度调整示例
2022/06/28 Python