使用Pytorch实现two-head(多输出)模型的操作


Posted in Python onMay 28, 2021

如何使用Pytorch实现two-head(多输出)模型

1. two-head模型定义

先放一张我要实现的模型结构图:

使用Pytorch实现two-head(多输出)模型的操作

如上图,就是一个two-head模型,也是一个但输入多输出模型。该模型的特点是输入一个x和一个t,h0和h1中只有一个会输出,所以可能这不算是一个典型的多输出模型。

2.实现所遇到的困难 一开始的想法:

这不是很简单嘛,做一个判断不就完了,t=0时模型为前半段加h0,t=1时模型为前半段加h1。但实现的时候傻眼了,发现在真正前向传播的时候t是一个tensor,有0有1,没法儿进行判断。

灵机一动,又生一法:把这个模型变为三个模型,前半段是一个模型(r),后面的h0和h1分别为另两个模型。把数据集按t=0和1分开,分别训练两个模型:r+h0和r+h1。

但是后来搜如何进行模型串联,发现极为麻烦。

3.解决方案

后来在pytorch的官方社区中看到一个极为简单的方法:

(1) 按照一般的多输出模型进行实现,代码如下:

def forward(self, x):
        #三层的表示层
        x = F.elu(self.fcR1(x))
        x = F.elu(self.fcR2(x))
        x = F.elu(self.fcR3(x))
		#two-head,两个head分别进行输出
        y0 = F.elu(self.fcH01(x))
        y0 = F.elu(self.fcH02(y0))
        y0 = F.elu(self.fcH03(y0))
        y1 = F.elu(self.fcH11(x))
        y1 = F.elu(self.fcH12(y1))
        y1 = F.elu(self.fcH13(y1))
        return y0, y1

这样就相当实现了一个多输出模型,一个x同时输出y0和y1.

训练的时候分别训练,也即分别建loss,代码如下:

f_out_y0, _ = net(x0)
            _, f_out_y1 = net(x1)
            #实例化损失函数
            criterion0 = Loss()
            criterion1 = Loss()
            loss0 = criterion0(f_y0, f_out_y0, w0)
            loss1 = criterion1(f_y1, f_out_y1, w1)
            print(loss0.item(), loss1.item())
            #对网络参数进行初始化
            optimizer.zero_grad()
            loss0.backward()
            loss1.backward()
            #对网络的参数进行更新
            optimizer.step()

先把x按t=0和t=1分为x0和x1,然后分别送入进行训练。这样就实现了一个two-head模型。

4.后记

我自以为多输出模型可以分为以下两类:

多个输出不同时获得,如本文情况。

多个输出同时获得。

多输出不同时获得的解决方法上文已说明。多输出同时获得则可以通过把y0和y1拼接起来一起输出来实现。

补充:PyTorch 多输入多输出模型构建

本篇教程基于 PyTorch 1.5版本

直接上代码!

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.distributed as dist
import torch.utils.data as data_utils
class Net(nn.Module):
    def __init__(self, n_input, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden1 = nn.Linear(n_input, n_hidden)
        self.hidden2 = nn.Linear(n_hidden, n_hidden)
        self.predict1 = nn.Linear(n_hidden*2, n_output)
        self.predict2 = nn.Linear(n_hidden*2, n_output)
    def forward(self, input1, input2): # 多输入!!!
        out01 = self.hidden1(input1)
        out02 = torch.relu(out01)
        out03 = self.hidden2(out02)
        out04 = torch.sigmoid(out03)
        out11 = self.hidden1(input2)
        out12 = torch.relu(out11)
        out13 = self.hidden2(out12)
        out14 = torch.sigmoid(out13)
        out = torch.cat((out04, out14), dim=1) # 模型层拼合!!!当然你的模型中可能不需要~
 
        out1 = self.predict1(out)
        out2 = self.predict2(out)
        return out1, out2 # 多输出!!!
net = Net(1, 20, 1)
x1 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # 请不要关心这里,随便弄一个数据,为了说明问题而已
y1 = x1.pow(3)+0.1*torch.randn(x1.size())
x2 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y2 = x2.pow(3)+0.1*torch.randn(x2.size())
x1, y1 = (Variable(x1), Variable(y1))
x2, y2 = (Variable(x2), Variable(y2))
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
loss_func = torch.nn.MSELoss()
for t in range(5000):
    prediction1, prediction2 = net(x1, x2)
    loss1 = loss_func(prediction1, y1)
    loss2 = loss_func(prediction2, y2)
    loss = loss1 + loss2 # 重点!
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if t % 100 == 0:
       print('Loss1 = %.4f' % loss1.data,'Loss2 = %.4f' % loss2.data,)

至此搞定!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
c++生成dll使用python调用dll的方法
Jan 20 Python
Python 中迭代器与生成器实例详解
Mar 29 Python
K-近邻算法的python实现代码分享
Dec 09 Python
Python实现感知器模型、两层神经网络
Dec 19 Python
python实现对指定字符串补足固定长度倍数截断输出的方法
Nov 15 Python
python 读取鼠标点击坐标的实例
Dec 29 Python
Python自动化之数据驱动让你的脚本简洁10倍【推荐】
Jun 04 Python
在Django下测试与调试REST API的方法详解
Aug 29 Python
python 实现多线程下载m3u8格式视频并使用fmmpeg合并
Nov 15 Python
PyQt5-QDateEdit的简单使用操作
Jul 12 Python
Django模型验证器介绍与源码分析
Sep 08 Python
Pandas直接读取sql脚本的方法
Jan 21 Python
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 #Python
pytorch锁死在dataloader(训练时卡死)
Python趣味爬虫之用Python实现智慧校园一键评教
Pytorch 如何加速Dataloader提升数据读取速度
You might like
PHP中图片等比缩放的实例
2013/03/24 PHP
PHP网页游戏学习之Xnova(ogame)源码解读(一)
2014/06/23 PHP
跟我学Laravel之视图 & Response
2014/10/15 PHP
php实现微信公众平台账号自定义菜单类
2014/12/02 PHP
Apache启动报错No space left on device: AH00023该怎么解决
2015/10/16 PHP
PHP实现简单ajax Loading加载功能示例
2016/12/28 PHP
PHP自定义错误处理的方法分析
2018/12/19 PHP
firefox下对ajax的onreadystatechange的支持情况分析
2009/12/14 Javascript
jquery插件tytabs.jquery.min.js实现渐变TAB选项卡效果
2015/08/25 Javascript
详解JavaScript基本类型和引用类型
2015/12/09 Javascript
Jquery+ajax+JAVA(servlet)实现下拉菜单异步取值
2016/03/23 Javascript
AngularJS 支付倒计时功能实现思路
2017/06/05 Javascript
Vue实现美团app的影院推荐选座功能【推荐】
2018/08/29 Javascript
总结4个方面优化Vue项目
2019/02/11 Javascript
详解JavaScript函数callee、call、apply的区别
2019/03/08 Javascript
Python内置函数—vars的具体使用方法
2017/12/04 Python
python3爬取各类天气信息
2018/02/24 Python
异步任务队列Celery在Django中的使用方法
2018/06/07 Python
Python TestCase中的断言方法介绍
2019/05/02 Python
python matplotlib imshow热图坐标替换/映射实例
2020/03/14 Python
Python greenlet和gevent使用代码示例解析
2020/04/01 Python
Python读取配置文件(config.ini)以及写入配置文件
2020/04/08 Python
HTML5进度条特效
2014/12/18 HTML / CSS
欧洲最大的美妆零售网站:Feelunique
2017/01/14 全球购物
新西兰杂志订阅:isubscribe
2019/08/26 全球购物
座谈会主持词
2014/03/20 职场文书
绿色环保标语
2014/06/12 职场文书
生产助理岗位职责
2014/06/18 职场文书
纪检干部先进事迹材料
2014/08/23 职场文书
2014年财务经理工作总结
2014/12/08 职场文书
罚款通知怎么写
2015/04/22 职场文书
2015年工程师工作总结
2015/04/30 职场文书
钢铁是怎样炼成的读书笔记
2015/06/29 职场文书
圣诞晚会主持词
2015/07/01 职场文书
“爱眼护眼,提前预防近视”倡议书3篇
2019/10/30 职场文书
Javascript中Microtask和Macrotask鲜为人知的知识点
2022/04/02 Javascript