使用Pytorch实现two-head(多输出)模型的操作


Posted in Python onMay 28, 2021

如何使用Pytorch实现two-head(多输出)模型

1. two-head模型定义

先放一张我要实现的模型结构图:

使用Pytorch实现two-head(多输出)模型的操作

如上图,就是一个two-head模型,也是一个但输入多输出模型。该模型的特点是输入一个x和一个t,h0和h1中只有一个会输出,所以可能这不算是一个典型的多输出模型。

2.实现所遇到的困难 一开始的想法:

这不是很简单嘛,做一个判断不就完了,t=0时模型为前半段加h0,t=1时模型为前半段加h1。但实现的时候傻眼了,发现在真正前向传播的时候t是一个tensor,有0有1,没法儿进行判断。

灵机一动,又生一法:把这个模型变为三个模型,前半段是一个模型(r),后面的h0和h1分别为另两个模型。把数据集按t=0和1分开,分别训练两个模型:r+h0和r+h1。

但是后来搜如何进行模型串联,发现极为麻烦。

3.解决方案

后来在pytorch的官方社区中看到一个极为简单的方法:

(1) 按照一般的多输出模型进行实现,代码如下:

def forward(self, x):
        #三层的表示层
        x = F.elu(self.fcR1(x))
        x = F.elu(self.fcR2(x))
        x = F.elu(self.fcR3(x))
		#two-head,两个head分别进行输出
        y0 = F.elu(self.fcH01(x))
        y0 = F.elu(self.fcH02(y0))
        y0 = F.elu(self.fcH03(y0))
        y1 = F.elu(self.fcH11(x))
        y1 = F.elu(self.fcH12(y1))
        y1 = F.elu(self.fcH13(y1))
        return y0, y1

这样就相当实现了一个多输出模型,一个x同时输出y0和y1.

训练的时候分别训练,也即分别建loss,代码如下:

f_out_y0, _ = net(x0)
            _, f_out_y1 = net(x1)
            #实例化损失函数
            criterion0 = Loss()
            criterion1 = Loss()
            loss0 = criterion0(f_y0, f_out_y0, w0)
            loss1 = criterion1(f_y1, f_out_y1, w1)
            print(loss0.item(), loss1.item())
            #对网络参数进行初始化
            optimizer.zero_grad()
            loss0.backward()
            loss1.backward()
            #对网络的参数进行更新
            optimizer.step()

先把x按t=0和t=1分为x0和x1,然后分别送入进行训练。这样就实现了一个two-head模型。

4.后记

我自以为多输出模型可以分为以下两类:

多个输出不同时获得,如本文情况。

多个输出同时获得。

多输出不同时获得的解决方法上文已说明。多输出同时获得则可以通过把y0和y1拼接起来一起输出来实现。

补充:PyTorch 多输入多输出模型构建

本篇教程基于 PyTorch 1.5版本

直接上代码!

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.distributed as dist
import torch.utils.data as data_utils
class Net(nn.Module):
    def __init__(self, n_input, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden1 = nn.Linear(n_input, n_hidden)
        self.hidden2 = nn.Linear(n_hidden, n_hidden)
        self.predict1 = nn.Linear(n_hidden*2, n_output)
        self.predict2 = nn.Linear(n_hidden*2, n_output)
    def forward(self, input1, input2): # 多输入!!!
        out01 = self.hidden1(input1)
        out02 = torch.relu(out01)
        out03 = self.hidden2(out02)
        out04 = torch.sigmoid(out03)
        out11 = self.hidden1(input2)
        out12 = torch.relu(out11)
        out13 = self.hidden2(out12)
        out14 = torch.sigmoid(out13)
        out = torch.cat((out04, out14), dim=1) # 模型层拼合!!!当然你的模型中可能不需要~
 
        out1 = self.predict1(out)
        out2 = self.predict2(out)
        return out1, out2 # 多输出!!!
net = Net(1, 20, 1)
x1 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # 请不要关心这里,随便弄一个数据,为了说明问题而已
y1 = x1.pow(3)+0.1*torch.randn(x1.size())
x2 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y2 = x2.pow(3)+0.1*torch.randn(x2.size())
x1, y1 = (Variable(x1), Variable(y1))
x2, y2 = (Variable(x2), Variable(y2))
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
loss_func = torch.nn.MSELoss()
for t in range(5000):
    prediction1, prediction2 = net(x1, x2)
    loss1 = loss_func(prediction1, y1)
    loss2 = loss_func(prediction2, y2)
    loss = loss1 + loss2 # 重点!
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if t % 100 == 0:
       print('Loss1 = %.4f' % loss1.data,'Loss2 = %.4f' % loss2.data,)

至此搞定!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中readline判断文件读取结束的方法
Nov 08 Python
python实现中文输出的两种方法
May 09 Python
Python制作豆瓣图片的爬虫
Dec 28 Python
django admin添加数据自动记录user到表中的实现方法
Jan 05 Python
详解如何在python中读写和存储matlab的数据文件(*.mat)
Feb 24 Python
pandas对指定列进行填充的方法
Apr 11 Python
对python中raw_input()和input()的用法详解
Apr 22 Python
解决Tensorflow使用pip安装后没有model目录的问题
Jun 13 Python
python 分离文件名和路径以及分离文件名和后缀的方法
Oct 21 Python
win10环境下配置vscode python开发环境的教程详解
Oct 16 Python
运行Python编写的程序方法实例
Oct 21 Python
matplotlib交互式数据光标mpldatacursor的实现
Feb 03 Python
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 #Python
pytorch锁死在dataloader(训练时卡死)
Python趣味爬虫之用Python实现智慧校园一键评教
Pytorch 如何加速Dataloader提升数据读取速度
You might like
PHP 输出简单动态WAP页面
2009/06/09 PHP
PHP面向对象的进阶学习(抽像类、接口、final、类常量)
2012/05/07 PHP
PHP异常Parse error: syntax error, unexpected T_VAR错误解决方法
2014/05/06 PHP
php文件夹的创建与删除方法
2015/01/24 PHP
PHP实现删除字符串中任何字符的函数
2015/08/11 PHP
Zend Framework入门教程之Zend_Mail用法示例
2016/12/08 PHP
php mysql实现mysql_select_db选择数据库
2016/12/30 PHP
php从数据库中读取特定的行(实例)
2017/06/02 PHP
基于Laravel实现的用户动态模块开发
2017/09/21 PHP
PHP设计模式之PHP迭代器模式讲解
2019/03/22 PHP
javascript hashtable实现代码
2009/10/13 Javascript
JavaScript Eval 函数使用
2010/03/23 Javascript
在IE 浏览器中使用 jquery的fadeIn() 效果 英文字符字体加粗
2011/06/02 Javascript
通过location.replace禁止浏览器后退防止重复提交
2014/09/04 Javascript
Javascript基础教程之数据类型 (字符串 String)
2015/01/18 Javascript
基于JS2Image实现圣诞树代码
2015/12/24 Javascript
JavaScript第一篇之实现按钮全选、功能
2016/08/21 Javascript
flag和jq on 的绑定多个对象和方法(必看)
2017/02/27 Javascript
js实现简单的二级联动效果
2017/03/09 Javascript
Vue非父子组件通信详解
2017/06/12 Javascript
荐书|您有一份JavaScript书单待签收
2017/07/21 Javascript
深入浅析ES6 Class 中的 super 关键字
2017/10/20 Javascript
微信小程序如何播放腾讯视频的实现
2019/09/20 Javascript
Openlayers实现地图全屏显示
2020/09/28 Javascript
[02:09]2018DOTA2亚洲邀请赛TNC赛前采访
2018/04/04 DOTA
python在控制台输出进度条的方法
2015/06/20 Python
Python3操作SQL Server数据库(实例讲解)
2017/10/21 Python
python图形工具turtle绘制国际象棋棋盘
2019/05/23 Python
支持IE8的纯css3开发的响应式设计动画菜单教程
2014/11/05 HTML / CSS
CSS3实现瀑布流布局与无限加载图片相册的实例代码
2016/12/22 HTML / CSS
澳大利亚领先的运动鞋商店:Hype DC
2018/03/31 全球购物
澳大利亚家用电器在线商店:Billy Guyatts
2020/05/05 全球购物
法律专业求职信
2014/05/24 职场文书
党课心得体会范文
2014/09/09 职场文书
公司离职证明标准样本
2014/10/05 职场文书
MySql新手入门的基本操作汇总
2021/05/13 MySQL