使用Pytorch实现two-head(多输出)模型的操作


Posted in Python onMay 28, 2021

如何使用Pytorch实现two-head(多输出)模型

1. two-head模型定义

先放一张我要实现的模型结构图:

使用Pytorch实现two-head(多输出)模型的操作

如上图,就是一个two-head模型,也是一个但输入多输出模型。该模型的特点是输入一个x和一个t,h0和h1中只有一个会输出,所以可能这不算是一个典型的多输出模型。

2.实现所遇到的困难 一开始的想法:

这不是很简单嘛,做一个判断不就完了,t=0时模型为前半段加h0,t=1时模型为前半段加h1。但实现的时候傻眼了,发现在真正前向传播的时候t是一个tensor,有0有1,没法儿进行判断。

灵机一动,又生一法:把这个模型变为三个模型,前半段是一个模型(r),后面的h0和h1分别为另两个模型。把数据集按t=0和1分开,分别训练两个模型:r+h0和r+h1。

但是后来搜如何进行模型串联,发现极为麻烦。

3.解决方案

后来在pytorch的官方社区中看到一个极为简单的方法:

(1) 按照一般的多输出模型进行实现,代码如下:

def forward(self, x):
        #三层的表示层
        x = F.elu(self.fcR1(x))
        x = F.elu(self.fcR2(x))
        x = F.elu(self.fcR3(x))
		#two-head,两个head分别进行输出
        y0 = F.elu(self.fcH01(x))
        y0 = F.elu(self.fcH02(y0))
        y0 = F.elu(self.fcH03(y0))
        y1 = F.elu(self.fcH11(x))
        y1 = F.elu(self.fcH12(y1))
        y1 = F.elu(self.fcH13(y1))
        return y0, y1

这样就相当实现了一个多输出模型,一个x同时输出y0和y1.

训练的时候分别训练,也即分别建loss,代码如下:

f_out_y0, _ = net(x0)
            _, f_out_y1 = net(x1)
            #实例化损失函数
            criterion0 = Loss()
            criterion1 = Loss()
            loss0 = criterion0(f_y0, f_out_y0, w0)
            loss1 = criterion1(f_y1, f_out_y1, w1)
            print(loss0.item(), loss1.item())
            #对网络参数进行初始化
            optimizer.zero_grad()
            loss0.backward()
            loss1.backward()
            #对网络的参数进行更新
            optimizer.step()

先把x按t=0和t=1分为x0和x1,然后分别送入进行训练。这样就实现了一个two-head模型。

4.后记

我自以为多输出模型可以分为以下两类:

多个输出不同时获得,如本文情况。

多个输出同时获得。

多输出不同时获得的解决方法上文已说明。多输出同时获得则可以通过把y0和y1拼接起来一起输出来实现。

补充:PyTorch 多输入多输出模型构建

本篇教程基于 PyTorch 1.5版本

直接上代码!

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.distributed as dist
import torch.utils.data as data_utils
class Net(nn.Module):
    def __init__(self, n_input, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden1 = nn.Linear(n_input, n_hidden)
        self.hidden2 = nn.Linear(n_hidden, n_hidden)
        self.predict1 = nn.Linear(n_hidden*2, n_output)
        self.predict2 = nn.Linear(n_hidden*2, n_output)
    def forward(self, input1, input2): # 多输入!!!
        out01 = self.hidden1(input1)
        out02 = torch.relu(out01)
        out03 = self.hidden2(out02)
        out04 = torch.sigmoid(out03)
        out11 = self.hidden1(input2)
        out12 = torch.relu(out11)
        out13 = self.hidden2(out12)
        out14 = torch.sigmoid(out13)
        out = torch.cat((out04, out14), dim=1) # 模型层拼合!!!当然你的模型中可能不需要~
 
        out1 = self.predict1(out)
        out2 = self.predict2(out)
        return out1, out2 # 多输出!!!
net = Net(1, 20, 1)
x1 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # 请不要关心这里,随便弄一个数据,为了说明问题而已
y1 = x1.pow(3)+0.1*torch.randn(x1.size())
x2 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y2 = x2.pow(3)+0.1*torch.randn(x2.size())
x1, y1 = (Variable(x1), Variable(y1))
x2, y2 = (Variable(x2), Variable(y2))
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
loss_func = torch.nn.MSELoss()
for t in range(5000):
    prediction1, prediction2 = net(x1, x2)
    loss1 = loss_func(prediction1, y1)
    loss2 = loss_func(prediction2, y2)
    loss = loss1 + loss2 # 重点!
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if t % 100 == 0:
       print('Loss1 = %.4f' % loss1.data,'Loss2 = %.4f' % loss2.data,)

至此搞定!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用poplib模块和smtplib模块收发电子邮件的教程
Jul 02 Python
django 开发忘记密码通过邮箱找回功能示例
Apr 17 Python
python bmp转换为jpg 并删除原图的方法
Oct 25 Python
python re正则匹配网页中图片url地址的方法
Dec 20 Python
Python实战购物车项目的实现参考
Feb 20 Python
Python使用微信itchat接口实现查看自己微信的信息功能详解
Aug 22 Python
Django继承自带user表并重写的例子
Nov 18 Python
opencv3/Python 稠密光流calcOpticalFlowFarneback详解
Dec 11 Python
python实现ssh及sftp功能(实例代码)
Mar 16 Python
python利用while求100内的整数和方式
Nov 07 Python
Python  lambda匿名函数和三元运算符
Apr 19 Python
Django数据库(SQlite)基本入门使用教程
Jul 07 Python
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 #Python
pytorch锁死在dataloader(训练时卡死)
Python趣味爬虫之用Python实现智慧校园一键评教
Pytorch 如何加速Dataloader提升数据读取速度
You might like
PHP源代码数组统计count分析
2011/08/02 PHP
PHP 面向对象程序设计(oop)学习笔记 (五) - PHP 命名空间
2014/06/12 PHP
php+resumablejs实现的分块上传 断点续传功能示例
2017/04/18 PHP
PHP命令空间namespace及use的用法小结
2017/11/27 PHP
PHP判断函数是否被定义的方法
2019/06/21 PHP
extJs 下拉框联动实现代码
2010/04/09 Javascript
document.getElementById的简写方式(获取id对象的简略写法)
2010/09/10 Javascript
验证码按回车不变解决方法
2013/03/29 Javascript
JavaScript中Number.NEGATIVE_INFINITY值的使用详解
2015/06/05 Javascript
jQuery Validate初步体验(一)
2015/12/12 Javascript
jQuery点击其他地方时菜单消失的实现方法
2016/04/22 Javascript
jquery+ajax实现直接提交表单实例分析
2016/06/17 Javascript
微信小程序 封装http请求实例详解
2017/01/16 Javascript
老生常谈ES6中的类
2017/07/31 Javascript
基于React+Redux的SSR实现方法
2018/07/03 Javascript
nodejs 使用nodejs-websocket模块实现点对点实时通讯
2018/11/28 NodeJs
微信小程序实现元素渐入渐出动画效果封装方法
2019/05/18 Javascript
JAVA面试题 static关键字详解
2019/07/16 Javascript
jquery分页优化操作实例分析
2019/08/23 jQuery
使用p5.js临摹动态图片
2019/11/04 Javascript
vue中使用WX-JSSDK的两种方法(推荐)
2020/01/18 Javascript
Python中利用sorted()函数排序的简单教程
2015/04/27 Python
python如何将图片转换为字符图片
2020/08/19 Python
Python3实现的简单工资管理系统示例
2019/03/12 Python
Python调用shell命令常用方法(4种)
2020/05/11 Python
使用html5+css3来实现slider切换效果告别javascript+css
2013/01/08 HTML / CSS
如何在发生故障的节点上重新安装 SQL Server
2013/03/14 面试题
初三学生评语大全
2014/04/24 职场文书
2014工程部年度工作总结
2014/12/17 职场文书
圆明园观后感
2015/06/03 职场文书
尼克胡哲观后感
2015/06/08 职场文书
考研经验交流会策划书
2015/11/02 职场文书
幽默口才训练经典句子(48句)
2019/08/19 职场文书
深入理解CSS 中 transform matrix矩阵变换问题
2021/08/30 HTML / CSS
Python 视频画质增强
2022/04/28 Python
springboot为异步任务规划自定义线程池的实现
2022/06/14 Java/Android