使用Pytorch训练two-head网络的操作


Posted in Python onMay 28, 2021

以前的训练方法:

之前是把两个head分开进行训练的,因此每一轮训练先要对一个batch的数据进行划分,然后再分别训练两个头。代码如下:

f_out_y0, _ = net(x0)
            _, f_out_y1 = net(x1)
            #实例化损失函数
            criterion0 = Loss()
            criterion1 = Loss()
            loss0 = criterion0(f_y0, f_out_y0, w0)
            loss1 = criterion1(f_y1, f_out_y1, w1)
            print(loss0.item(), loss1.item())
            #对网络参数进行初始化
            optimizer.zero_grad()
            loss0.backward()
            loss1.backward()
            #对网络的参数进行更新
            optimizer.step()

但是在实际操作中想到那这样的话岂不是每次都先使用t=0的数据训练公共的表示层,再使用t=1的数据去训练。这样会不会使表示层产生bias呢?且这样两步训练也很麻烦。

修改后的方法

使用之前训练方法其实还是对神经网络的训练的机理不清楚。事实上,在计算loss的时候每个数据点的梯度都是单独计算的。

因此完全可以把网络前向传播得到结果按之前的顺序拼接起来后再进行梯度的反向传播,这样就可以只进行一步训练,且不会出现训练先后的偏差。

代码如下:

f_out_y0, cf_out_y0 = net(x0)
            cf_out_y1, f_out_y1 = net(x1)
            #按照t=0和t=1的索引拼接向量
            y_pred = torch.zeros([len(x), 1])
            y_pred[index0] = f_out_y0
            y_pred[index1] = f_out_y1
            
   criterion = Loss()
            loss = criterion(f_y, y_pred, w) + 0.01 * (l2_regularization0 + l2_regularization1)
            #print(loss.item())
            viz.line([float(loss)], [epoch], win='train_loss', update='append')
            optimizer.zero_grad()
            loss.backward()
            #对网络的参数进行更新
            optimizer.step()

总结

two-head网络前向传播得到结果的时候是分开得到的,训练的时候通过拼接预测结果可以实现一次训练。

补充:Pytorch训练网络的一般步骤

如下所示:

import torch 
print(torch.tensor([1,2,3],dtype=torch.float))#将一个列表强制转换为torch.Tensor类型
print(torch.randn(5,3))#生成torch.Tensor类型的5X3的随机数

1、构建模型

2、定义一个损失函数

3、定义一个优化器

4、将训练数据带入模型得到预测值

5、将梯度清零

6、获得损失

7、进行优化

import torch
from torch.autograd import Variable
 
#初步认识构建Tensor数据
def one():
    print(torch.tensor([1,2,3],dtype=torch.float))#将一个列表强制转换为torch.Tensor类型
    print(torch.randn(5,3))#生成torch.Tensor类型的5X3的随机数
    print(torch.zeros((2,3)))#生成一个2X3的全零矩阵
    print(torch.ones((2,3)))#生成一个2X3的全一矩阵
    a = torch.randn((2,3))
    b = a.numpy()#将一个torch.Tensor转换为numpy
    c = torch.from_numpy(b)#将numpy转换为Tensor
    print(a)
    print(b)
    print(c)
 
#使用Variable自动求导
def two():
    # 构建Variable
    x = Variable(torch.Tensor([1, 2, 3]), requires_grad=True)
    w = Variable(torch.Tensor([4, 5, 6]), requires_grad=True)
    b = Variable(torch.Tensor([7, 8, 9]), requires_grad=True)
    # 函数等式
    y = w * x ** 2 + b
    # 使用梯度下降计算各变量的偏导数
    y.backward(torch.Tensor([1, 1, 1]))
    print(x.grad)
    print(w.grad)
    print(b.grad)

线性回归例子:

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
 
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = 3*x+10+torch.rand(x.size())
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.Linear = nn.Linear(1,1)
    def forward(self,x):
        return self.Linear(x)
model = LinearRegression()
Loss = nn.MSELoss()
Opt = torch.optim.SGD(model.parameters(),lr=0.01)
for i in range(1000):
    inputs = Variable(x)
    targets = Variable(y)
    outputs = model(inputs)
    loss = Loss(outputs,targets)
    Opt.zero_grad()
    loss.backward()
    Opt.step()
model.eval()
predict = model(Variable(x))
plt.plot(x.numpy(),y.numpy(),'ro')
plt.plot(x.numpy(),predict.data.numpy())
plt.show()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现自动网页截图并裁剪图片
Jul 30 Python
对numpy中的transpose和swapaxes函数详解
Aug 02 Python
python 函数内部修改外部变量的方法
Dec 18 Python
python requests 库请求带有文件参数的接口实例
Jan 03 Python
django和vue实现数据交互的方法
Aug 21 Python
基于pytorch 预训练的词向量用法详解
Jan 06 Python
Pytorch数据拼接与拆分操作实现图解
Apr 30 Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 Python
Python 代码调试技巧示例代码
Aug 11 Python
python中判断数字是否为质数的实例讲解
Dec 06 Python
python中常用的数据结构介绍
Jan 12 Python
用python画城市轮播地图
May 28 Python
使用Python的开发框架Brownie部署以太坊智能合约
使用Pytorch实现two-head(多输出)模型的操作
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 #Python
pytorch锁死在dataloader(训练时卡死)
You might like
php UTF8 文件的签名问题
2009/10/30 PHP
php实现的用户查询类实例
2015/06/18 PHP
php生成图片缩略图功能示例
2017/02/22 PHP
SyntaxHighlighter代码加色使用方法
2008/09/07 Javascript
jquery中:input和input的区别分析
2011/07/13 Javascript
JavaScript中instanceof与typeof运算符的用法及区别详细解析
2013/11/19 Javascript
多个jquery.datatable共存,checkbox全选异常的快速解决方法
2013/12/10 Javascript
网页右侧悬浮滚动在线qq客服代码示例
2014/04/28 Javascript
JavaScript跨浏览器获取页面中相同class节点的方法
2015/03/03 Javascript
jQuery 遍历函数详解
2015/07/05 Javascript
js判断移动端是否安装某款app的多种方法
2015/12/18 Javascript
jQuery绑定事件监听bind和移除事件监听unbind用法实例详解
2016/01/19 Javascript
AngularJS 遇到的小坑与技巧小结
2016/06/07 Javascript
AngularJS基础 ng-mouseenter 指令示例代码
2016/08/02 Javascript
jQuery购物网页经典制作案例
2016/08/19 Javascript
荐书|您有一份JavaScript书单待签收
2017/07/21 Javascript
angularjs使用gulp-uglify压缩后执行报错的解决方法
2018/03/07 Javascript
react 创建单例组件的方法
2018/04/26 Javascript
微信小程序新闻网站详情页实例代码
2020/01/10 Javascript
JavaScript进制转换实现方法解析
2020/01/18 Javascript
封装 axios+promise通用请求函数操作
2020/08/11 Javascript
JavaScript 中的六种循环方法
2021/01/06 Javascript
Python正确重载运算符的方法示例详解
2017/08/27 Python
Python子进程subpocess原理及用法解析
2020/07/16 Python
华为俄罗斯官方网上商城:购买Huawei手机和平板
2017/04/21 全球购物
蛋白质世界:Protein World
2017/11/23 全球购物
学习委员自我鉴定
2014/01/13 职场文书
个性与发展自我评价
2014/02/11 职场文书
幼师辞职信怎么写
2015/02/27 职场文书
幼儿教师师德师风自我评价
2015/03/05 职场文书
2015年幼儿园中班工作总结
2015/04/25 职场文书
退休欢送会主持词
2015/07/01 职场文书
大一新生军训新闻稿
2015/07/17 职场文书
小学信息技术教学反思
2016/02/16 职场文书
Python编写可视化界面的全过程(Python+PyCharm+PyQt)
2021/05/17 Python
SpringBoot集成Redis,并自定义对象序列化操作
2021/06/22 Java/Android