使用Pytorch实现two-head(多输出)模型的操作


Posted in Python onMay 28, 2021

如何使用Pytorch实现two-head(多输出)模型

1. two-head模型定义

先放一张我要实现的模型结构图:

使用Pytorch实现two-head(多输出)模型的操作

如上图,就是一个two-head模型,也是一个但输入多输出模型。该模型的特点是输入一个x和一个t,h0和h1中只有一个会输出,所以可能这不算是一个典型的多输出模型。

2.实现所遇到的困难 一开始的想法:

这不是很简单嘛,做一个判断不就完了,t=0时模型为前半段加h0,t=1时模型为前半段加h1。但实现的时候傻眼了,发现在真正前向传播的时候t是一个tensor,有0有1,没法儿进行判断。

灵机一动,又生一法:把这个模型变为三个模型,前半段是一个模型(r),后面的h0和h1分别为另两个模型。把数据集按t=0和1分开,分别训练两个模型:r+h0和r+h1。

但是后来搜如何进行模型串联,发现极为麻烦。

3.解决方案

后来在pytorch的官方社区中看到一个极为简单的方法:

(1) 按照一般的多输出模型进行实现,代码如下:

def forward(self, x):
        #三层的表示层
        x = F.elu(self.fcR1(x))
        x = F.elu(self.fcR2(x))
        x = F.elu(self.fcR3(x))
		#two-head,两个head分别进行输出
        y0 = F.elu(self.fcH01(x))
        y0 = F.elu(self.fcH02(y0))
        y0 = F.elu(self.fcH03(y0))
        y1 = F.elu(self.fcH11(x))
        y1 = F.elu(self.fcH12(y1))
        y1 = F.elu(self.fcH13(y1))
        return y0, y1

这样就相当实现了一个多输出模型,一个x同时输出y0和y1.

训练的时候分别训练,也即分别建loss,代码如下:

f_out_y0, _ = net(x0)
            _, f_out_y1 = net(x1)
            #实例化损失函数
            criterion0 = Loss()
            criterion1 = Loss()
            loss0 = criterion0(f_y0, f_out_y0, w0)
            loss1 = criterion1(f_y1, f_out_y1, w1)
            print(loss0.item(), loss1.item())
            #对网络参数进行初始化
            optimizer.zero_grad()
            loss0.backward()
            loss1.backward()
            #对网络的参数进行更新
            optimizer.step()

先把x按t=0和t=1分为x0和x1,然后分别送入进行训练。这样就实现了一个two-head模型。

4.后记

我自以为多输出模型可以分为以下两类:

多个输出不同时获得,如本文情况。

多个输出同时获得。

多输出不同时获得的解决方法上文已说明。多输出同时获得则可以通过把y0和y1拼接起来一起输出来实现。

补充:PyTorch 多输入多输出模型构建

本篇教程基于 PyTorch 1.5版本

直接上代码!

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.distributed as dist
import torch.utils.data as data_utils
class Net(nn.Module):
    def __init__(self, n_input, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden1 = nn.Linear(n_input, n_hidden)
        self.hidden2 = nn.Linear(n_hidden, n_hidden)
        self.predict1 = nn.Linear(n_hidden*2, n_output)
        self.predict2 = nn.Linear(n_hidden*2, n_output)
    def forward(self, input1, input2): # 多输入!!!
        out01 = self.hidden1(input1)
        out02 = torch.relu(out01)
        out03 = self.hidden2(out02)
        out04 = torch.sigmoid(out03)
        out11 = self.hidden1(input2)
        out12 = torch.relu(out11)
        out13 = self.hidden2(out12)
        out14 = torch.sigmoid(out13)
        out = torch.cat((out04, out14), dim=1) # 模型层拼合!!!当然你的模型中可能不需要~
 
        out1 = self.predict1(out)
        out2 = self.predict2(out)
        return out1, out2 # 多输出!!!
net = Net(1, 20, 1)
x1 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # 请不要关心这里,随便弄一个数据,为了说明问题而已
y1 = x1.pow(3)+0.1*torch.randn(x1.size())
x2 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y2 = x2.pow(3)+0.1*torch.randn(x2.size())
x1, y1 = (Variable(x1), Variable(y1))
x2, y2 = (Variable(x2), Variable(y2))
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
loss_func = torch.nn.MSELoss()
for t in range(5000):
    prediction1, prediction2 = net(x1, x2)
    loss1 = loss_func(prediction1, y1)
    loss2 = loss_func(prediction2, y2)
    loss = loss1 + loss2 # 重点!
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if t % 100 == 0:
       print('Loss1 = %.4f' % loss1.data,'Loss2 = %.4f' % loss2.data,)

至此搞定!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中进程和线程的区别详解
Oct 29 Python
python如何将图片转换为字符图片
Aug 19 Python
python实现人民币大写转换
Jun 20 Python
python中ASCII码和字符的转换方法
Jul 09 Python
在Python中定义一个常量的方法
Nov 10 Python
Python判断一个文件夹内哪些文件是图片的实例
Dec 07 Python
python 二维矩阵转三维矩阵示例
Nov 30 Python
python 解决tqdm模块不能单行显示的问题
Feb 19 Python
Python 之 Json序列化嵌套类方式
Feb 27 Python
python入门学习关于for else的特殊特性讲解
Nov 20 Python
python 多态 协议 鸭子类型详解
Nov 27 Python
聊聊Python String型列表求最值的问题
Jan 18 Python
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 #Python
pytorch锁死在dataloader(训练时卡死)
Python趣味爬虫之用Python实现智慧校园一键评教
Pytorch 如何加速Dataloader提升数据读取速度
You might like
提升PHP性能的21种方法介绍
2013/06/25 PHP
PHP中浮点数计算比较及取整不准确的解决方法
2015/01/09 PHP
ThinkPHP模板Volist标签嵌套循环输出多维数组的方法
2016/03/23 PHP
PHP简单预防sql注入的方法
2016/09/27 PHP
php array_map()函数实例用法
2021/03/03 PHP
JavaScript中valueOf函数与toString方法深入理解
2012/12/02 Javascript
js通过元素class名字获取元素集合的具体实现
2014/01/06 Javascript
node.js中的socket.io入门实例
2014/04/26 Javascript
js检测网络是否具体连接功能的代码
2014/05/23 Javascript
jQuery中filter()方法用法实例
2015/01/06 Javascript
jQuery鼠标经过方形图片切换成圆边效果代码分享
2015/08/20 Javascript
JQuery控制图片由中心点逐渐放大效果
2016/06/26 Javascript
AngularJS教程之简单应用程序示例
2016/08/16 Javascript
js图片延迟加载(Lazyload)三种实现方式
2017/03/01 Javascript
Bootstrap输入框组件简单实现代码
2017/03/06 Javascript
js实现文件上传功能 后台使用MultipartFile
2018/09/08 Javascript
JS与SQL方式随机生成高强度密码示例
2018/12/29 Javascript
基于element-ui封装可搜索的懒加载tree组件的实现
2020/05/22 Javascript
vue实现的多页面项目如何优化打包的步骤详解
2020/07/19 Javascript
python编程开发之textwrap文本样式处理技巧
2015/11/13 Python
Python爬虫包 BeautifulSoup  递归抓取实例详解
2017/01/28 Python
利用python计算windows全盘文件md5值的脚本
2019/07/27 Python
python点击鼠标获取坐标(Graphics)
2019/08/10 Python
Python Django实现layui风格+django分页功能的例子
2019/08/29 Python
关于Pytorch的MLP模块实现方式
2020/01/07 Python
Python3.7实现验证码登录方式代码实例
2020/02/14 Python
python virtualenv虚拟环境配置与使用教程详解
2020/07/13 Python
Python 字典一个键对应多个值的方法
2020/09/29 Python
联想德国官网:Lenovo Germany
2018/07/04 全球购物
澳大利亚墨尔本的在线时装店:LORETA
2018/09/14 全球购物
正宗的日本零食和糖果订阅盒:Bokksu
2019/11/21 全球购物
PHP面试题及答案二
2015/05/23 面试题
诺思信科技(南京)有限公司.NET笔试题答案
2013/07/06 面试题
入党心得体会
2019/06/20 职场文书
创业计划之特色精品店
2019/08/12 职场文书
python脚本框架webpy的url映射详解
2021/11/20 Python