使用Pytorch训练two-head网络的操作


Posted in Python onMay 28, 2021

以前的训练方法:

之前是把两个head分开进行训练的,因此每一轮训练先要对一个batch的数据进行划分,然后再分别训练两个头。代码如下:

f_out_y0, _ = net(x0)
            _, f_out_y1 = net(x1)
            #实例化损失函数
            criterion0 = Loss()
            criterion1 = Loss()
            loss0 = criterion0(f_y0, f_out_y0, w0)
            loss1 = criterion1(f_y1, f_out_y1, w1)
            print(loss0.item(), loss1.item())
            #对网络参数进行初始化
            optimizer.zero_grad()
            loss0.backward()
            loss1.backward()
            #对网络的参数进行更新
            optimizer.step()

但是在实际操作中想到那这样的话岂不是每次都先使用t=0的数据训练公共的表示层,再使用t=1的数据去训练。这样会不会使表示层产生bias呢?且这样两步训练也很麻烦。

修改后的方法

使用之前训练方法其实还是对神经网络的训练的机理不清楚。事实上,在计算loss的时候每个数据点的梯度都是单独计算的。

因此完全可以把网络前向传播得到结果按之前的顺序拼接起来后再进行梯度的反向传播,这样就可以只进行一步训练,且不会出现训练先后的偏差。

代码如下:

f_out_y0, cf_out_y0 = net(x0)
            cf_out_y1, f_out_y1 = net(x1)
            #按照t=0和t=1的索引拼接向量
            y_pred = torch.zeros([len(x), 1])
            y_pred[index0] = f_out_y0
            y_pred[index1] = f_out_y1
            
   criterion = Loss()
            loss = criterion(f_y, y_pred, w) + 0.01 * (l2_regularization0 + l2_regularization1)
            #print(loss.item())
            viz.line([float(loss)], [epoch], win='train_loss', update='append')
            optimizer.zero_grad()
            loss.backward()
            #对网络的参数进行更新
            optimizer.step()

总结

two-head网络前向传播得到结果的时候是分开得到的,训练的时候通过拼接预测结果可以实现一次训练。

补充:Pytorch训练网络的一般步骤

如下所示:

import torch 
print(torch.tensor([1,2,3],dtype=torch.float))#将一个列表强制转换为torch.Tensor类型
print(torch.randn(5,3))#生成torch.Tensor类型的5X3的随机数

1、构建模型

2、定义一个损失函数

3、定义一个优化器

4、将训练数据带入模型得到预测值

5、将梯度清零

6、获得损失

7、进行优化

import torch
from torch.autograd import Variable
 
#初步认识构建Tensor数据
def one():
    print(torch.tensor([1,2,3],dtype=torch.float))#将一个列表强制转换为torch.Tensor类型
    print(torch.randn(5,3))#生成torch.Tensor类型的5X3的随机数
    print(torch.zeros((2,3)))#生成一个2X3的全零矩阵
    print(torch.ones((2,3)))#生成一个2X3的全一矩阵
    a = torch.randn((2,3))
    b = a.numpy()#将一个torch.Tensor转换为numpy
    c = torch.from_numpy(b)#将numpy转换为Tensor
    print(a)
    print(b)
    print(c)
 
#使用Variable自动求导
def two():
    # 构建Variable
    x = Variable(torch.Tensor([1, 2, 3]), requires_grad=True)
    w = Variable(torch.Tensor([4, 5, 6]), requires_grad=True)
    b = Variable(torch.Tensor([7, 8, 9]), requires_grad=True)
    # 函数等式
    y = w * x ** 2 + b
    # 使用梯度下降计算各变量的偏导数
    y.backward(torch.Tensor([1, 1, 1]))
    print(x.grad)
    print(w.grad)
    print(b.grad)

线性回归例子:

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
 
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = 3*x+10+torch.rand(x.size())
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.Linear = nn.Linear(1,1)
    def forward(self,x):
        return self.Linear(x)
model = LinearRegression()
Loss = nn.MSELoss()
Opt = torch.optim.SGD(model.parameters(),lr=0.01)
for i in range(1000):
    inputs = Variable(x)
    targets = Variable(y)
    outputs = model(inputs)
    loss = Loss(outputs,targets)
    Opt.zero_grad()
    loss.backward()
    Opt.step()
model.eval()
predict = model(Variable(x))
plt.plot(x.numpy(),y.numpy(),'ro')
plt.plot(x.numpy(),predict.data.numpy())
plt.show()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的类实例属性访问规则探讨
Jan 30 Python
浅析Python多线程下的变量问题
Apr 28 Python
python开发之字符串string操作方法实例详解
Nov 12 Python
Python学习小技巧之列表项的推导式与过滤操作
May 20 Python
详谈Pandas中iloc和loc以及ix的区别
Jun 08 Python
Python模拟浏览器上传文件脚本的方法(Multipart/form-data格式)
Oct 22 Python
Python实现繁体中文与简体中文相互转换的方法示例
Dec 18 Python
django框架事务处理小结【ORM 事务及raw sql,customize sql 事务处理】
Jun 27 Python
python实现切割url得到域名、协议、主机名等各个字段的例子
Jul 25 Python
Python学习笔记之lambda表达式用法详解
Aug 08 Python
python图像处理模块Pillow的学习详解
Oct 09 Python
Pygame游戏开发之太空射击实战敌人精灵篇
Aug 05 Python
使用Python的开发框架Brownie部署以太坊智能合约
使用Pytorch实现two-head(多输出)模型的操作
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 #Python
pytorch锁死在dataloader(训练时卡死)
You might like
世界咖啡生产者论坛呼吁:需要立即就咖啡价格采取认真行动
2021/03/06 咖啡文化
如何使用php绘制在图片上的正余弦曲线
2013/06/08 PHP
php读取本地json文件的实例
2018/03/07 PHP
smarty模板的使用方法实例分析
2019/09/18 PHP
PHP设计模式之 策略模式Strategy详解【对象行为型】
2020/05/01 PHP
js 禁用浏览器的后退功能的简单方法
2008/12/10 Javascript
jquery 关键字“拖曳搜索”之“拖曳”以及 图片“提示自适应放大”效果 的实现
2010/04/18 Javascript
JavaScript 浏览器验证代码(来自discuz)
2010/07/17 Javascript
js获取键盘按键响应事件(兼容各浏览器)
2013/05/16 Javascript
Jquery获得控件值的三种方法总结
2014/02/13 Javascript
IE中图片的onload事件无效问题和解决方法
2014/06/06 Javascript
JavaScript匿名函数与委托使用示例
2014/07/22 Javascript
JavaScript设计模式之代理模式介绍
2014/12/28 Javascript
浅谈jQuery中setInterval()方法
2015/07/07 Javascript
JS 事件绑定、事件监听、事件委托详细介绍
2016/09/28 Javascript
Bootstrap Table从零开始
2017/06/30 Javascript
捕获未处理的Promise错误方法
2017/10/13 Javascript
Vue结合后台导入导出Excel问题详解
2019/02/19 Javascript
微信小程序返回箭头跳转到指定页面实例解析
2019/10/08 Javascript
JS实现简易图片自动轮播
2020/10/16 Javascript
Python实现将HTML转成PDF的方法分析
2019/05/04 Python
numpy的Fancy Indexing和array比较详解
2020/06/11 Python
Python操作Excel的学习笔记
2021/02/18 Python
python 求两个向量的顺时针夹角操作
2021/03/04 Python
Expedia爱尔兰:酒店、机票、租车及廉价假期
2017/01/02 全球购物
英国受欢迎的运动鞋和街头服装商店:Footasylum
2018/06/12 全球购物
大学总结自我鉴定
2014/01/18 职场文书
迎元旦广播稿
2014/02/22 职场文书
导师工作推荐信范文
2014/05/17 职场文书
学前教育专业求职信
2014/09/02 职场文书
设备收款委托书范本
2014/10/02 职场文书
法定代表人证明书
2014/11/28 职场文书
如何做好工作总结!
2019/04/10 职场文书
如何用PHP实现分布算法之一致性哈希算法
2021/05/26 PHP
Python import模块的缓存问题解决方案
2021/06/02 Python
python中validators库的使用方法详解
2022/09/23 Python