使用Pytorch实现two-head(多输出)模型的操作


Posted in Python onMay 28, 2021

如何使用Pytorch实现two-head(多输出)模型

1. two-head模型定义

先放一张我要实现的模型结构图:

使用Pytorch实现two-head(多输出)模型的操作

如上图,就是一个two-head模型,也是一个但输入多输出模型。该模型的特点是输入一个x和一个t,h0和h1中只有一个会输出,所以可能这不算是一个典型的多输出模型。

2.实现所遇到的困难 一开始的想法:

这不是很简单嘛,做一个判断不就完了,t=0时模型为前半段加h0,t=1时模型为前半段加h1。但实现的时候傻眼了,发现在真正前向传播的时候t是一个tensor,有0有1,没法儿进行判断。

灵机一动,又生一法:把这个模型变为三个模型,前半段是一个模型(r),后面的h0和h1分别为另两个模型。把数据集按t=0和1分开,分别训练两个模型:r+h0和r+h1。

但是后来搜如何进行模型串联,发现极为麻烦。

3.解决方案

后来在pytorch的官方社区中看到一个极为简单的方法:

(1) 按照一般的多输出模型进行实现,代码如下:

def forward(self, x):
        #三层的表示层
        x = F.elu(self.fcR1(x))
        x = F.elu(self.fcR2(x))
        x = F.elu(self.fcR3(x))
		#two-head,两个head分别进行输出
        y0 = F.elu(self.fcH01(x))
        y0 = F.elu(self.fcH02(y0))
        y0 = F.elu(self.fcH03(y0))
        y1 = F.elu(self.fcH11(x))
        y1 = F.elu(self.fcH12(y1))
        y1 = F.elu(self.fcH13(y1))
        return y0, y1

这样就相当实现了一个多输出模型,一个x同时输出y0和y1.

训练的时候分别训练,也即分别建loss,代码如下:

f_out_y0, _ = net(x0)
            _, f_out_y1 = net(x1)
            #实例化损失函数
            criterion0 = Loss()
            criterion1 = Loss()
            loss0 = criterion0(f_y0, f_out_y0, w0)
            loss1 = criterion1(f_y1, f_out_y1, w1)
            print(loss0.item(), loss1.item())
            #对网络参数进行初始化
            optimizer.zero_grad()
            loss0.backward()
            loss1.backward()
            #对网络的参数进行更新
            optimizer.step()

先把x按t=0和t=1分为x0和x1,然后分别送入进行训练。这样就实现了一个two-head模型。

4.后记

我自以为多输出模型可以分为以下两类:

多个输出不同时获得,如本文情况。

多个输出同时获得。

多输出不同时获得的解决方法上文已说明。多输出同时获得则可以通过把y0和y1拼接起来一起输出来实现。

补充:PyTorch 多输入多输出模型构建

本篇教程基于 PyTorch 1.5版本

直接上代码!

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.distributed as dist
import torch.utils.data as data_utils
class Net(nn.Module):
    def __init__(self, n_input, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden1 = nn.Linear(n_input, n_hidden)
        self.hidden2 = nn.Linear(n_hidden, n_hidden)
        self.predict1 = nn.Linear(n_hidden*2, n_output)
        self.predict2 = nn.Linear(n_hidden*2, n_output)
    def forward(self, input1, input2): # 多输入!!!
        out01 = self.hidden1(input1)
        out02 = torch.relu(out01)
        out03 = self.hidden2(out02)
        out04 = torch.sigmoid(out03)
        out11 = self.hidden1(input2)
        out12 = torch.relu(out11)
        out13 = self.hidden2(out12)
        out14 = torch.sigmoid(out13)
        out = torch.cat((out04, out14), dim=1) # 模型层拼合!!!当然你的模型中可能不需要~
 
        out1 = self.predict1(out)
        out2 = self.predict2(out)
        return out1, out2 # 多输出!!!
net = Net(1, 20, 1)
x1 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # 请不要关心这里,随便弄一个数据,为了说明问题而已
y1 = x1.pow(3)+0.1*torch.randn(x1.size())
x2 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y2 = x2.pow(3)+0.1*torch.randn(x2.size())
x1, y1 = (Variable(x1), Variable(y1))
x2, y2 = (Variable(x2), Variable(y2))
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
loss_func = torch.nn.MSELoss()
for t in range(5000):
    prediction1, prediction2 = net(x1, x2)
    loss1 = loss_func(prediction1, y1)
    loss2 = loss_func(prediction2, y2)
    loss = loss1 + loss2 # 重点!
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if t % 100 == 0:
       print('Loss1 = %.4f' % loss1.data,'Loss2 = %.4f' % loss2.data,)

至此搞定!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python批量导出导入MySQL用户的方法
Nov 15 Python
详解Python中映射类型的内建函数和工厂函数
Aug 19 Python
举例讲解Python中的list列表数据结构用法
Mar 12 Python
搭建Python的Django框架环境并建立和运行第一个App的教程
Jul 02 Python
Python图片裁剪实例代码(如头像裁剪)
Jun 21 Python
简单了解什么是神经网络
Dec 23 Python
pandas 数据索引与选取的实现方法
Jun 21 Python
Kmeans均值聚类算法原理以及Python如何实现
Sep 26 Python
Python字典取键、值对的方法步骤
Sep 30 Python
用python对oracle进行简单性能测试
Dec 05 Python
用Python 执行cmd命令
Dec 18 Python
Python创建简单的神经网络实例讲解
Jan 04 Python
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 #Python
pytorch锁死在dataloader(训练时卡死)
Python趣味爬虫之用Python实现智慧校园一键评教
Pytorch 如何加速Dataloader提升数据读取速度
You might like
解决PHP mysql_query执行超时(Fatal error: Maximum execution time …)
2013/07/03 PHP
php计算数组不为空元素个数的方法
2014/01/27 PHP
ThinkPHP5.1的权限控制怎么写?分享一个AUTH权限控制
2021/03/09 PHP
用Javascript读取中文COOKIE的解决办法
2007/02/15 Javascript
Ext.MessageBox工具类简介
2009/12/10 Javascript
javascript 四则运算精度修正函数代码
2010/05/31 Javascript
JS 实现完美include载入实现代码
2010/08/05 Javascript
基于jQuery的仿flash的广告轮播代码
2010/11/04 Javascript
JS中把字符转成ASCII值的函数示例代码
2013/11/21 Javascript
关闭ie窗口清除Session的解决方法
2014/01/10 Javascript
JavaScript匿名函数与委托使用示例
2014/07/22 Javascript
jQuery实现仿Google首页拖动效果的方法
2015/05/04 Javascript
jquery实现Slide Out Navigation滑出式菜单效果代码
2015/09/07 Javascript
使用jQuery获取data-的自定义属性
2015/11/10 Javascript
详解js图片轮播效果实现原理
2015/12/17 Javascript
浅谈jquery之on()绑定事件和off()解除绑定事件
2016/10/26 Javascript
如何利用JQuery实现从底部回到顶部的功能
2016/12/27 Javascript
Javascript Function.prototype.bind详细分析
2016/12/29 Javascript
让你彻底掌握es6 Promise的八段代码
2017/07/26 Javascript
聊聊Vue.js的template编译的问题
2017/10/09 Javascript
解决angularjs service中依赖注入$scope报错的问题
2018/10/02 Javascript
微信小程序使用wx.request请求服务器json数据并渲染到页面操作示例
2019/03/30 Javascript
详解vuex数据传输的两种方式及this.$store undefined的解决办法
2019/08/26 Javascript
element-ui中按需引入的实现
2019/12/25 Javascript
python 字典(dict)按键和值排序
2016/06/28 Python
用python写个自动SSH登录远程服务器的小工具(实例)
2017/06/17 Python
Python基于pyCUDA实现GPU加速并行计算功能入门教程
2018/06/19 Python
Python中使用Counter进行字典创建以及key数量统计的方法
2018/07/06 Python
Python rstrip()方法实例详解
2018/11/11 Python
Python中logging.NullHandler 的使用教程
2018/11/29 Python
详解Python字符串切片
2019/05/20 Python
Python接口测试数据库封装实现原理
2020/05/09 Python
升职自荐信范文
2013/10/05 职场文书
会计出纳员的自我评价
2014/01/15 职场文书
安全事故隐患排查治理制度
2015/08/05 职场文书
Nginx反向代理至go-fastdfs案例讲解
2021/08/02 Servers