使用Pytorch实现two-head(多输出)模型的操作


Posted in Python onMay 28, 2021

如何使用Pytorch实现two-head(多输出)模型

1. two-head模型定义

先放一张我要实现的模型结构图:

使用Pytorch实现two-head(多输出)模型的操作

如上图,就是一个two-head模型,也是一个但输入多输出模型。该模型的特点是输入一个x和一个t,h0和h1中只有一个会输出,所以可能这不算是一个典型的多输出模型。

2.实现所遇到的困难 一开始的想法:

这不是很简单嘛,做一个判断不就完了,t=0时模型为前半段加h0,t=1时模型为前半段加h1。但实现的时候傻眼了,发现在真正前向传播的时候t是一个tensor,有0有1,没法儿进行判断。

灵机一动,又生一法:把这个模型变为三个模型,前半段是一个模型(r),后面的h0和h1分别为另两个模型。把数据集按t=0和1分开,分别训练两个模型:r+h0和r+h1。

但是后来搜如何进行模型串联,发现极为麻烦。

3.解决方案

后来在pytorch的官方社区中看到一个极为简单的方法:

(1) 按照一般的多输出模型进行实现,代码如下:

def forward(self, x):
        #三层的表示层
        x = F.elu(self.fcR1(x))
        x = F.elu(self.fcR2(x))
        x = F.elu(self.fcR3(x))
		#two-head,两个head分别进行输出
        y0 = F.elu(self.fcH01(x))
        y0 = F.elu(self.fcH02(y0))
        y0 = F.elu(self.fcH03(y0))
        y1 = F.elu(self.fcH11(x))
        y1 = F.elu(self.fcH12(y1))
        y1 = F.elu(self.fcH13(y1))
        return y0, y1

这样就相当实现了一个多输出模型,一个x同时输出y0和y1.

训练的时候分别训练,也即分别建loss,代码如下:

f_out_y0, _ = net(x0)
            _, f_out_y1 = net(x1)
            #实例化损失函数
            criterion0 = Loss()
            criterion1 = Loss()
            loss0 = criterion0(f_y0, f_out_y0, w0)
            loss1 = criterion1(f_y1, f_out_y1, w1)
            print(loss0.item(), loss1.item())
            #对网络参数进行初始化
            optimizer.zero_grad()
            loss0.backward()
            loss1.backward()
            #对网络的参数进行更新
            optimizer.step()

先把x按t=0和t=1分为x0和x1,然后分别送入进行训练。这样就实现了一个two-head模型。

4.后记

我自以为多输出模型可以分为以下两类:

多个输出不同时获得,如本文情况。

多个输出同时获得。

多输出不同时获得的解决方法上文已说明。多输出同时获得则可以通过把y0和y1拼接起来一起输出来实现。

补充:PyTorch 多输入多输出模型构建

本篇教程基于 PyTorch 1.5版本

直接上代码!

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.distributed as dist
import torch.utils.data as data_utils
class Net(nn.Module):
    def __init__(self, n_input, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden1 = nn.Linear(n_input, n_hidden)
        self.hidden2 = nn.Linear(n_hidden, n_hidden)
        self.predict1 = nn.Linear(n_hidden*2, n_output)
        self.predict2 = nn.Linear(n_hidden*2, n_output)
    def forward(self, input1, input2): # 多输入!!!
        out01 = self.hidden1(input1)
        out02 = torch.relu(out01)
        out03 = self.hidden2(out02)
        out04 = torch.sigmoid(out03)
        out11 = self.hidden1(input2)
        out12 = torch.relu(out11)
        out13 = self.hidden2(out12)
        out14 = torch.sigmoid(out13)
        out = torch.cat((out04, out14), dim=1) # 模型层拼合!!!当然你的模型中可能不需要~
 
        out1 = self.predict1(out)
        out2 = self.predict2(out)
        return out1, out2 # 多输出!!!
net = Net(1, 20, 1)
x1 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # 请不要关心这里,随便弄一个数据,为了说明问题而已
y1 = x1.pow(3)+0.1*torch.randn(x1.size())
x2 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y2 = x2.pow(3)+0.1*torch.randn(x2.size())
x1, y1 = (Variable(x1), Variable(y1))
x2, y2 = (Variable(x2), Variable(y2))
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
loss_func = torch.nn.MSELoss()
for t in range(5000):
    prediction1, prediction2 = net(x1, x2)
    loss1 = loss_func(prediction1, y1)
    loss2 = loss_func(prediction2, y2)
    loss = loss1 + loss2 # 重点!
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if t % 100 == 0:
       print('Loss1 = %.4f' % loss1.data,'Loss2 = %.4f' % loss2.data,)

至此搞定!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python列表的常用操作方法小结
May 21 Python
Python多进程multiprocessing.Pool类详解
Apr 27 Python
selenium+python自动化测试之使用webdriver操作浏览器的方法
Jan 23 Python
Python3.5 Pandas模块缺失值处理和层次索引实例详解
Apr 23 Python
pandas实现将dataframe满足某一条件的值选出
Jun 12 Python
Python从列表推导到zip()函数的5种技巧总结
Oct 23 Python
pyinstaller打包找不到文件的问题解决
Apr 15 Python
解决jupyter运行pyqt代码内核重启的问题
Apr 16 Python
Python实现画图软件功能方法详解
Jul 28 Python
python opencv图像处理(素描、怀旧、光照、流年、滤镜 原理及实现)
Dec 10 Python
python 定义函数 返回值只取其中一个的实现
May 21 Python
python实现对doc、txt、xls等文档的读写操作
Apr 02 Python
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 #Python
pytorch锁死在dataloader(训练时卡死)
Python趣味爬虫之用Python实现智慧校园一键评教
Pytorch 如何加速Dataloader提升数据读取速度
You might like
WML,Apache,和 PHP 的介绍
2006/10/09 PHP
php调用mysql数据 dbclass类
2011/05/07 PHP
php模仿asp Application对象在线人数统计实现方法
2015/01/04 PHP
Yii+upload实现AJAX上传图片的方法
2016/07/13 PHP
解决Laravel5.x的php artisan migrate数据库迁移创建操作报错SQLSTATE[42000]
2020/04/06 PHP
提高网站信任度的技巧
2008/10/17 Javascript
Window.Open如何在同一个标签页打开
2014/06/20 Javascript
javascript中hasOwnProperty() 方法使用指南
2015/03/09 Javascript
javascript实现根据3原色制作颜色选择器的方法
2015/07/17 Javascript
JavaScript实现节点的删除与序号重建实例
2015/08/05 Javascript
js实现表单Radio切换效果的方法
2015/08/17 Javascript
easyui Draggable组件实现拖动效果
2015/08/19 Javascript
jQuery实现的网格线绘制方法
2016/06/20 Javascript
js 动态添加元素(div、li、img等)及设置属性的方法
2016/07/19 Javascript
判断数组的最佳方法(推荐)
2016/10/11 Javascript
js 提交form表单和设置form表单请求路径的实现方法
2016/10/25 Javascript
jQuery实现遍历复选框的方法示例
2017/03/06 Javascript
JS中使用gulp实现压缩文件及浏览器热加载功能
2017/07/12 Javascript
jquery 获取索引值在一定范围的列表方法
2018/01/25 jQuery
webpack中的热刷新与热加载的区别
2018/04/09 Javascript
在vue-cli的组件模板里使用font-awesome的两种方法
2018/09/28 Javascript
详解Vue中的基本语法和常用指令
2019/07/23 Javascript
使用vuex存储用户信息到localStorage的实例
2019/11/11 Javascript
Vue中通过vue-router实现命名视图的问题
2020/04/23 Javascript
python实现根据月份和日期得到星座的方法
2015/03/27 Python
Python进程通信之匿名管道实例讲解
2015/04/11 Python
Python3实现的字典、列表和json对象互转功能示例
2018/05/22 Python
python try except 捕获所有异常的实例
2018/10/18 Python
Python根据成绩分析系统浅析
2019/02/11 Python
大学活动策划书范文
2014/01/10 职场文书
户外婚礼策划方案
2014/02/08 职场文书
捐书活动总结
2014/05/04 职场文书
诉前财产保全担保书
2014/05/20 职场文书
国庆促销活动总结
2014/08/29 职场文书
优质护理心得体会
2016/01/22 职场文书
Java界面编程实现界面跳转
2022/06/16 Java/Android