使用Pytorch训练two-head网络的操作


Posted in Python onMay 28, 2021

以前的训练方法:

之前是把两个head分开进行训练的,因此每一轮训练先要对一个batch的数据进行划分,然后再分别训练两个头。代码如下:

f_out_y0, _ = net(x0)
            _, f_out_y1 = net(x1)
            #实例化损失函数
            criterion0 = Loss()
            criterion1 = Loss()
            loss0 = criterion0(f_y0, f_out_y0, w0)
            loss1 = criterion1(f_y1, f_out_y1, w1)
            print(loss0.item(), loss1.item())
            #对网络参数进行初始化
            optimizer.zero_grad()
            loss0.backward()
            loss1.backward()
            #对网络的参数进行更新
            optimizer.step()

但是在实际操作中想到那这样的话岂不是每次都先使用t=0的数据训练公共的表示层,再使用t=1的数据去训练。这样会不会使表示层产生bias呢?且这样两步训练也很麻烦。

修改后的方法

使用之前训练方法其实还是对神经网络的训练的机理不清楚。事实上,在计算loss的时候每个数据点的梯度都是单独计算的。

因此完全可以把网络前向传播得到结果按之前的顺序拼接起来后再进行梯度的反向传播,这样就可以只进行一步训练,且不会出现训练先后的偏差。

代码如下:

f_out_y0, cf_out_y0 = net(x0)
            cf_out_y1, f_out_y1 = net(x1)
            #按照t=0和t=1的索引拼接向量
            y_pred = torch.zeros([len(x), 1])
            y_pred[index0] = f_out_y0
            y_pred[index1] = f_out_y1
            
   criterion = Loss()
            loss = criterion(f_y, y_pred, w) + 0.01 * (l2_regularization0 + l2_regularization1)
            #print(loss.item())
            viz.line([float(loss)], [epoch], win='train_loss', update='append')
            optimizer.zero_grad()
            loss.backward()
            #对网络的参数进行更新
            optimizer.step()

总结

two-head网络前向传播得到结果的时候是分开得到的,训练的时候通过拼接预测结果可以实现一次训练。

补充:Pytorch训练网络的一般步骤

如下所示:

import torch 
print(torch.tensor([1,2,3],dtype=torch.float))#将一个列表强制转换为torch.Tensor类型
print(torch.randn(5,3))#生成torch.Tensor类型的5X3的随机数

1、构建模型

2、定义一个损失函数

3、定义一个优化器

4、将训练数据带入模型得到预测值

5、将梯度清零

6、获得损失

7、进行优化

import torch
from torch.autograd import Variable
 
#初步认识构建Tensor数据
def one():
    print(torch.tensor([1,2,3],dtype=torch.float))#将一个列表强制转换为torch.Tensor类型
    print(torch.randn(5,3))#生成torch.Tensor类型的5X3的随机数
    print(torch.zeros((2,3)))#生成一个2X3的全零矩阵
    print(torch.ones((2,3)))#生成一个2X3的全一矩阵
    a = torch.randn((2,3))
    b = a.numpy()#将一个torch.Tensor转换为numpy
    c = torch.from_numpy(b)#将numpy转换为Tensor
    print(a)
    print(b)
    print(c)
 
#使用Variable自动求导
def two():
    # 构建Variable
    x = Variable(torch.Tensor([1, 2, 3]), requires_grad=True)
    w = Variable(torch.Tensor([4, 5, 6]), requires_grad=True)
    b = Variable(torch.Tensor([7, 8, 9]), requires_grad=True)
    # 函数等式
    y = w * x ** 2 + b
    # 使用梯度下降计算各变量的偏导数
    y.backward(torch.Tensor([1, 1, 1]))
    print(x.grad)
    print(w.grad)
    print(b.grad)

线性回归例子:

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
 
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = 3*x+10+torch.rand(x.size())
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.Linear = nn.Linear(1,1)
    def forward(self,x):
        return self.Linear(x)
model = LinearRegression()
Loss = nn.MSELoss()
Opt = torch.optim.SGD(model.parameters(),lr=0.01)
for i in range(1000):
    inputs = Variable(x)
    targets = Variable(y)
    outputs = model(inputs)
    loss = Loss(outputs,targets)
    Opt.zero_grad()
    loss.backward()
    Opt.step()
model.eval()
predict = model(Variable(x))
plt.plot(x.numpy(),y.numpy(),'ro')
plt.plot(x.numpy(),predict.data.numpy())
plt.show()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中遇到的小问题及解决方法汇总
Jan 11 Python
Python基于递归算法实现的汉诺塔与Fibonacci数列示例
Apr 18 Python
Python常用爬虫代码总结方便查询
Feb 25 Python
python 猴子补丁(monkey patch)
Jun 26 Python
python飞机大战pygame碰撞检测实现方法分析
Dec 17 Python
详解Django配置JWT认证方式
May 09 Python
Python新手学习装饰器
Jun 04 Python
Python内置方法和属性应用:反射和单例(推荐)
Jun 19 Python
python中有帮助函数吗
Jun 19 Python
python中通过pip安装库文件时出现“EnvironmentError: [WinError 5] 拒绝访问”的问题及解决方案
Aug 11 Python
python爬取豆瓣电影排行榜(requests)的示例代码
Feb 18 Python
深入理解pytorch库的dockerfile
Jun 10 Python
使用Python的开发框架Brownie部署以太坊智能合约
使用Pytorch实现two-head(多输出)模型的操作
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 #Python
pytorch锁死在dataloader(训练时卡死)
You might like
PHP中的traits简单使用实例
2015/05/13 PHP
调用WordPress函数统计文章访问量及PHP原生计数器的实现
2016/03/21 PHP
php获取文件后缀的9种方法
2016/03/22 PHP
PHP基于GD2函数库实现验证码功能示例
2019/01/27 PHP
Jquery升级新版本后选择器的语法问题
2010/06/02 Javascript
js事件监听机制(事件捕获)总结
2014/08/08 Javascript
js随机生成网页背景颜色的方法
2015/02/26 Javascript
jQuery获得document和window对象宽度和高度的方法
2015/03/25 Javascript
在JavaScript中如何解决用execCommand(
2015/10/19 Javascript
模仿password输入框的实现代码
2016/06/07 Javascript
JavaScript判断浏览器对CSS3属性是否支持的多种方法
2016/11/13 Javascript
浅谈javascript中执行环境(作用域)与作用域链
2016/12/08 Javascript
JS中对数组元素进行增删改移的方法总结
2016/12/15 Javascript
使用Bootrap和Vue实现仿百度搜索功能
2017/10/26 Javascript
Vue列表渲染的示例代码
2018/11/01 Javascript
详解vue数组遍历方法forEach和map的原理解析和实际应用
2018/11/15 Javascript
layui button 按钮弹出提示窗口,确定才进行的方法
2019/09/06 Javascript
vue 实现cli3.0中使用proxy进行代理转发
2019/10/30 Javascript
JavaScript面试中常考的字符串操作方法大全(包含ES6)
2020/05/10 Javascript
原生JavaScript实现刮刮乐
2020/09/29 Javascript
详尽讲述用Python的Django框架测试驱动开发的教程
2015/04/22 Python
研究Python的ORM框架中的SQLAlchemy库的映射关系
2015/04/25 Python
python 统计列表中不同元素的数量方法
2018/06/29 Python
python将.ppm格式图片转换成.jpg格式文件的方法
2018/10/27 Python
python3 打开外部程序及关闭的示例
2018/11/06 Python
台湾森森购物网:U-mall
2017/10/16 全球购物
美国在线眼镜店:GlassesShop
2018/11/15 全球购物
澳大利亚波希米亚风时尚品牌:Tree of Life
2019/09/15 全球购物
物业工作计划书
2014/01/10 职场文书
高中化学教学反思
2014/01/13 职场文书
酒店员工职业生涯规划
2014/02/25 职场文书
高中综合实践活动总结
2014/07/07 职场文书
户籍证明格式
2014/09/15 职场文书
太空授课观后感
2015/06/17 职场文书
环保建议书作文500字
2015/09/14 职场文书
Python3 多线程(连接池)操作MySQL插入数据
2021/06/09 Python