使用Pytorch训练two-head网络的操作


Posted in Python onMay 28, 2021

以前的训练方法:

之前是把两个head分开进行训练的,因此每一轮训练先要对一个batch的数据进行划分,然后再分别训练两个头。代码如下:

f_out_y0, _ = net(x0)
            _, f_out_y1 = net(x1)
            #实例化损失函数
            criterion0 = Loss()
            criterion1 = Loss()
            loss0 = criterion0(f_y0, f_out_y0, w0)
            loss1 = criterion1(f_y1, f_out_y1, w1)
            print(loss0.item(), loss1.item())
            #对网络参数进行初始化
            optimizer.zero_grad()
            loss0.backward()
            loss1.backward()
            #对网络的参数进行更新
            optimizer.step()

但是在实际操作中想到那这样的话岂不是每次都先使用t=0的数据训练公共的表示层,再使用t=1的数据去训练。这样会不会使表示层产生bias呢?且这样两步训练也很麻烦。

修改后的方法

使用之前训练方法其实还是对神经网络的训练的机理不清楚。事实上,在计算loss的时候每个数据点的梯度都是单独计算的。

因此完全可以把网络前向传播得到结果按之前的顺序拼接起来后再进行梯度的反向传播,这样就可以只进行一步训练,且不会出现训练先后的偏差。

代码如下:

f_out_y0, cf_out_y0 = net(x0)
            cf_out_y1, f_out_y1 = net(x1)
            #按照t=0和t=1的索引拼接向量
            y_pred = torch.zeros([len(x), 1])
            y_pred[index0] = f_out_y0
            y_pred[index1] = f_out_y1
            
   criterion = Loss()
            loss = criterion(f_y, y_pred, w) + 0.01 * (l2_regularization0 + l2_regularization1)
            #print(loss.item())
            viz.line([float(loss)], [epoch], win='train_loss', update='append')
            optimizer.zero_grad()
            loss.backward()
            #对网络的参数进行更新
            optimizer.step()

总结

two-head网络前向传播得到结果的时候是分开得到的,训练的时候通过拼接预测结果可以实现一次训练。

补充:Pytorch训练网络的一般步骤

如下所示:

import torch 
print(torch.tensor([1,2,3],dtype=torch.float))#将一个列表强制转换为torch.Tensor类型
print(torch.randn(5,3))#生成torch.Tensor类型的5X3的随机数

1、构建模型

2、定义一个损失函数

3、定义一个优化器

4、将训练数据带入模型得到预测值

5、将梯度清零

6、获得损失

7、进行优化

import torch
from torch.autograd import Variable
 
#初步认识构建Tensor数据
def one():
    print(torch.tensor([1,2,3],dtype=torch.float))#将一个列表强制转换为torch.Tensor类型
    print(torch.randn(5,3))#生成torch.Tensor类型的5X3的随机数
    print(torch.zeros((2,3)))#生成一个2X3的全零矩阵
    print(torch.ones((2,3)))#生成一个2X3的全一矩阵
    a = torch.randn((2,3))
    b = a.numpy()#将一个torch.Tensor转换为numpy
    c = torch.from_numpy(b)#将numpy转换为Tensor
    print(a)
    print(b)
    print(c)
 
#使用Variable自动求导
def two():
    # 构建Variable
    x = Variable(torch.Tensor([1, 2, 3]), requires_grad=True)
    w = Variable(torch.Tensor([4, 5, 6]), requires_grad=True)
    b = Variable(torch.Tensor([7, 8, 9]), requires_grad=True)
    # 函数等式
    y = w * x ** 2 + b
    # 使用梯度下降计算各变量的偏导数
    y.backward(torch.Tensor([1, 1, 1]))
    print(x.grad)
    print(w.grad)
    print(b.grad)

线性回归例子:

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
 
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = 3*x+10+torch.rand(x.size())
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.Linear = nn.Linear(1,1)
    def forward(self,x):
        return self.Linear(x)
model = LinearRegression()
Loss = nn.MSELoss()
Opt = torch.optim.SGD(model.parameters(),lr=0.01)
for i in range(1000):
    inputs = Variable(x)
    targets = Variable(y)
    outputs = model(inputs)
    loss = Loss(outputs,targets)
    Opt.zero_grad()
    loss.backward()
    Opt.step()
model.eval()
predict = model(Variable(x))
plt.plot(x.numpy(),y.numpy(),'ro')
plt.plot(x.numpy(),predict.data.numpy())
plt.show()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python dict remove数组删除(del,pop)
Mar 24 Python
python静态方法实例
Jan 14 Python
Python编程中装饰器的使用示例解析
Jun 20 Python
详解用Python处理HTML转义字符的5种方式
Dec 27 Python
python3+dlib实现人脸识别和情绪分析
Apr 21 Python
Python 使用PIL中的resize进行缩放的实例讲解
Aug 03 Python
python解析含有重复key的json方法
Jan 22 Python
查看python安装路径及pip安装的包列表及路径
Apr 03 Python
Django设置Postgresql的操作
May 14 Python
详解python中的lambda与sorted函数
Sep 04 Python
pycharm激活方法到2099年(激活流程)
Sep 22 Python
python opencv旋转图片的使用方法
Jun 04 Python
使用Python的开发框架Brownie部署以太坊智能合约
使用Pytorch实现two-head(多输出)模型的操作
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 #Python
pytorch锁死在dataloader(训练时卡死)
You might like
如何实现给定日期的若干天以后的日期
2006/10/09 PHP
php设计模式 Visitor 访问者模式
2011/06/28 PHP
php cURL和Rolling cURL并发方式比较
2013/10/30 PHP
ThinkPHP CURD方法之order方法详解
2014/06/18 PHP
PHP中常用的字符串格式化函数总结
2014/11/19 PHP
Symfony2中被遗弃的getRequest()方法分析
2016/03/17 PHP
php版微信自动获取收货地址api用法示例
2016/09/22 PHP
Yii2中添加全局函数的方法分析
2017/05/04 PHP
tp5.1 框架数据库-数据集操作实例分析
2020/05/26 PHP
jquery多浏览器捕捉回车事件代码
2010/06/22 Javascript
jQuery 验证插件 Web前端设计模式(asp.net)
2010/10/17 Javascript
为Javascript中的String对象添加去除左右空格的方法(示例代码)
2013/11/30 Javascript
使用js实现关闭js弹出层的窗口
2014/02/10 Javascript
JS 弹出层 定位至屏幕居中示例
2014/05/21 Javascript
基于jquery实现的文字向上跑动类似跑马灯的效果
2014/06/17 Javascript
JS显示下拉列表框内全部元素的方法
2015/03/31 Javascript
js脚本分页代码分享(7种样式)
2015/08/19 Javascript
jQuery绑定事件监听bind和移除事件监听unbind用法实例详解
2016/01/19 Javascript
js 中文汉字转Unicode、Unicode转中文汉字、ASCII转换Unicode、Unicode转换ASCII、中文转换
2016/12/06 Javascript
vue-cli之router基本使用方法详解
2017/10/17 Javascript
浅谈webpack下的AOP式无侵入注入
2017/11/12 Javascript
Vue利用Blob下载原生二进制数组文件
2019/09/25 Javascript
Node.js API详解之 dgram模块用法实例分析
2020/06/05 Javascript
JavaScript本地储存:localStorage、sessionStorage、cookie的使用
2020/10/13 Javascript
使用Python写一个贪吃蛇游戏实例代码
2017/08/21 Python
python使用json序列化datetime类型实例解析
2018/02/11 Python
使用python存储网页上的图片实例
2018/05/22 Python
对python:print打印时加u的含义详解
2018/12/15 Python
python海龟绘图之画国旗实例代码
2020/11/11 Python
关于青春的演讲稿
2014/05/05 职场文书
房屋租赁委托书范本
2014/10/04 职场文书
工厂标语大全
2014/10/06 职场文书
体育教师个人总结
2015/02/09 职场文书
法定代表人资格证明书
2015/06/18 职场文书
Go缓冲channel和非缓冲channel的区别说明
2021/04/25 Golang
自己搭建resnet18网络并加载torchvision自带权重的操作
2021/05/13 Python