pytorch实现对输入超过三通道的数据进行训练


Posted in Python onJanuary 15, 2020

案例背景:视频识别

假设每次输入是8s的灰度视频,视频帧率为25fps,则视频由200帧图像序列构成.每帧是一副单通道的灰度图像,通过pythonb里面的np.stack(深度拼接)可将200帧拼接成200通道的深度数据.进而送到网络里面去训练.

如果输入图像200通道觉得多,可以对视频进行抽帧,针对具体场景可以随机抽帧或等间隔抽帧.比如这里等间隔抽取40帧.则最后输入视频相当于输入一个40通道的图像数据了.

pytorch对超过三通道数据的加载:

读取视频每一帧,转为array格式,然后依次将每一帧进行深度拼接,最后得到一个40通道的array格式的深度数据,保存到pickle里.

对每个视频都进行上述操作,保存到pickle里.

我这里将火的视频深度数据保存在一个.pkl文件中,一共2504个火的视频,即2504个火的深度数据.

将非火的视频深度数据保存在一个.pkl文件中,一共3985个非火的视频,即3985个非火的深度数据.

数据加载

import torch 
from torch.utils import data
import os
from PIL import Image
import numpy as np
import pickle
 
class Fire_Unfire(data.Dataset):
  def __init__(self,fire_path,unfire_path):
    self.pickle_fire = open(fire_path,'rb')
    self.pickle_unfire = open(unfire_path,'rb')
    
  def __getitem__(self,index):
    if index <2504:
      fire = pickle.load(self.pickle_fire)#高*宽*通道
      fire = fire.transpose(2,0,1)#通道*高*宽
      data = torch.from_numpy(fire)
      label = 1
      return data,label
    elif index>=2504 and index<6489:
      unfire = pickle.load(self.pickle_unfire)
      unfire = unfire.transpose(2,0,1)
      data = torch.from_numpy(unfire)
      label = 0
      return data,label
    
  def __len__(self):
    return 6489
root_path = './datasets/train'
dataset = Fire_Unfire(root_path +'/fire_train.pkl',root_path +'/unfire_train.pkl')
 
#转换成pytorch网络输入的格式(批量大小,通道数,高,宽)
from torch.utils.data import DataLoader
fire_dataloader = DataLoader(dataset,batch_size=4,shuffle=True,drop_last = True)

模型训练

import torch
from torch.utils import data
from nets.mobilenet import mobilenet
from config.config import default_config
from torch.autograd import Variable as V
import numpy as np
import sys
import time
 
opt = default_config()
def train():
  #模型定义
  model = mobilenet().cuda()
  if opt.pretrain_model:
    model.load_state_dict(torch.load(opt.pretrain_model))
  
  #损失函数
  criterion = torch.nn.CrossEntropyLoss().cuda()
  
  #学习率
  lr = opt.lr
  
  #优化器
  optimizer = torch.optim.SGD(model.parameters(),lr = lr,weight_decay=opt.weight_decay)
  
  
  pre_loss = 0.0
  #训练
  for epoch in range(opt.max_epoch):
     #训练数据
    train_data = Fire_Unfire(opt.root_path +'/fire_train.pkl',opt.root_path +'/unfire_train.pkl')
    train_dataloader = data.DataLoader(train_data,batch_size=opt.batch_size,shuffle=True,drop_last = True)
    loss_sum = 0.0
    for i,(datas,labels) in enumerate(train_dataloader):
      #print(i,datas.size(),labels)
      #梯度清零
      optimizer.zero_grad()
      #输入
      input = V(datas.cuda()).float()
      #目标
      target = V(labels.cuda()).long()
      #输出
      score = model(input).cuda()
      #损失
      loss = criterion(score,target)
      loss_sum += loss
      #反向传播
      loss.backward()
      #梯度更新
      optimizer.step()      
    print('{}{}{}{}{}'.format('epoch:',epoch,',','loss:',loss))
    torch.save(model.state_dict(),'models/mobilenet_%d.pth'%(epoch+370))

RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target'

解决方案:target = target.long()

以上这篇pytorch实现对输入超过三通道的数据进行训练就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python记录详细调用堆栈日志的方法
May 05 Python
分享Python文本生成二维码实例
Jan 06 Python
python3.5 tkinter实现页面跳转
Jan 30 Python
Python使用itertools模块实现排列组合功能示例
Jul 02 Python
Pycharm+Scrapy安装并且初始化项目的方法
Jan 15 Python
python实现微信定时每天和女友发送消息
Apr 29 Python
python队列Queue的详解
May 10 Python
python连接PostgreSQL过程解析
Feb 09 Python
探秘TensorFlow 和 NumPy 的 Broadcasting 机制
Mar 13 Python
Python xlwt模块使用代码实例
Jun 10 Python
python中通过pip安装库文件时出现“EnvironmentError: [WinError 5] 拒绝访问”的问题及解决方案
Aug 11 Python
python3获取控制台输入的数据的具体实例
Aug 16 Python
Pytorch 定义MyDatasets实现多通道分别输入不同数据方式
Jan 15 #Python
pytorch构建多模型实例
Jan 15 #Python
利用Pytorch实现简单的线性回归算法
Jan 15 #Python
pytorch实现线性拟合方式
Jan 15 #Python
Python 支持向量机分类器的实现
Jan 15 #Python
pytorch-神经网络拟合曲线实例
Jan 15 #Python
Pytorch中的VGG实现修改最后一层FC
Jan 15 #Python
You might like
php smarty的预保留变量总结
2008/12/04 PHP
php实现斐波那契数列的简单写法
2014/07/19 PHP
php实现的Timer页面运行时间监测类
2014/09/24 PHP
如何写php守护进程(Daemon)
2015/12/30 PHP
8款非常棒的响应式jQuery 幻灯片插件推荐
2012/02/02 Javascript
js jquery分别实现动态的文件上传操作按钮的添加和删除
2014/01/13 Javascript
JS判断字符串长度的5个方法(区分中文和英文)
2014/03/18 Javascript
原生js编写设为首页兼容ie、火狐和谷歌
2014/06/05 Javascript
connect中间件session、cookie的使用方法分享
2014/06/17 Javascript
js中数组排序sort方法的原理分析
2014/11/20 Javascript
GitHub上一些实用的JavaScript的文件压缩解压缩库推荐
2016/03/13 Javascript
nodejs入门教程一:概念与用法简介
2017/04/24 NodeJs
JavaScript实现的可变动态数字键盘控件方式实例代码
2017/07/15 Javascript
微信小程序多列选择器range-key使用详解
2020/03/30 Javascript
微信小程序实现打卡日历功能
2020/09/21 Javascript
Three.js中矩阵和向量的使用教程
2019/03/19 Javascript
JS数组方法reduce的用法实例分析
2020/03/03 Javascript
js实现简单选项卡制作
2020/08/05 Javascript
JavaScript实现滑块验证解锁
2021/01/07 Javascript
[38:41]2014 DOTA2国际邀请赛中国区预选赛 LGD VS CNB
2014/05/22 DOTA
Python基于QRCode实现生成二维码的方法【下载,安装,调用等】
2017/07/11 Python
Python程序运行原理图文解析
2018/02/10 Python
TensorFlow实现Softmax回归模型
2018/03/09 Python
tensorflow实现tensor中满足某一条件的数值取出组成新的tensor
2020/01/04 Python
python tkiner实现 一个小小的图片翻页功能的示例代码
2020/06/24 Python
python图片合成的示例
2020/11/09 Python
Python3压缩和解压缩实现代码
2021/03/01 Python
adidas澳大利亚官方网站:adidas Australia
2018/04/15 全球购物
香港莎莎官网Sasa.com:亚洲著名国际化妆品商城
2019/11/10 全球购物
大学生职业生涯规划范文
2014/01/22 职场文书
骨干教师考核方案
2014/05/09 职场文书
美术第二课堂活动总结
2014/07/08 职场文书
科级干部群众路线教育实践活动对照检查材料思想汇报
2014/09/20 职场文书
2015年机关党委工作总结
2015/05/23 职场文书
繁星春水读书笔记
2015/06/30 职场文书
开业庆典嘉宾致辞
2015/08/01 职场文书