使用Pytorch训练two-head网络的操作


Posted in Python onMay 28, 2021

以前的训练方法:

之前是把两个head分开进行训练的,因此每一轮训练先要对一个batch的数据进行划分,然后再分别训练两个头。代码如下:

f_out_y0, _ = net(x0)
            _, f_out_y1 = net(x1)
            #实例化损失函数
            criterion0 = Loss()
            criterion1 = Loss()
            loss0 = criterion0(f_y0, f_out_y0, w0)
            loss1 = criterion1(f_y1, f_out_y1, w1)
            print(loss0.item(), loss1.item())
            #对网络参数进行初始化
            optimizer.zero_grad()
            loss0.backward()
            loss1.backward()
            #对网络的参数进行更新
            optimizer.step()

但是在实际操作中想到那这样的话岂不是每次都先使用t=0的数据训练公共的表示层,再使用t=1的数据去训练。这样会不会使表示层产生bias呢?且这样两步训练也很麻烦。

修改后的方法

使用之前训练方法其实还是对神经网络的训练的机理不清楚。事实上,在计算loss的时候每个数据点的梯度都是单独计算的。

因此完全可以把网络前向传播得到结果按之前的顺序拼接起来后再进行梯度的反向传播,这样就可以只进行一步训练,且不会出现训练先后的偏差。

代码如下:

f_out_y0, cf_out_y0 = net(x0)
            cf_out_y1, f_out_y1 = net(x1)
            #按照t=0和t=1的索引拼接向量
            y_pred = torch.zeros([len(x), 1])
            y_pred[index0] = f_out_y0
            y_pred[index1] = f_out_y1
            
   criterion = Loss()
            loss = criterion(f_y, y_pred, w) + 0.01 * (l2_regularization0 + l2_regularization1)
            #print(loss.item())
            viz.line([float(loss)], [epoch], win='train_loss', update='append')
            optimizer.zero_grad()
            loss.backward()
            #对网络的参数进行更新
            optimizer.step()

总结

two-head网络前向传播得到结果的时候是分开得到的,训练的时候通过拼接预测结果可以实现一次训练。

补充:Pytorch训练网络的一般步骤

如下所示:

import torch 
print(torch.tensor([1,2,3],dtype=torch.float))#将一个列表强制转换为torch.Tensor类型
print(torch.randn(5,3))#生成torch.Tensor类型的5X3的随机数

1、构建模型

2、定义一个损失函数

3、定义一个优化器

4、将训练数据带入模型得到预测值

5、将梯度清零

6、获得损失

7、进行优化

import torch
from torch.autograd import Variable
 
#初步认识构建Tensor数据
def one():
    print(torch.tensor([1,2,3],dtype=torch.float))#将一个列表强制转换为torch.Tensor类型
    print(torch.randn(5,3))#生成torch.Tensor类型的5X3的随机数
    print(torch.zeros((2,3)))#生成一个2X3的全零矩阵
    print(torch.ones((2,3)))#生成一个2X3的全一矩阵
    a = torch.randn((2,3))
    b = a.numpy()#将一个torch.Tensor转换为numpy
    c = torch.from_numpy(b)#将numpy转换为Tensor
    print(a)
    print(b)
    print(c)
 
#使用Variable自动求导
def two():
    # 构建Variable
    x = Variable(torch.Tensor([1, 2, 3]), requires_grad=True)
    w = Variable(torch.Tensor([4, 5, 6]), requires_grad=True)
    b = Variable(torch.Tensor([7, 8, 9]), requires_grad=True)
    # 函数等式
    y = w * x ** 2 + b
    # 使用梯度下降计算各变量的偏导数
    y.backward(torch.Tensor([1, 1, 1]))
    print(x.grad)
    print(w.grad)
    print(b.grad)

线性回归例子:

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
 
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = 3*x+10+torch.rand(x.size())
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.Linear = nn.Linear(1,1)
    def forward(self,x):
        return self.Linear(x)
model = LinearRegression()
Loss = nn.MSELoss()
Opt = torch.optim.SGD(model.parameters(),lr=0.01)
for i in range(1000):
    inputs = Variable(x)
    targets = Variable(y)
    outputs = model(inputs)
    loss = Loss(outputs,targets)
    Opt.zero_grad()
    loss.backward()
    Opt.step()
model.eval()
predict = model(Variable(x))
plt.plot(x.numpy(),y.numpy(),'ro')
plt.plot(x.numpy(),predict.data.numpy())
plt.show()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
构建Python包的五个简单准则简介
Jun 15 Python
CentOS中升级Python版本的方法详解
Jul 10 Python
Python实现曲线拟合操作示例【基于numpy,scipy,matplotlib库】
Jul 12 Python
selenium3+python3环境搭建教程图解
Dec 07 Python
Python实现的对一个数进行因式分解操作示例
Jun 27 Python
python的debug实用工具 pdb详解
Jul 12 Python
Python中print函数简单使用总结
Aug 05 Python
python模块常用用法实例详解
Oct 17 Python
Python基于yield遍历多个可迭代对象
Mar 12 Python
keras训练曲线,混淆矩阵,CNN层输出可视化实例
Jun 15 Python
django restframework serializer 增加自定义字段操作
Jul 15 Python
python数字转对应中文的方法总结
Aug 02 Python
使用Python的开发框架Brownie部署以太坊智能合约
使用Pytorch实现two-head(多输出)模型的操作
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 #Python
pytorch锁死在dataloader(训练时卡死)
You might like
一步一步学习PHP(5) 类和对象
2010/02/16 PHP
PHP获取文件的MD5值并判断是否被修改的例子
2014/06/19 PHP
PHP中执行cmd命令的方法
2014/10/11 PHP
PHP将数据导出Excel表中的实例(投机型)
2017/07/31 PHP
PHP数组式访问接口ArrayAccess用法分析
2017/12/28 PHP
PHP date()格式MySQL中插入datetime方法
2019/01/29 PHP
IE6图片加载的一个BUG解决方法
2010/07/13 Javascript
网易JS面试题与Javascript词法作用域说明
2010/11/09 Javascript
打印json对象的内容及JSON.stringify函数应用
2013/03/29 Javascript
javascript获取网页中指定节点的父节点、子节点的方法小结
2013/04/24 Javascript
JS Replace()的高级使用方法介绍
2013/06/29 Javascript
jquery iframe操作详细解析
2013/11/20 Javascript
JS简单编号生成器实现方法(附demo源码下载)
2016/04/05 Javascript
Jqprint实现页面打印
2017/01/06 Javascript
Ajax验证用户名或昵称是否已被注册
2017/04/05 Javascript
JavaScript学习笔记之函数记忆
2017/09/06 Javascript
使用bootstrap实现下拉框搜索功能的实例讲解
2018/08/10 Javascript
jQuery事件绑定和解绑、事件冒泡与阻止事件冒泡及弹出应用示例
2019/05/13 jQuery
最简单的vue消息提示全局组件的方法
2019/06/16 Javascript
Python实现的彩票机选器实例
2015/06/17 Python
Pycharm学习教程(7)虚拟机VM的配置教程
2017/05/04 Python
python实现多线程行情抓取工具的方法
2018/02/28 Python
PyCharm代码整体缩进,反向缩进的方法
2018/06/25 Python
Django框架模板注入操作示例【变量传递到模板】
2018/12/19 Python
Python英文文章词频统计(14份剑桥真题词频统计)
2019/10/13 Python
python如何实现复制目录到指定目录
2020/02/13 Python
python中字典增加和删除使用方法
2020/09/30 Python
python 实现网易邮箱邮件阅读和删除的辅助小脚本
2021/03/01 Python
英国航空官网:British Airways
2016/09/11 全球购物
香港现代设计家具品牌:Ziinlife Furniture
2018/11/13 全球购物
《玩具柜台前的孩子》教学反思
2014/02/13 职场文书
党的群众路线教育实践活动对照检查材料思想汇报
2014/09/19 职场文书
2014年女职工工作总结
2014/11/27 职场文书
心理健康教育培训研修感言
2015/11/18 职场文书
Python 详解通过Scrapy框架实现爬取百度新冠疫情数据流程
2021/11/11 Python
详解Python+OpenCV绘制灰度直方图
2022/03/22 Python