pytorch实现对输入超过三通道的数据进行训练


Posted in Python onJanuary 15, 2020

案例背景:视频识别

假设每次输入是8s的灰度视频,视频帧率为25fps,则视频由200帧图像序列构成.每帧是一副单通道的灰度图像,通过pythonb里面的np.stack(深度拼接)可将200帧拼接成200通道的深度数据.进而送到网络里面去训练.

如果输入图像200通道觉得多,可以对视频进行抽帧,针对具体场景可以随机抽帧或等间隔抽帧.比如这里等间隔抽取40帧.则最后输入视频相当于输入一个40通道的图像数据了.

pytorch对超过三通道数据的加载:

读取视频每一帧,转为array格式,然后依次将每一帧进行深度拼接,最后得到一个40通道的array格式的深度数据,保存到pickle里.

对每个视频都进行上述操作,保存到pickle里.

我这里将火的视频深度数据保存在一个.pkl文件中,一共2504个火的视频,即2504个火的深度数据.

将非火的视频深度数据保存在一个.pkl文件中,一共3985个非火的视频,即3985个非火的深度数据.

数据加载

import torch 
from torch.utils import data
import os
from PIL import Image
import numpy as np
import pickle
 
class Fire_Unfire(data.Dataset):
  def __init__(self,fire_path,unfire_path):
    self.pickle_fire = open(fire_path,'rb')
    self.pickle_unfire = open(unfire_path,'rb')
    
  def __getitem__(self,index):
    if index <2504:
      fire = pickle.load(self.pickle_fire)#高*宽*通道
      fire = fire.transpose(2,0,1)#通道*高*宽
      data = torch.from_numpy(fire)
      label = 1
      return data,label
    elif index>=2504 and index<6489:
      unfire = pickle.load(self.pickle_unfire)
      unfire = unfire.transpose(2,0,1)
      data = torch.from_numpy(unfire)
      label = 0
      return data,label
    
  def __len__(self):
    return 6489
root_path = './datasets/train'
dataset = Fire_Unfire(root_path +'/fire_train.pkl',root_path +'/unfire_train.pkl')
 
#转换成pytorch网络输入的格式(批量大小,通道数,高,宽)
from torch.utils.data import DataLoader
fire_dataloader = DataLoader(dataset,batch_size=4,shuffle=True,drop_last = True)

模型训练

import torch
from torch.utils import data
from nets.mobilenet import mobilenet
from config.config import default_config
from torch.autograd import Variable as V
import numpy as np
import sys
import time
 
opt = default_config()
def train():
  #模型定义
  model = mobilenet().cuda()
  if opt.pretrain_model:
    model.load_state_dict(torch.load(opt.pretrain_model))
  
  #损失函数
  criterion = torch.nn.CrossEntropyLoss().cuda()
  
  #学习率
  lr = opt.lr
  
  #优化器
  optimizer = torch.optim.SGD(model.parameters(),lr = lr,weight_decay=opt.weight_decay)
  
  
  pre_loss = 0.0
  #训练
  for epoch in range(opt.max_epoch):
     #训练数据
    train_data = Fire_Unfire(opt.root_path +'/fire_train.pkl',opt.root_path +'/unfire_train.pkl')
    train_dataloader = data.DataLoader(train_data,batch_size=opt.batch_size,shuffle=True,drop_last = True)
    loss_sum = 0.0
    for i,(datas,labels) in enumerate(train_dataloader):
      #print(i,datas.size(),labels)
      #梯度清零
      optimizer.zero_grad()
      #输入
      input = V(datas.cuda()).float()
      #目标
      target = V(labels.cuda()).long()
      #输出
      score = model(input).cuda()
      #损失
      loss = criterion(score,target)
      loss_sum += loss
      #反向传播
      loss.backward()
      #梯度更新
      optimizer.step()      
    print('{}{}{}{}{}'.format('epoch:',epoch,',','loss:',loss))
    torch.save(model.state_dict(),'models/mobilenet_%d.pth'%(epoch+370))

RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target'

解决方案:target = target.long()

以上这篇pytorch实现对输入超过三通道的数据进行训练就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python标准库之随机数 (math包、random包)介绍
Nov 25 Python
在Python的框架中为MySQL实现restful接口的教程
Apr 08 Python
Python自动生产表情包
Mar 17 Python
详解python中executemany和序列的使用方法
Aug 12 Python
解决python3 安装完Pycurl在import pycurl时报错的问题
Oct 15 Python
python实现归并排序算法
Nov 22 Python
Python 实现交换矩阵的行示例
Jun 26 Python
Python实现代码统计工具
Sep 19 Python
Python 实现打印单词的菱形字符图案
Apr 12 Python
PyCharm配置anaconda环境的步骤详解
Jul 31 Python
使用python向MongoDB插入时间字段的操作
May 18 Python
深入理解Pytorch微调torchvision模型
Nov 11 Python
Pytorch 定义MyDatasets实现多通道分别输入不同数据方式
Jan 15 #Python
pytorch构建多模型实例
Jan 15 #Python
利用Pytorch实现简单的线性回归算法
Jan 15 #Python
pytorch实现线性拟合方式
Jan 15 #Python
Python 支持向量机分类器的实现
Jan 15 #Python
pytorch-神经网络拟合曲线实例
Jan 15 #Python
Pytorch中的VGG实现修改最后一层FC
Jan 15 #Python
You might like
回首过去10年中最搞笑的10部动漫,哪一部让你节操尽碎?
2020/03/03 日漫
php 数组排序 array_multisort与uasort的区别
2011/03/24 PHP
PHP导入Excel到MySQL的方法
2011/04/23 PHP
PHP表单提交表单名称含有点号(.)则会被转化为下划线(_)
2011/12/14 PHP
PHP中mb_convert_encoding与iconv函数的深入解析
2013/06/21 PHP
PHP结合Ffmpeg快速搭建流媒体服务的实践记录
2018/10/31 PHP
PHP+Mysql分布式事务与解决方案深入理解
2021/02/27 PHP
jQuery 获取URL参数的插件
2010/03/04 Javascript
Jquery 在页面加载后执行的几种方式
2014/03/14 Javascript
JQuery 图片滚动轮播示例代码
2014/03/24 Javascript
js使用心得分享
2015/01/13 Javascript
JQuery的ON()方法支持的所有事件罗列
2015/02/28 Javascript
window.open()实现post传递参数
2015/03/12 Javascript
js控制div弹出层实现方法
2015/05/11 Javascript
jQuery计算文本框字数及限制文本框字数的方法
2016/03/01 Javascript
详细总结Javascript中的焦点管理
2016/09/17 Javascript
Angular2 PrimeNG分页模块学习
2017/01/14 Javascript
Javascript 一些需要注意的细节(必看篇)
2017/07/08 Javascript
vue单页应用中如何使用jquery的方法示例
2017/07/27 jQuery
纯JS实现可用于页码更换的飞页特效示例
2018/05/21 Javascript
Vue $emit $refs子父组件间方法的调用实例
2018/09/12 Javascript
详解Vue前端对axios的封装和使用
2019/04/01 Javascript
layer的prompt弹出框,点击回车,触发确定事件的方法
2019/09/06 Javascript
javascript的惯性运动实现代码实例
2019/09/07 Javascript
Python的Tornado框架异步编程入门实例
2015/04/24 Python
python requests 库请求带有文件参数的接口实例
2019/01/03 Python
python如何实现异步调用函数执行
2019/07/08 Python
Python socket 套接字实现通信详解
2019/08/27 Python
Python调用Windows命令打印文件
2020/02/07 Python
python 深度学习中的4种激活函数
2020/09/18 Python
优秀教师单行材料
2014/12/16 职场文书
2015年转正工作总结范文
2015/04/02 职场文书
违纪开除通知书
2015/04/25 职场文书
交通安全温馨提示语
2015/07/14 职场文书
MySQL优化及索引解析
2022/03/17 MySQL
你真的会用Mysql的explain吗
2022/03/31 MySQL