使用Pytorch训练two-head网络的操作


Posted in Python onMay 28, 2021

以前的训练方法:

之前是把两个head分开进行训练的,因此每一轮训练先要对一个batch的数据进行划分,然后再分别训练两个头。代码如下:

f_out_y0, _ = net(x0)
            _, f_out_y1 = net(x1)
            #实例化损失函数
            criterion0 = Loss()
            criterion1 = Loss()
            loss0 = criterion0(f_y0, f_out_y0, w0)
            loss1 = criterion1(f_y1, f_out_y1, w1)
            print(loss0.item(), loss1.item())
            #对网络参数进行初始化
            optimizer.zero_grad()
            loss0.backward()
            loss1.backward()
            #对网络的参数进行更新
            optimizer.step()

但是在实际操作中想到那这样的话岂不是每次都先使用t=0的数据训练公共的表示层,再使用t=1的数据去训练。这样会不会使表示层产生bias呢?且这样两步训练也很麻烦。

修改后的方法

使用之前训练方法其实还是对神经网络的训练的机理不清楚。事实上,在计算loss的时候每个数据点的梯度都是单独计算的。

因此完全可以把网络前向传播得到结果按之前的顺序拼接起来后再进行梯度的反向传播,这样就可以只进行一步训练,且不会出现训练先后的偏差。

代码如下:

f_out_y0, cf_out_y0 = net(x0)
            cf_out_y1, f_out_y1 = net(x1)
            #按照t=0和t=1的索引拼接向量
            y_pred = torch.zeros([len(x), 1])
            y_pred[index0] = f_out_y0
            y_pred[index1] = f_out_y1
            
   criterion = Loss()
            loss = criterion(f_y, y_pred, w) + 0.01 * (l2_regularization0 + l2_regularization1)
            #print(loss.item())
            viz.line([float(loss)], [epoch], win='train_loss', update='append')
            optimizer.zero_grad()
            loss.backward()
            #对网络的参数进行更新
            optimizer.step()

总结

two-head网络前向传播得到结果的时候是分开得到的,训练的时候通过拼接预测结果可以实现一次训练。

补充:Pytorch训练网络的一般步骤

如下所示:

import torch 
print(torch.tensor([1,2,3],dtype=torch.float))#将一个列表强制转换为torch.Tensor类型
print(torch.randn(5,3))#生成torch.Tensor类型的5X3的随机数

1、构建模型

2、定义一个损失函数

3、定义一个优化器

4、将训练数据带入模型得到预测值

5、将梯度清零

6、获得损失

7、进行优化

import torch
from torch.autograd import Variable
 
#初步认识构建Tensor数据
def one():
    print(torch.tensor([1,2,3],dtype=torch.float))#将一个列表强制转换为torch.Tensor类型
    print(torch.randn(5,3))#生成torch.Tensor类型的5X3的随机数
    print(torch.zeros((2,3)))#生成一个2X3的全零矩阵
    print(torch.ones((2,3)))#生成一个2X3的全一矩阵
    a = torch.randn((2,3))
    b = a.numpy()#将一个torch.Tensor转换为numpy
    c = torch.from_numpy(b)#将numpy转换为Tensor
    print(a)
    print(b)
    print(c)
 
#使用Variable自动求导
def two():
    # 构建Variable
    x = Variable(torch.Tensor([1, 2, 3]), requires_grad=True)
    w = Variable(torch.Tensor([4, 5, 6]), requires_grad=True)
    b = Variable(torch.Tensor([7, 8, 9]), requires_grad=True)
    # 函数等式
    y = w * x ** 2 + b
    # 使用梯度下降计算各变量的偏导数
    y.backward(torch.Tensor([1, 1, 1]))
    print(x.grad)
    print(w.grad)
    print(b.grad)

线性回归例子:

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
 
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = 3*x+10+torch.rand(x.size())
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.Linear = nn.Linear(1,1)
    def forward(self,x):
        return self.Linear(x)
model = LinearRegression()
Loss = nn.MSELoss()
Opt = torch.optim.SGD(model.parameters(),lr=0.01)
for i in range(1000):
    inputs = Variable(x)
    targets = Variable(y)
    outputs = model(inputs)
    loss = Loss(outputs,targets)
    Opt.zero_grad()
    loss.backward()
    Opt.step()
model.eval()
predict = model(Variable(x))
plt.plot(x.numpy(),y.numpy(),'ro')
plt.plot(x.numpy(),predict.data.numpy())
plt.show()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现绘制树枝简单示例
Jul 24 Python
对python numpy数组中冒号的使用方法详解
Apr 17 Python
使用PyInstaller将python转成可执行文件exe笔记
May 26 Python
Linux下python3.7.0安装教程
Jul 30 Python
django 将model转换为字典的方法示例
Oct 16 Python
python如何实现代码检查
Jun 28 Python
python 实现手机自动拨打电话的方法(通话压力测试)
Aug 08 Python
python判断自身是否正在运行的方法
Aug 08 Python
PHP统计代码行数的小代码
Sep 19 Python
Pycharm创建文件时自动生成文件头注释(自定义设置作者日期)
Nov 24 Python
Jmeter调用Python脚本实现参数互相传递的实现
Jan 22 Python
Python面向对象之内置函数相关知识总结
Jun 24 Python
使用Python的开发框架Brownie部署以太坊智能合约
使用Pytorch实现two-head(多输出)模型的操作
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 #Python
pytorch锁死在dataloader(训练时卡死)
You might like
Apache设置虚拟WEB
2006/10/09 PHP
php中的数组操作函数整理
2008/08/18 PHP
ThinkPHP 连接Oracle数据库的详细教程[全]
2012/07/16 PHP
PHP中上传多个文件的表单设计例子
2014/11/19 PHP
PHP中的命名空间详细介绍
2015/07/02 PHP
不错的asp中显示新闻的功能
2006/10/13 Javascript
爱恋千雪-US-AscII加密解密工具(网页加密)下载
2007/06/06 Javascript
用JTrackBar实现的模拟苹果风格的滚动条
2007/08/06 Javascript
jQuery select操作控制方法小结
2010/05/26 Javascript
JavaScript加强之自定义event事件
2013/09/21 Javascript
JavaScript将数组转换成CSV格式的方法
2015/03/19 Javascript
JavaScript常用判断写法大全(推荐)
2016/05/30 Javascript
js 实现数值的千分位及保存小数方法(推荐)
2016/08/01 Javascript
JS中IP地址与整数相互转换的实现代码
2017/04/10 Javascript
vue mint-ui学习笔记之picker的使用
2017/10/11 Javascript
JavaScript实现二叉树的先序、中序及后序遍历方法详解
2017/10/26 Javascript
从零到一详聊创建Vue工程及遇到的常见问题
2019/04/25 Javascript
[05:45]Ti4观战指南(下)
2014/07/07 DOTA
Python入门_学会创建并调用函数的方法
2017/05/16 Python
Python(TensorFlow框架)实现手写数字识别系统的方法
2018/05/29 Python
Python实现按逗号分隔列表的方法
2018/10/23 Python
Python 中的参数传递、返回值、浅拷贝、深拷贝
2019/06/25 Python
Python3之字节串bytes与字节数组bytearray的使用详解
2019/08/27 Python
python opencv将表格图片按照表格框线分割和识别
2019/10/30 Python
Python3.9又更新了:dict内置新功能
2020/02/28 Python
Python函数递归调用实现原理实例解析
2020/08/11 Python
python list的index()和find()的实现
2020/11/16 Python
CSS3 实现童年的纸飞机
2019/05/05 HTML / CSS
x-ua-compatible content=”IE=7, IE=9″意思理解
2013/07/22 HTML / CSS
ProBikeKit新西兰:自行车套件,跑步和铁人三项装备
2017/04/05 全球购物
领先的荷兰线上超市:荷兰之家Holland at Home(支持中文)
2021/01/21 全球购物
诉讼财产保全担保书
2014/05/20 职场文书
乡村教师党员四风问题对照检查材料思想汇报
2014/10/08 职场文书
无房产证房屋转让协议书合同样本
2014/10/18 职场文书
Golang表示枚举类型的详细讲解
2021/09/04 Golang
js前端面试常见浏览器缓存强缓存及协商缓存实例
2022/06/21 Javascript