使用Pytorch实现two-head(多输出)模型的操作


Posted in Python onMay 28, 2021

如何使用Pytorch实现two-head(多输出)模型

1. two-head模型定义

先放一张我要实现的模型结构图:

使用Pytorch实现two-head(多输出)模型的操作

如上图,就是一个two-head模型,也是一个但输入多输出模型。该模型的特点是输入一个x和一个t,h0和h1中只有一个会输出,所以可能这不算是一个典型的多输出模型。

2.实现所遇到的困难 一开始的想法:

这不是很简单嘛,做一个判断不就完了,t=0时模型为前半段加h0,t=1时模型为前半段加h1。但实现的时候傻眼了,发现在真正前向传播的时候t是一个tensor,有0有1,没法儿进行判断。

灵机一动,又生一法:把这个模型变为三个模型,前半段是一个模型(r),后面的h0和h1分别为另两个模型。把数据集按t=0和1分开,分别训练两个模型:r+h0和r+h1。

但是后来搜如何进行模型串联,发现极为麻烦。

3.解决方案

后来在pytorch的官方社区中看到一个极为简单的方法:

(1) 按照一般的多输出模型进行实现,代码如下:

def forward(self, x):
        #三层的表示层
        x = F.elu(self.fcR1(x))
        x = F.elu(self.fcR2(x))
        x = F.elu(self.fcR3(x))
		#two-head,两个head分别进行输出
        y0 = F.elu(self.fcH01(x))
        y0 = F.elu(self.fcH02(y0))
        y0 = F.elu(self.fcH03(y0))
        y1 = F.elu(self.fcH11(x))
        y1 = F.elu(self.fcH12(y1))
        y1 = F.elu(self.fcH13(y1))
        return y0, y1

这样就相当实现了一个多输出模型,一个x同时输出y0和y1.

训练的时候分别训练,也即分别建loss,代码如下:

f_out_y0, _ = net(x0)
            _, f_out_y1 = net(x1)
            #实例化损失函数
            criterion0 = Loss()
            criterion1 = Loss()
            loss0 = criterion0(f_y0, f_out_y0, w0)
            loss1 = criterion1(f_y1, f_out_y1, w1)
            print(loss0.item(), loss1.item())
            #对网络参数进行初始化
            optimizer.zero_grad()
            loss0.backward()
            loss1.backward()
            #对网络的参数进行更新
            optimizer.step()

先把x按t=0和t=1分为x0和x1,然后分别送入进行训练。这样就实现了一个two-head模型。

4.后记

我自以为多输出模型可以分为以下两类:

多个输出不同时获得,如本文情况。

多个输出同时获得。

多输出不同时获得的解决方法上文已说明。多输出同时获得则可以通过把y0和y1拼接起来一起输出来实现。

补充:PyTorch 多输入多输出模型构建

本篇教程基于 PyTorch 1.5版本

直接上代码!

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.distributed as dist
import torch.utils.data as data_utils
class Net(nn.Module):
    def __init__(self, n_input, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden1 = nn.Linear(n_input, n_hidden)
        self.hidden2 = nn.Linear(n_hidden, n_hidden)
        self.predict1 = nn.Linear(n_hidden*2, n_output)
        self.predict2 = nn.Linear(n_hidden*2, n_output)
    def forward(self, input1, input2): # 多输入!!!
        out01 = self.hidden1(input1)
        out02 = torch.relu(out01)
        out03 = self.hidden2(out02)
        out04 = torch.sigmoid(out03)
        out11 = self.hidden1(input2)
        out12 = torch.relu(out11)
        out13 = self.hidden2(out12)
        out14 = torch.sigmoid(out13)
        out = torch.cat((out04, out14), dim=1) # 模型层拼合!!!当然你的模型中可能不需要~
 
        out1 = self.predict1(out)
        out2 = self.predict2(out)
        return out1, out2 # 多输出!!!
net = Net(1, 20, 1)
x1 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # 请不要关心这里,随便弄一个数据,为了说明问题而已
y1 = x1.pow(3)+0.1*torch.randn(x1.size())
x2 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y2 = x2.pow(3)+0.1*torch.randn(x2.size())
x1, y1 = (Variable(x1), Variable(y1))
x2, y2 = (Variable(x2), Variable(y2))
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
loss_func = torch.nn.MSELoss()
for t in range(5000):
    prediction1, prediction2 = net(x1, x2)
    loss1 = loss_func(prediction1, y1)
    loss2 = loss_func(prediction2, y2)
    loss = loss1 + loss2 # 重点!
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if t % 100 == 0:
       print('Loss1 = %.4f' % loss1.data,'Loss2 = %.4f' % loss2.data,)

至此搞定!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
400多行Python代码实现了一个FTP服务器
May 10 Python
python计算圆周长、面积、球体体积并画出圆
Apr 08 Python
Python基于checksum计算文件是否相同的方法
Jul 09 Python
Python编程实现两个文件夹里文件的对比功能示例【包含内容的对比】
Jun 20 Python
Python书单 不将就
Jul 11 Python
python利用高阶函数实现剪枝函数
Mar 20 Python
Flask核心机制之上下文源码剖析
Dec 25 Python
pycharm配置pyqt5-tools开发环境的方法步骤
Feb 11 Python
浅析PyTorch中nn.Module的使用
Aug 18 Python
Python threading的使用方法解析
Aug 28 Python
pytorch torch.nn.AdaptiveAvgPool2d()自适应平均池化函数详解
Jan 03 Python
Tensorflow矩阵运算实例(矩阵相乘,点乘,行/列累加)
Feb 05 Python
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 #Python
pytorch锁死在dataloader(训练时卡死)
Python趣味爬虫之用Python实现智慧校园一键评教
Pytorch 如何加速Dataloader提升数据读取速度
You might like
基于Windows下Apache PHP5.3.1安装教程
2010/01/08 PHP
《PHP编程最快明白》第七讲:php图片验证码与缩略图
2010/11/01 PHP
打造超酷的PHP数据饼图效果实现代码
2011/11/23 PHP
在PHP中设置、使用、删除Cookie的解决方法
2013/05/06 PHP
php多文件上传实现代码
2014/02/20 PHP
php添加数据到xml文件的简单例子
2016/09/08 PHP
JavaScript RegExp方法获取地址栏参数(面向对象)
2009/03/10 Javascript
javascript获取隐藏dom的宽高 具体实现
2013/07/14 Javascript
处理文本部分内容的TextRange对象应用实例
2014/07/29 Javascript
javascript转换日期字符串为Date日期对象的方法
2015/02/13 Javascript
Nodejs Stream 数据流使用手册
2016/04/17 NodeJs
利用javascript实现的三种图片放大镜效果实例(附源码)
2017/01/23 Javascript
JS 组件系列之Bootstrap Table 冻结列功能IE浏览器兼容性问题解决方案
2017/06/30 Javascript
vue使用axios跨域请求数据问题详解
2017/10/18 Javascript
微信小程序实现倒计时补零功能
2018/07/09 Javascript
vue模仿网易云音乐的单页面应用
2019/04/24 Javascript
vue实现全匹配搜索列表内容
2019/09/26 Javascript
vue 路由守卫(导航守卫)及其具体使用
2020/02/25 Javascript
详解Vue 数据更新了但页面没有更新的 7 种情况汇总及延伸总结
2020/05/28 Javascript
[37:45]2014 DOTA2国际邀请赛中国区预选赛5.21 DT VS Orenda
2014/05/22 DOTA
python写的一个文本编辑器
2014/01/23 Python
python实现根据月份和日期得到星座的方法
2015/03/27 Python
详解python中的线程
2018/02/10 Python
python 多线程将大文件分开下载后在合并的实例
2018/11/09 Python
python 提取tuple类型值中json格式的key值方法
2018/12/31 Python
Python同步遍历多个列表的示例
2019/02/19 Python
Python xpath表达式如何实现数据处理
2020/06/13 Python
python爬取”顶点小说网“《纯阳剑尊》的示例代码
2020/10/16 Python
基于HTML5 audio元素播放声音jQuery小插件
2011/05/11 HTML / CSS
巴西宠物店在线:Geração Pet
2017/05/31 全球购物
九州传奇上机题
2014/07/10 面试题
资产经营总监岗位职责
2013/12/04 职场文书
史学专业毕业生求职信
2014/05/09 职场文书
高中课程设置方案
2014/05/28 职场文书
2015年小班保育员工作总结
2015/05/27 职场文书
python中的plt.cm.Paired用法说明
2021/05/31 Python