使用Pytorch训练two-head网络的操作


Posted in Python onMay 28, 2021

以前的训练方法:

之前是把两个head分开进行训练的,因此每一轮训练先要对一个batch的数据进行划分,然后再分别训练两个头。代码如下:

f_out_y0, _ = net(x0)
            _, f_out_y1 = net(x1)
            #实例化损失函数
            criterion0 = Loss()
            criterion1 = Loss()
            loss0 = criterion0(f_y0, f_out_y0, w0)
            loss1 = criterion1(f_y1, f_out_y1, w1)
            print(loss0.item(), loss1.item())
            #对网络参数进行初始化
            optimizer.zero_grad()
            loss0.backward()
            loss1.backward()
            #对网络的参数进行更新
            optimizer.step()

但是在实际操作中想到那这样的话岂不是每次都先使用t=0的数据训练公共的表示层,再使用t=1的数据去训练。这样会不会使表示层产生bias呢?且这样两步训练也很麻烦。

修改后的方法

使用之前训练方法其实还是对神经网络的训练的机理不清楚。事实上,在计算loss的时候每个数据点的梯度都是单独计算的。

因此完全可以把网络前向传播得到结果按之前的顺序拼接起来后再进行梯度的反向传播,这样就可以只进行一步训练,且不会出现训练先后的偏差。

代码如下:

f_out_y0, cf_out_y0 = net(x0)
            cf_out_y1, f_out_y1 = net(x1)
            #按照t=0和t=1的索引拼接向量
            y_pred = torch.zeros([len(x), 1])
            y_pred[index0] = f_out_y0
            y_pred[index1] = f_out_y1
            
   criterion = Loss()
            loss = criterion(f_y, y_pred, w) + 0.01 * (l2_regularization0 + l2_regularization1)
            #print(loss.item())
            viz.line([float(loss)], [epoch], win='train_loss', update='append')
            optimizer.zero_grad()
            loss.backward()
            #对网络的参数进行更新
            optimizer.step()

总结

two-head网络前向传播得到结果的时候是分开得到的,训练的时候通过拼接预测结果可以实现一次训练。

补充:Pytorch训练网络的一般步骤

如下所示:

import torch 
print(torch.tensor([1,2,3],dtype=torch.float))#将一个列表强制转换为torch.Tensor类型
print(torch.randn(5,3))#生成torch.Tensor类型的5X3的随机数

1、构建模型

2、定义一个损失函数

3、定义一个优化器

4、将训练数据带入模型得到预测值

5、将梯度清零

6、获得损失

7、进行优化

import torch
from torch.autograd import Variable
 
#初步认识构建Tensor数据
def one():
    print(torch.tensor([1,2,3],dtype=torch.float))#将一个列表强制转换为torch.Tensor类型
    print(torch.randn(5,3))#生成torch.Tensor类型的5X3的随机数
    print(torch.zeros((2,3)))#生成一个2X3的全零矩阵
    print(torch.ones((2,3)))#生成一个2X3的全一矩阵
    a = torch.randn((2,3))
    b = a.numpy()#将一个torch.Tensor转换为numpy
    c = torch.from_numpy(b)#将numpy转换为Tensor
    print(a)
    print(b)
    print(c)
 
#使用Variable自动求导
def two():
    # 构建Variable
    x = Variable(torch.Tensor([1, 2, 3]), requires_grad=True)
    w = Variable(torch.Tensor([4, 5, 6]), requires_grad=True)
    b = Variable(torch.Tensor([7, 8, 9]), requires_grad=True)
    # 函数等式
    y = w * x ** 2 + b
    # 使用梯度下降计算各变量的偏导数
    y.backward(torch.Tensor([1, 1, 1]))
    print(x.grad)
    print(w.grad)
    print(b.grad)

线性回归例子:

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
 
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = 3*x+10+torch.rand(x.size())
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.Linear = nn.Linear(1,1)
    def forward(self,x):
        return self.Linear(x)
model = LinearRegression()
Loss = nn.MSELoss()
Opt = torch.optim.SGD(model.parameters(),lr=0.01)
for i in range(1000):
    inputs = Variable(x)
    targets = Variable(y)
    outputs = model(inputs)
    loss = Loss(outputs,targets)
    Opt.zero_grad()
    loss.backward()
    Opt.step()
model.eval()
predict = model(Variable(x))
plt.plot(x.numpy(),y.numpy(),'ro')
plt.plot(x.numpy(),predict.data.numpy())
plt.show()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python多线程抓取天涯帖子内容示例
Apr 03 Python
一个超级简单的python web程序
Sep 11 Python
Python中asyncore异步模块的用法及实现httpclient的实例
Jun 28 Python
Python创建对称矩阵的方法示例【基于numpy模块】
Oct 12 Python
Python实现PS图像调整黑白效果示例
Jan 25 Python
python指定写入文件时的编码格式方法
Jun 07 Python
python实现反转部分单向链表
Sep 27 Python
python实现连连看辅助之图像识别延伸
Jul 17 Python
PYTHON发送邮件YAGMAIL的简单实现解析
Oct 28 Python
python 浅谈serial与stm32通信的编码问题
Dec 18 Python
python GUI库图形界面开发之PyQt5线程类QThread详细使用方法
Feb 26 Python
Python多线程多进程实例对比解析
Mar 12 Python
使用Python的开发框架Brownie部署以太坊智能合约
使用Pytorch实现two-head(多输出)模型的操作
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 #Python
pytorch锁死在dataloader(训练时卡死)
You might like
用PHP连接MySQL代码的参数说明
2008/06/07 PHP
php存储过程调用实例代码
2013/02/03 PHP
深入理解ob_flush和flush的区别(ob_flush()与flush()使用方法)
2013/02/06 PHP
php实现通过ftp上传文件
2015/06/19 PHP
PHP输出XML格式数据的方法总结
2017/02/08 PHP
PHP基于正则批量替换Img中src内容实现获取缩略图的功能示例
2017/06/07 PHP
thinkphp5框架扩展redis类方法示例
2019/05/06 PHP
javascript实现TreeView 无刷新展开的实例代码
2013/07/13 Javascript
JavaScript移除数组元素减少长度的方法
2013/09/05 Javascript
JQuery表单验证插件EasyValidator用法分析
2014/11/15 Javascript
jQuery实现简洁的轮播图效果实例
2016/09/07 Javascript
AngularJS创建自定义指令的方法详解
2016/11/03 Javascript
原生JS实现跑马灯效果
2017/02/20 Javascript
protractor的安装与基本使用教程
2017/07/07 Javascript
layui使用form表单实现post请求页面跳转的方法
2019/09/14 Javascript
JavaScript This指向问题详解
2019/11/25 Javascript
JS面向对象之单选框实现
2020/01/17 Javascript
vue动态渲染svg、添加点击事件的实现
2020/03/13 Javascript
jQuery实现简单飞机大战
2020/07/05 jQuery
Python正则表达式教程之一:基础篇
2017/03/02 Python
对python多线程中Lock()与RLock()锁详解
2019/01/11 Python
python使用KNN算法识别手写数字
2019/04/25 Python
使用python实现男神女神颜值打分系统(推荐)
2019/10/31 Python
几个解决兼容IE6\7\8不支持html5标签的几个方法
2013/01/07 HTML / CSS
模范家庭事迹材料
2014/02/10 职场文书
春节超市活动方案
2014/08/14 职场文书
股东授权委托书范本
2014/09/13 职场文书
县政府领导班子四风问题对照检查材料思想汇报
2014/09/26 职场文书
上课迟到检讨书300字
2014/10/15 职场文书
个人先进事迹总结
2015/02/26 职场文书
2015年学校安全工作总结
2015/04/22 职场文书
评奖评优个人先进事迹材料
2015/11/04 职场文书
多属性、多分类MySQL模式设计
2021/04/05 MySQL
90行Python代码开发个人云盘应用
2021/04/20 Python
python脚本框架webpy模板控制结构
2021/11/20 Python
Java字符缓冲流BufferedWriter
2022/04/09 Java/Android