使用Pytorch训练two-head网络的操作


Posted in Python onMay 28, 2021

以前的训练方法:

之前是把两个head分开进行训练的,因此每一轮训练先要对一个batch的数据进行划分,然后再分别训练两个头。代码如下:

f_out_y0, _ = net(x0)
            _, f_out_y1 = net(x1)
            #实例化损失函数
            criterion0 = Loss()
            criterion1 = Loss()
            loss0 = criterion0(f_y0, f_out_y0, w0)
            loss1 = criterion1(f_y1, f_out_y1, w1)
            print(loss0.item(), loss1.item())
            #对网络参数进行初始化
            optimizer.zero_grad()
            loss0.backward()
            loss1.backward()
            #对网络的参数进行更新
            optimizer.step()

但是在实际操作中想到那这样的话岂不是每次都先使用t=0的数据训练公共的表示层,再使用t=1的数据去训练。这样会不会使表示层产生bias呢?且这样两步训练也很麻烦。

修改后的方法

使用之前训练方法其实还是对神经网络的训练的机理不清楚。事实上,在计算loss的时候每个数据点的梯度都是单独计算的。

因此完全可以把网络前向传播得到结果按之前的顺序拼接起来后再进行梯度的反向传播,这样就可以只进行一步训练,且不会出现训练先后的偏差。

代码如下:

f_out_y0, cf_out_y0 = net(x0)
            cf_out_y1, f_out_y1 = net(x1)
            #按照t=0和t=1的索引拼接向量
            y_pred = torch.zeros([len(x), 1])
            y_pred[index0] = f_out_y0
            y_pred[index1] = f_out_y1
            
   criterion = Loss()
            loss = criterion(f_y, y_pred, w) + 0.01 * (l2_regularization0 + l2_regularization1)
            #print(loss.item())
            viz.line([float(loss)], [epoch], win='train_loss', update='append')
            optimizer.zero_grad()
            loss.backward()
            #对网络的参数进行更新
            optimizer.step()

总结

two-head网络前向传播得到结果的时候是分开得到的,训练的时候通过拼接预测结果可以实现一次训练。

补充:Pytorch训练网络的一般步骤

如下所示:

import torch 
print(torch.tensor([1,2,3],dtype=torch.float))#将一个列表强制转换为torch.Tensor类型
print(torch.randn(5,3))#生成torch.Tensor类型的5X3的随机数

1、构建模型

2、定义一个损失函数

3、定义一个优化器

4、将训练数据带入模型得到预测值

5、将梯度清零

6、获得损失

7、进行优化

import torch
from torch.autograd import Variable
 
#初步认识构建Tensor数据
def one():
    print(torch.tensor([1,2,3],dtype=torch.float))#将一个列表强制转换为torch.Tensor类型
    print(torch.randn(5,3))#生成torch.Tensor类型的5X3的随机数
    print(torch.zeros((2,3)))#生成一个2X3的全零矩阵
    print(torch.ones((2,3)))#生成一个2X3的全一矩阵
    a = torch.randn((2,3))
    b = a.numpy()#将一个torch.Tensor转换为numpy
    c = torch.from_numpy(b)#将numpy转换为Tensor
    print(a)
    print(b)
    print(c)
 
#使用Variable自动求导
def two():
    # 构建Variable
    x = Variable(torch.Tensor([1, 2, 3]), requires_grad=True)
    w = Variable(torch.Tensor([4, 5, 6]), requires_grad=True)
    b = Variable(torch.Tensor([7, 8, 9]), requires_grad=True)
    # 函数等式
    y = w * x ** 2 + b
    # 使用梯度下降计算各变量的偏导数
    y.backward(torch.Tensor([1, 1, 1]))
    print(x.grad)
    print(w.grad)
    print(b.grad)

线性回归例子:

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
 
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = 3*x+10+torch.rand(x.size())
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.Linear = nn.Linear(1,1)
    def forward(self,x):
        return self.Linear(x)
model = LinearRegression()
Loss = nn.MSELoss()
Opt = torch.optim.SGD(model.parameters(),lr=0.01)
for i in range(1000):
    inputs = Variable(x)
    targets = Variable(y)
    outputs = model(inputs)
    loss = Loss(outputs,targets)
    Opt.zero_grad()
    loss.backward()
    Opt.step()
model.eval()
predict = model(Variable(x))
plt.plot(x.numpy(),y.numpy(),'ro')
plt.plot(x.numpy(),predict.data.numpy())
plt.show()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python读取图片EXIF信息类库介绍和使用实例
Jul 10 Python
Python遍历指定文件及文件夹的方法
May 09 Python
python类的方法属性与方法属性的动态绑定代码详解
Dec 27 Python
Python实现的视频播放器功能完整示例
Feb 01 Python
Python 绘图库 Matplotlib 入门教程
Apr 19 Python
详谈Pandas中iloc和loc以及ix的区别
Jun 08 Python
python实现黑客字幕雨效果
Jun 21 Python
使用Python测试Ping主机IP和某端口是否开放的实例
Dec 17 Python
Python实现Word表格转成Excel表格的示例代码
Apr 16 Python
Python 按比例获取样本数据或执行任务的实现代码
Dec 03 Python
PyQt5中QSpinBox计数器的实现
Jan 18 Python
conda安装tensorflow和conda常用命令小结
Feb 20 Python
使用Python的开发框架Brownie部署以太坊智能合约
使用Pytorch实现two-head(多输出)模型的操作
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 #Python
pytorch锁死在dataloader(训练时卡死)
You might like
用js进行url编码后用php反解以及用php实现js的escape功能函数总结
2010/02/08 PHP
php实现的CSS更新类实例
2014/09/22 PHP
Javascript-Mozilla和IE中的一个函数直接量的问题
2007/01/09 Javascript
jQuery TextBox自动完成条
2009/07/22 Javascript
javascript+xml实现简单图片轮换(只支持IE)
2012/12/23 Javascript
jQuery淡入淡出元素让其效果更为生动
2014/09/01 Javascript
jQuery使用height()获取高度需要注意的地方
2014/12/13 Javascript
javascript判断css3动画结束 css3动画结束的回调函数
2015/03/10 Javascript
JavaScript通过setTimeout实时显示当前时间的方法
2015/04/16 Javascript
JavaScript知识点整理
2015/12/09 Javascript
Bootstrap学习笔记之js组件(4)
2016/06/12 Javascript
js表单登陆验证示例
2016/10/19 Javascript
JS表单数据验证的正则表达式(常用)
2017/02/18 Javascript
关于Stream和Buffer的相互转换详解
2017/07/26 Javascript
bootstrap-Treeview实现级联勾选
2017/11/23 Javascript
Bootstrap实现翻页效果
2017/11/27 Javascript
Bootstrap Table列宽拖动的方法
2018/08/15 Javascript
JavaScript跳出循环的三种方法(break, return, continue)
2019/07/30 Javascript
jQuery zTree如何改变指定节点文本样式
2020/10/16 jQuery
使用Python和OpenCV检测图像中的物体并将物体裁剪下来
2019/10/30 Python
详解pyinstaller生成exe的闪退问题解决方案
2020/06/19 Python
python 三种方法实现对Excel表格的读写
2020/11/19 Python
HTML5注册表单的自动聚焦与占位文本示例代码
2013/07/19 HTML / CSS
美国棒球装备和用品商店:Baseball Savings
2018/06/09 全球购物
什么是岗位职责
2013/11/12 职场文书
英文导游欢迎词
2014/01/11 职场文书
迎新晚会邀请函
2014/02/01 职场文书
职工运动会邀请函
2014/02/02 职场文书
军训考核自我鉴定
2014/02/13 职场文书
付款委托书范本
2014/04/04 职场文书
校园标语大全
2014/06/19 职场文书
乡镇务虚会发言材料
2014/10/20 职场文书
2016元旦文艺汇演主持词
2015/07/06 职场文书
python 下载文件的几种方式分享
2021/04/07 Python
Python中如何处理常见报错
2022/01/18 Python
Android中View.post和Handler.post的关系
2022/06/05 Java/Android