使用Pytorch实现two-head(多输出)模型的操作


Posted in Python onMay 28, 2021

如何使用Pytorch实现two-head(多输出)模型

1. two-head模型定义

先放一张我要实现的模型结构图:

使用Pytorch实现two-head(多输出)模型的操作

如上图,就是一个two-head模型,也是一个但输入多输出模型。该模型的特点是输入一个x和一个t,h0和h1中只有一个会输出,所以可能这不算是一个典型的多输出模型。

2.实现所遇到的困难 一开始的想法:

这不是很简单嘛,做一个判断不就完了,t=0时模型为前半段加h0,t=1时模型为前半段加h1。但实现的时候傻眼了,发现在真正前向传播的时候t是一个tensor,有0有1,没法儿进行判断。

灵机一动,又生一法:把这个模型变为三个模型,前半段是一个模型(r),后面的h0和h1分别为另两个模型。把数据集按t=0和1分开,分别训练两个模型:r+h0和r+h1。

但是后来搜如何进行模型串联,发现极为麻烦。

3.解决方案

后来在pytorch的官方社区中看到一个极为简单的方法:

(1) 按照一般的多输出模型进行实现,代码如下:

def forward(self, x):
        #三层的表示层
        x = F.elu(self.fcR1(x))
        x = F.elu(self.fcR2(x))
        x = F.elu(self.fcR3(x))
		#two-head,两个head分别进行输出
        y0 = F.elu(self.fcH01(x))
        y0 = F.elu(self.fcH02(y0))
        y0 = F.elu(self.fcH03(y0))
        y1 = F.elu(self.fcH11(x))
        y1 = F.elu(self.fcH12(y1))
        y1 = F.elu(self.fcH13(y1))
        return y0, y1

这样就相当实现了一个多输出模型,一个x同时输出y0和y1.

训练的时候分别训练,也即分别建loss,代码如下:

f_out_y0, _ = net(x0)
            _, f_out_y1 = net(x1)
            #实例化损失函数
            criterion0 = Loss()
            criterion1 = Loss()
            loss0 = criterion0(f_y0, f_out_y0, w0)
            loss1 = criterion1(f_y1, f_out_y1, w1)
            print(loss0.item(), loss1.item())
            #对网络参数进行初始化
            optimizer.zero_grad()
            loss0.backward()
            loss1.backward()
            #对网络的参数进行更新
            optimizer.step()

先把x按t=0和t=1分为x0和x1,然后分别送入进行训练。这样就实现了一个two-head模型。

4.后记

我自以为多输出模型可以分为以下两类:

多个输出不同时获得,如本文情况。

多个输出同时获得。

多输出不同时获得的解决方法上文已说明。多输出同时获得则可以通过把y0和y1拼接起来一起输出来实现。

补充:PyTorch 多输入多输出模型构建

本篇教程基于 PyTorch 1.5版本

直接上代码!

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.distributed as dist
import torch.utils.data as data_utils
class Net(nn.Module):
    def __init__(self, n_input, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden1 = nn.Linear(n_input, n_hidden)
        self.hidden2 = nn.Linear(n_hidden, n_hidden)
        self.predict1 = nn.Linear(n_hidden*2, n_output)
        self.predict2 = nn.Linear(n_hidden*2, n_output)
    def forward(self, input1, input2): # 多输入!!!
        out01 = self.hidden1(input1)
        out02 = torch.relu(out01)
        out03 = self.hidden2(out02)
        out04 = torch.sigmoid(out03)
        out11 = self.hidden1(input2)
        out12 = torch.relu(out11)
        out13 = self.hidden2(out12)
        out14 = torch.sigmoid(out13)
        out = torch.cat((out04, out14), dim=1) # 模型层拼合!!!当然你的模型中可能不需要~
 
        out1 = self.predict1(out)
        out2 = self.predict2(out)
        return out1, out2 # 多输出!!!
net = Net(1, 20, 1)
x1 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # 请不要关心这里,随便弄一个数据,为了说明问题而已
y1 = x1.pow(3)+0.1*torch.randn(x1.size())
x2 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y2 = x2.pow(3)+0.1*torch.randn(x2.size())
x1, y1 = (Variable(x1), Variable(y1))
x2, y2 = (Variable(x2), Variable(y2))
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
loss_func = torch.nn.MSELoss()
for t in range(5000):
    prediction1, prediction2 = net(x1, x2)
    loss1 = loss_func(prediction1, y1)
    loss2 = loss_func(prediction2, y2)
    loss = loss1 + loss2 # 重点!
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if t % 100 == 0:
       print('Loss1 = %.4f' % loss1.data,'Loss2 = %.4f' % loss2.data,)

至此搞定!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
极简的Python入门指引
Apr 01 Python
在Python的Django框架中获取单个对象数据的简单方法
Jul 17 Python
Python中的super()方法使用简介
Aug 14 Python
基于python的字节编译详解
Sep 20 Python
Python实现在某个数组中查找一个值的算法示例
Jun 27 Python
Python数据类型之Tuple元组实例详解
May 08 Python
python 命令行传入参数实现解析
Aug 30 Python
PyCharm第一次安装及使用教程
Jan 08 Python
使用pytorch完成kaggle猫狗图像识别方式
Jan 10 Python
opencv python在视屏上截图功能的实现
Mar 05 Python
Python numpy多维数组实现原理详解
Mar 10 Python
Python Socket TCP双端聊天功能实现过程详解
Jun 15 Python
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 #Python
pytorch锁死在dataloader(训练时卡死)
Python趣味爬虫之用Python实现智慧校园一键评教
Pytorch 如何加速Dataloader提升数据读取速度
You might like
php中函数的形参与实参的问题说明
2010/09/01 PHP
PHP修改session_id示例代码
2014/01/08 PHP
PHP之uniqid()函数用法
2014/11/03 PHP
PHP面向对象之工作单元(实例讲解)
2017/06/26 PHP
PHP实现简单登录界面
2019/10/23 PHP
简单的js分页脚本
2009/05/21 Javascript
JQuery UI DatePicker中z-index默认为1的解决办法
2010/09/28 Javascript
JQuery中如何传递参数如click(),change()等具体实现
2013/04/28 Javascript
js中的this关键字详解
2013/09/25 Javascript
JavaScript+html5 canvas绘制的小人效果
2016/01/27 Javascript
深入理解Ajax的get和post请求
2016/06/02 Javascript
BootStrap实现树形目录组件代码详解
2016/06/21 Javascript
JavaScript实现刷新不重记的倒计时
2016/08/10 Javascript
基于JS实现仿百度百家主页的轮播图效果
2017/03/06 Javascript
ZeroClipboard.js使用一个flash复制多个文本框
2017/06/19 Javascript
JS图片轮播与索引变色功能实例详解
2017/07/06 Javascript
基于Bootstrap下拉框插件bootstrap-select使用方法详解
2018/08/07 Javascript
nuxt踩坑之Vuex状态树的模块方式使用详解
2019/09/06 Javascript
JavaScript观察者模式原理与用法实例详解
2020/03/10 Javascript
解决vue字符串换行问题(绝对管用)
2020/08/06 Javascript
[46:48]DOTA2上海特级锦标赛A组小组赛#2 Secret VS CDEC第三局
2016/02/25 DOTA
快速了解Python开发中的cookie及简单代码示例
2018/01/17 Python
Python cookbook(数据结构与算法)实现查找两个字典相同点的方法
2018/02/18 Python
python实现简单日志记录库glog的使用
2019/12/13 Python
Python 数据的累加与统计的示例代码
2020/08/03 Python
CSS3实现3D翻书效果
2016/06/20 HTML / CSS
伦敦香水公司:The London Perfume Company
2019/11/13 全球购物
同步和异步有何异同,在什么情况下分别使用他们?举例说明
2014/02/27 面试题
《盘古开天地》教学反思
2014/02/28 职场文书
小学生常见病防治方案
2014/06/06 职场文书
材料物理专业求职信
2014/09/01 职场文书
群众路线自我剖析及整改措施
2014/11/04 职场文书
雨中的树观后感
2015/06/03 职场文书
导游词之广西漓江
2019/11/02 职场文书
如何在CSS中绘制曲线图形及展示动画
2021/05/24 HTML / CSS
Redis 彻底禁用RDB持久化操作
2021/07/09 Redis