使用Pytorch实现two-head(多输出)模型的操作


Posted in Python onMay 28, 2021

如何使用Pytorch实现two-head(多输出)模型

1. two-head模型定义

先放一张我要实现的模型结构图:

使用Pytorch实现two-head(多输出)模型的操作

如上图,就是一个two-head模型,也是一个但输入多输出模型。该模型的特点是输入一个x和一个t,h0和h1中只有一个会输出,所以可能这不算是一个典型的多输出模型。

2.实现所遇到的困难 一开始的想法:

这不是很简单嘛,做一个判断不就完了,t=0时模型为前半段加h0,t=1时模型为前半段加h1。但实现的时候傻眼了,发现在真正前向传播的时候t是一个tensor,有0有1,没法儿进行判断。

灵机一动,又生一法:把这个模型变为三个模型,前半段是一个模型(r),后面的h0和h1分别为另两个模型。把数据集按t=0和1分开,分别训练两个模型:r+h0和r+h1。

但是后来搜如何进行模型串联,发现极为麻烦。

3.解决方案

后来在pytorch的官方社区中看到一个极为简单的方法:

(1) 按照一般的多输出模型进行实现,代码如下:

def forward(self, x):
        #三层的表示层
        x = F.elu(self.fcR1(x))
        x = F.elu(self.fcR2(x))
        x = F.elu(self.fcR3(x))
		#two-head,两个head分别进行输出
        y0 = F.elu(self.fcH01(x))
        y0 = F.elu(self.fcH02(y0))
        y0 = F.elu(self.fcH03(y0))
        y1 = F.elu(self.fcH11(x))
        y1 = F.elu(self.fcH12(y1))
        y1 = F.elu(self.fcH13(y1))
        return y0, y1

这样就相当实现了一个多输出模型,一个x同时输出y0和y1.

训练的时候分别训练,也即分别建loss,代码如下:

f_out_y0, _ = net(x0)
            _, f_out_y1 = net(x1)
            #实例化损失函数
            criterion0 = Loss()
            criterion1 = Loss()
            loss0 = criterion0(f_y0, f_out_y0, w0)
            loss1 = criterion1(f_y1, f_out_y1, w1)
            print(loss0.item(), loss1.item())
            #对网络参数进行初始化
            optimizer.zero_grad()
            loss0.backward()
            loss1.backward()
            #对网络的参数进行更新
            optimizer.step()

先把x按t=0和t=1分为x0和x1,然后分别送入进行训练。这样就实现了一个two-head模型。

4.后记

我自以为多输出模型可以分为以下两类:

多个输出不同时获得,如本文情况。

多个输出同时获得。

多输出不同时获得的解决方法上文已说明。多输出同时获得则可以通过把y0和y1拼接起来一起输出来实现。

补充:PyTorch 多输入多输出模型构建

本篇教程基于 PyTorch 1.5版本

直接上代码!

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.distributed as dist
import torch.utils.data as data_utils
class Net(nn.Module):
    def __init__(self, n_input, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden1 = nn.Linear(n_input, n_hidden)
        self.hidden2 = nn.Linear(n_hidden, n_hidden)
        self.predict1 = nn.Linear(n_hidden*2, n_output)
        self.predict2 = nn.Linear(n_hidden*2, n_output)
    def forward(self, input1, input2): # 多输入!!!
        out01 = self.hidden1(input1)
        out02 = torch.relu(out01)
        out03 = self.hidden2(out02)
        out04 = torch.sigmoid(out03)
        out11 = self.hidden1(input2)
        out12 = torch.relu(out11)
        out13 = self.hidden2(out12)
        out14 = torch.sigmoid(out13)
        out = torch.cat((out04, out14), dim=1) # 模型层拼合!!!当然你的模型中可能不需要~
 
        out1 = self.predict1(out)
        out2 = self.predict2(out)
        return out1, out2 # 多输出!!!
net = Net(1, 20, 1)
x1 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # 请不要关心这里,随便弄一个数据,为了说明问题而已
y1 = x1.pow(3)+0.1*torch.randn(x1.size())
x2 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y2 = x2.pow(3)+0.1*torch.randn(x2.size())
x1, y1 = (Variable(x1), Variable(y1))
x2, y2 = (Variable(x2), Variable(y2))
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
loss_func = torch.nn.MSELoss()
for t in range(5000):
    prediction1, prediction2 = net(x1, x2)
    loss1 = loss_func(prediction1, y1)
    loss2 = loss_func(prediction2, y2)
    loss = loss1 + loss2 # 重点!
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if t % 100 == 0:
       print('Loss1 = %.4f' % loss1.data,'Loss2 = %.4f' % loss2.data,)

至此搞定!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中itertools模块用法详解
Sep 25 Python
Linux下使用python调用top命令获得CPU利用率
Mar 10 Python
Python实现批量读取word中表格信息的方法
Jul 30 Python
ubuntu系统下 python链接mysql数据库的方法
Jan 09 Python
通过Python实现自动填写调查问卷
Sep 06 Python
Python3使用正则表达式爬取内涵段子示例
Apr 22 Python
Python使用re模块实现信息筛选的方法
Apr 29 Python
一百多行python代码实现抢票助手
Sep 25 Python
Python爬取智联招聘数据分析师岗位相关信息的方法
Aug 13 Python
pycharm内无法import已安装的模块问题解决
Feb 12 Python
Python网页解析器使用实例详解
May 30 Python
python元组打包和解包过程详解
Aug 02 Python
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 #Python
pytorch锁死在dataloader(训练时卡死)
Python趣味爬虫之用Python实现智慧校园一键评教
Pytorch 如何加速Dataloader提升数据读取速度
You might like
我的论坛源代码(二)
2006/10/09 PHP
PHP session会话的安全性分析
2011/09/08 PHP
DOM 脚本编程中的兄弟节点
2009/10/31 Javascript
extjs 初始化checkboxgroup值的代码
2011/09/21 Javascript
Javascript浅谈之引用类型
2013/12/18 Javascript
JavaScript操作DOM元素的childNodes和children区别
2015/04/01 Javascript
jQuery实现鼠标经过图片变亮其他变暗效果
2015/05/08 Javascript
z-blog SyntaxHighlighter 长代码无法换行解决办法(基于jquery)
2015/11/18 Javascript
Angular2内置指令NgFor和NgIf详解
2016/08/03 Javascript
基于vue.js实现侧边菜单栏
2017/03/20 Javascript
基于vue+ bootstrap实现图片上传图片展示功能
2017/05/17 Javascript
chorme 浏览器记住密码后input黄色背景处理方法(两种)
2017/11/22 Javascript
vue树形结构获取键值的方法示例
2018/06/21 Javascript
使用jQuery给Table动态增加行、清空table的方法
2018/09/05 jQuery
Vue中点击active并第一个默认选中功能的实现
2020/02/24 Javascript
Postman内建变量常用方法实例解析
2020/07/28 Javascript
[34:41]夜魇凡尔赛茶话会 第二期02:你画我猜
2021/03/11 DOTA
phpsir 开发 一个检测百度关键字网站排名的python 程序
2009/09/17 Python
使用python实现rsa算法代码
2016/02/17 Python
简单谈谈python的反射机制
2016/06/28 Python
Python基础教程之tcp socket编程详解及简单实例
2017/02/23 Python
Jupyter安装nbextensions,启动提示没有nbextensions库
2020/04/23 Python
Python判断telnet通不通的实例
2019/01/26 Python
Python实现的列表排序、反转操作示例
2019/03/13 Python
python多线程同步实例教程
2019/08/11 Python
在python中计算ssim的方法(与Matlab结果一致)
2019/12/19 Python
Python参数传递机制传值和传引用原理详解
2020/05/22 Python
python3中数组逆序输出方法
2020/12/01 Python
完美解决Pycharm中matplotlib画图中文乱码问题
2021/01/11 Python
Abe’s of Maine:自1979以来销售相机和电子产品
2016/11/21 全球购物
物流创业计划书
2014/02/01 职场文书
2015年求职自荐信范文
2015/03/04 职场文书
学校财务管理制度
2015/08/04 职场文书
学习焦裕禄先进事迹心得体会
2016/01/23 职场文书
七年级话题作文之执着
2019/11/19 职场文书
微信小程序实现轮播图指示器
2022/06/25 Javascript