pytorch实现对输入超过三通道的数据进行训练


Posted in Python onJanuary 15, 2020

案例背景:视频识别

假设每次输入是8s的灰度视频,视频帧率为25fps,则视频由200帧图像序列构成.每帧是一副单通道的灰度图像,通过pythonb里面的np.stack(深度拼接)可将200帧拼接成200通道的深度数据.进而送到网络里面去训练.

如果输入图像200通道觉得多,可以对视频进行抽帧,针对具体场景可以随机抽帧或等间隔抽帧.比如这里等间隔抽取40帧.则最后输入视频相当于输入一个40通道的图像数据了.

pytorch对超过三通道数据的加载:

读取视频每一帧,转为array格式,然后依次将每一帧进行深度拼接,最后得到一个40通道的array格式的深度数据,保存到pickle里.

对每个视频都进行上述操作,保存到pickle里.

我这里将火的视频深度数据保存在一个.pkl文件中,一共2504个火的视频,即2504个火的深度数据.

将非火的视频深度数据保存在一个.pkl文件中,一共3985个非火的视频,即3985个非火的深度数据.

数据加载

import torch 
from torch.utils import data
import os
from PIL import Image
import numpy as np
import pickle
 
class Fire_Unfire(data.Dataset):
  def __init__(self,fire_path,unfire_path):
    self.pickle_fire = open(fire_path,'rb')
    self.pickle_unfire = open(unfire_path,'rb')
    
  def __getitem__(self,index):
    if index <2504:
      fire = pickle.load(self.pickle_fire)#高*宽*通道
      fire = fire.transpose(2,0,1)#通道*高*宽
      data = torch.from_numpy(fire)
      label = 1
      return data,label
    elif index>=2504 and index<6489:
      unfire = pickle.load(self.pickle_unfire)
      unfire = unfire.transpose(2,0,1)
      data = torch.from_numpy(unfire)
      label = 0
      return data,label
    
  def __len__(self):
    return 6489
root_path = './datasets/train'
dataset = Fire_Unfire(root_path +'/fire_train.pkl',root_path +'/unfire_train.pkl')
 
#转换成pytorch网络输入的格式(批量大小,通道数,高,宽)
from torch.utils.data import DataLoader
fire_dataloader = DataLoader(dataset,batch_size=4,shuffle=True,drop_last = True)

模型训练

import torch
from torch.utils import data
from nets.mobilenet import mobilenet
from config.config import default_config
from torch.autograd import Variable as V
import numpy as np
import sys
import time
 
opt = default_config()
def train():
  #模型定义
  model = mobilenet().cuda()
  if opt.pretrain_model:
    model.load_state_dict(torch.load(opt.pretrain_model))
  
  #损失函数
  criterion = torch.nn.CrossEntropyLoss().cuda()
  
  #学习率
  lr = opt.lr
  
  #优化器
  optimizer = torch.optim.SGD(model.parameters(),lr = lr,weight_decay=opt.weight_decay)
  
  
  pre_loss = 0.0
  #训练
  for epoch in range(opt.max_epoch):
     #训练数据
    train_data = Fire_Unfire(opt.root_path +'/fire_train.pkl',opt.root_path +'/unfire_train.pkl')
    train_dataloader = data.DataLoader(train_data,batch_size=opt.batch_size,shuffle=True,drop_last = True)
    loss_sum = 0.0
    for i,(datas,labels) in enumerate(train_dataloader):
      #print(i,datas.size(),labels)
      #梯度清零
      optimizer.zero_grad()
      #输入
      input = V(datas.cuda()).float()
      #目标
      target = V(labels.cuda()).long()
      #输出
      score = model(input).cuda()
      #损失
      loss = criterion(score,target)
      loss_sum += loss
      #反向传播
      loss.backward()
      #梯度更新
      optimizer.step()      
    print('{}{}{}{}{}'.format('epoch:',epoch,',','loss:',loss))
    torch.save(model.state_dict(),'models/mobilenet_%d.pth'%(epoch+370))

RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target'

解决方案:target = target.long()

以上这篇pytorch实现对输入超过三通道的数据进行训练就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中查找excel某一列的重复数据 剔除之后打印
Feb 10 Python
TensorFlow深度学习之卷积神经网络CNN
Mar 09 Python
django加载本地html的方法
May 27 Python
python 自定义异常和异常捕捉的方法
Oct 18 Python
python支付宝支付示例详解
Aug 22 Python
Python实现微信中找回好友、群聊用户撤回的消息功能示例
Aug 23 Python
Python 生成一个从0到n个数字的列表4种方法小结
Nov 28 Python
Python 元组拆包示例(Tuple Unpacking)
Dec 24 Python
Pycharm中import torch报错的快速解决方法
Mar 05 Python
pycharm如何使用anaconda中的各种包(操作步骤)
Jul 31 Python
Python Selenium库的基本使用教程
Jan 04 Python
详解Python牛顿插值法
May 11 Python
Pytorch 定义MyDatasets实现多通道分别输入不同数据方式
Jan 15 #Python
pytorch构建多模型实例
Jan 15 #Python
利用Pytorch实现简单的线性回归算法
Jan 15 #Python
pytorch实现线性拟合方式
Jan 15 #Python
Python 支持向量机分类器的实现
Jan 15 #Python
pytorch-神经网络拟合曲线实例
Jan 15 #Python
Pytorch中的VGG实现修改最后一层FC
Jan 15 #Python
You might like
如何实现给定日期的若干天以后的日期
2006/10/09 PHP
php与php MySQL 之间的关系
2009/07/17 PHP
ThinkPHP的URL重写问题
2014/06/22 PHP
CakePHP框架Session设置方法分析
2017/02/23 PHP
php 处理png图片白色背景色改为透明色的实例代码
2018/12/10 PHP
laravel中数据显示方法(默认值和下拉option默认选中)
2019/10/11 PHP
Thinkphp5框架中引入Markdown编辑器操作示例
2020/06/03 PHP
IE与Firefox下javascript getyear年份的兼容性写法
2007/12/20 Javascript
有趣的javascript数组定义方法
2010/09/10 Javascript
Jquery easyui 实现动态树
2015/11/17 Javascript
js判断输入字符串是否为空、空格、null的方法总结
2016/06/14 Javascript
基于jQuery选择器之表单对象属性筛选选择器的实例
2017/09/19 jQuery
nodejs中Express与Koa2对比分析
2018/02/06 NodeJs
JS实现可以用键盘方向键控制的动画
2020/12/11 Javascript
Python笔记(叁)继续学习
2012/10/24 Python
完美解决python遍历删除字典里值为空的元素报错问题
2016/09/11 Python
Python 递归函数详解及实例
2016/12/27 Python
python tkinter界面居中显示的方法
2018/10/11 Python
uwsgi+nginx部署Django项目操作示例
2018/12/04 Python
Python实现分段线性插值
2018/12/17 Python
对Python 两大环境管理神器 pyenv 和 virtualenv详解
2018/12/31 Python
python3中类的继承以及self和super的区别详解
2019/06/26 Python
Python 基于wxpy库实现微信添加好友功能(简洁)
2019/11/29 Python
Python3.8.2安装包及安装教程图文详解(附安装包)
2020/11/28 Python
1688平价精选商城:阿里集团旗下,工厂出厂价格直销
2017/04/24 全球购物
ECCO爱步官方旗舰店:丹麦鞋履品牌
2018/01/02 全球购物
Moss Bros官网:英国排名第一的西装店
2020/02/26 全球购物
编写一子程序,将一链表倒序,即使链表表尾变表头,表头变表尾
2016/02/10 面试题
过程装备与控制工程专业求职信
2014/07/02 职场文书
关于有小孩的离婚协议书
2014/10/26 职场文书
文明班级申报材料
2014/12/24 职场文书
春季运动会加油词
2015/07/18 职场文书
推普标语口号大全
2015/12/26 职场文书
如何撰写创业策划书
2019/06/27 职场文书
如何用python插入独创性声明
2021/03/31 Python
Html5调用企业微信的实现
2021/04/16 HTML / CSS