pytorch实现对输入超过三通道的数据进行训练


Posted in Python onJanuary 15, 2020

案例背景:视频识别

假设每次输入是8s的灰度视频,视频帧率为25fps,则视频由200帧图像序列构成.每帧是一副单通道的灰度图像,通过pythonb里面的np.stack(深度拼接)可将200帧拼接成200通道的深度数据.进而送到网络里面去训练.

如果输入图像200通道觉得多,可以对视频进行抽帧,针对具体场景可以随机抽帧或等间隔抽帧.比如这里等间隔抽取40帧.则最后输入视频相当于输入一个40通道的图像数据了.

pytorch对超过三通道数据的加载:

读取视频每一帧,转为array格式,然后依次将每一帧进行深度拼接,最后得到一个40通道的array格式的深度数据,保存到pickle里.

对每个视频都进行上述操作,保存到pickle里.

我这里将火的视频深度数据保存在一个.pkl文件中,一共2504个火的视频,即2504个火的深度数据.

将非火的视频深度数据保存在一个.pkl文件中,一共3985个非火的视频,即3985个非火的深度数据.

数据加载

import torch 
from torch.utils import data
import os
from PIL import Image
import numpy as np
import pickle
 
class Fire_Unfire(data.Dataset):
  def __init__(self,fire_path,unfire_path):
    self.pickle_fire = open(fire_path,'rb')
    self.pickle_unfire = open(unfire_path,'rb')
    
  def __getitem__(self,index):
    if index <2504:
      fire = pickle.load(self.pickle_fire)#高*宽*通道
      fire = fire.transpose(2,0,1)#通道*高*宽
      data = torch.from_numpy(fire)
      label = 1
      return data,label
    elif index>=2504 and index<6489:
      unfire = pickle.load(self.pickle_unfire)
      unfire = unfire.transpose(2,0,1)
      data = torch.from_numpy(unfire)
      label = 0
      return data,label
    
  def __len__(self):
    return 6489
root_path = './datasets/train'
dataset = Fire_Unfire(root_path +'/fire_train.pkl',root_path +'/unfire_train.pkl')
 
#转换成pytorch网络输入的格式(批量大小,通道数,高,宽)
from torch.utils.data import DataLoader
fire_dataloader = DataLoader(dataset,batch_size=4,shuffle=True,drop_last = True)

模型训练

import torch
from torch.utils import data
from nets.mobilenet import mobilenet
from config.config import default_config
from torch.autograd import Variable as V
import numpy as np
import sys
import time
 
opt = default_config()
def train():
  #模型定义
  model = mobilenet().cuda()
  if opt.pretrain_model:
    model.load_state_dict(torch.load(opt.pretrain_model))
  
  #损失函数
  criterion = torch.nn.CrossEntropyLoss().cuda()
  
  #学习率
  lr = opt.lr
  
  #优化器
  optimizer = torch.optim.SGD(model.parameters(),lr = lr,weight_decay=opt.weight_decay)
  
  
  pre_loss = 0.0
  #训练
  for epoch in range(opt.max_epoch):
     #训练数据
    train_data = Fire_Unfire(opt.root_path +'/fire_train.pkl',opt.root_path +'/unfire_train.pkl')
    train_dataloader = data.DataLoader(train_data,batch_size=opt.batch_size,shuffle=True,drop_last = True)
    loss_sum = 0.0
    for i,(datas,labels) in enumerate(train_dataloader):
      #print(i,datas.size(),labels)
      #梯度清零
      optimizer.zero_grad()
      #输入
      input = V(datas.cuda()).float()
      #目标
      target = V(labels.cuda()).long()
      #输出
      score = model(input).cuda()
      #损失
      loss = criterion(score,target)
      loss_sum += loss
      #反向传播
      loss.backward()
      #梯度更新
      optimizer.step()      
    print('{}{}{}{}{}'.format('epoch:',epoch,',','loss:',loss))
    torch.save(model.state_dict(),'models/mobilenet_%d.pth'%(epoch+370))

RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target'

解决方案:target = target.long()

以上这篇pytorch实现对输入超过三通道的数据进行训练就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
tornado框架blog模块分析与使用
Nov 21 Python
在Linux系统上安装Python的Scrapy框架的教程
Jun 11 Python
解决Python2.7读写文件中的中文乱码问题
Apr 12 Python
Python 字符串换行的多种方式
Sep 06 Python
python实现飞机大战
Sep 11 Python
Python 中的lambda函数介绍
Oct 10 Python
python pyheatmap包绘制热力图
Nov 09 Python
值得收藏的10道python 面试题
Apr 15 Python
使用pandas实现连续数据的离散化处理方式(分箱操作)
Nov 22 Python
python实现输入三角形边长自动作图求面积案例
Apr 12 Python
使用TensorBoard进行超参数优化的实现
Jul 06 Python
Python脚本实现Zabbix多行日志监控过程解析
Aug 26 Python
Pytorch 定义MyDatasets实现多通道分别输入不同数据方式
Jan 15 #Python
pytorch构建多模型实例
Jan 15 #Python
利用Pytorch实现简单的线性回归算法
Jan 15 #Python
pytorch实现线性拟合方式
Jan 15 #Python
Python 支持向量机分类器的实现
Jan 15 #Python
pytorch-神经网络拟合曲线实例
Jan 15 #Python
Pytorch中的VGG实现修改最后一层FC
Jan 15 #Python
You might like
用PHP伪造referer突破网盘禁止外连的代码
2008/06/15 PHP
在PHP中检查PHP文件是否有语法错误的方法
2009/12/23 PHP
php获取表单中多个同名input元素的值
2014/03/20 PHP
ThinkPHP使用心得分享-ThinkPHP + Ajax 实现2级联动下拉菜单
2014/05/15 PHP
ECshop 迁移到 PHP7版本时遇到的兼容性问题
2016/02/15 PHP
PHPExcel导出2003和2007的excel文档功能示例
2017/01/04 PHP
摘自启点的main.js
2008/04/20 Javascript
19个很有用的 JavaScript库推荐
2011/06/27 Javascript
javascript之bind使用介绍
2011/10/09 Javascript
JS+CSS制作DIV层可(最小化/拖拽/排序)功能实现代码
2013/02/25 Javascript
JavaScript实现计算字符串中出现次数最多的字符和出现的次数
2015/03/12 Javascript
jQuery使用after()方法在元素后面添加多项内容的方法
2015/03/26 Javascript
jQuery simpleModal插件的使用介绍
2016/08/30 Javascript
在vue中使用image-webpack-loader实例
2020/11/12 Javascript
[02:26]2016国际邀请赛8月3日开战 中国军团出征西雅图
2016/08/02 DOTA
python实现给数组按片赋值的方法
2015/07/28 Python
Python的Flask框架标配模板引擎Jinja2的使用教程
2016/07/12 Python
NetworkX之Prim算法(实例讲解)
2017/12/22 Python
Python实现二叉搜索树BST的方法示例
2019/07/30 Python
浅谈Pycharm最有必要改的几个默认设置项
2020/02/14 Python
Python线程threading模块用法详解
2020/02/26 Python
python GUI库图形界面开发之PyQt5信号与槽的高级使用技巧装饰器信号与槽详细使用方法与实例
2020/03/06 Python
Python定义一个Actor任务
2020/07/29 Python
CHARLES & KEITH台湾官网:新加坡时尚品牌
2019/07/30 全球购物
Steiff台湾官网:德国金耳釦泰迪熊
2019/12/26 全球购物
Oracle中delete,truncate和drop的区别
2016/05/05 面试题
自我介绍演讲稿
2014/01/15 职场文书
寒假家长评语大全
2014/04/16 职场文书
预防艾滋病宣传标语
2014/06/25 职场文书
计划生育工作汇报
2014/10/28 职场文书
租车协议书
2015/01/27 职场文书
2015年街道办事处工作总结
2015/05/22 职场文书
公司员工宿舍管理制度
2015/08/03 职场文书
党员反四风学习心得体会
2016/01/22 职场文书
react使用antd的上传组件实现文件表单一起提交功能(完整代码)
2021/06/29 Javascript
win11怎么用快捷键锁屏? windows11锁屏的几种方法
2021/11/21 数码科技