使用Pytorch训练two-head网络的操作


Posted in Python onMay 28, 2021

以前的训练方法:

之前是把两个head分开进行训练的,因此每一轮训练先要对一个batch的数据进行划分,然后再分别训练两个头。代码如下:

f_out_y0, _ = net(x0)
            _, f_out_y1 = net(x1)
            #实例化损失函数
            criterion0 = Loss()
            criterion1 = Loss()
            loss0 = criterion0(f_y0, f_out_y0, w0)
            loss1 = criterion1(f_y1, f_out_y1, w1)
            print(loss0.item(), loss1.item())
            #对网络参数进行初始化
            optimizer.zero_grad()
            loss0.backward()
            loss1.backward()
            #对网络的参数进行更新
            optimizer.step()

但是在实际操作中想到那这样的话岂不是每次都先使用t=0的数据训练公共的表示层,再使用t=1的数据去训练。这样会不会使表示层产生bias呢?且这样两步训练也很麻烦。

修改后的方法

使用之前训练方法其实还是对神经网络的训练的机理不清楚。事实上,在计算loss的时候每个数据点的梯度都是单独计算的。

因此完全可以把网络前向传播得到结果按之前的顺序拼接起来后再进行梯度的反向传播,这样就可以只进行一步训练,且不会出现训练先后的偏差。

代码如下:

f_out_y0, cf_out_y0 = net(x0)
            cf_out_y1, f_out_y1 = net(x1)
            #按照t=0和t=1的索引拼接向量
            y_pred = torch.zeros([len(x), 1])
            y_pred[index0] = f_out_y0
            y_pred[index1] = f_out_y1
            
   criterion = Loss()
            loss = criterion(f_y, y_pred, w) + 0.01 * (l2_regularization0 + l2_regularization1)
            #print(loss.item())
            viz.line([float(loss)], [epoch], win='train_loss', update='append')
            optimizer.zero_grad()
            loss.backward()
            #对网络的参数进行更新
            optimizer.step()

总结

two-head网络前向传播得到结果的时候是分开得到的,训练的时候通过拼接预测结果可以实现一次训练。

补充:Pytorch训练网络的一般步骤

如下所示:

import torch 
print(torch.tensor([1,2,3],dtype=torch.float))#将一个列表强制转换为torch.Tensor类型
print(torch.randn(5,3))#生成torch.Tensor类型的5X3的随机数

1、构建模型

2、定义一个损失函数

3、定义一个优化器

4、将训练数据带入模型得到预测值

5、将梯度清零

6、获得损失

7、进行优化

import torch
from torch.autograd import Variable
 
#初步认识构建Tensor数据
def one():
    print(torch.tensor([1,2,3],dtype=torch.float))#将一个列表强制转换为torch.Tensor类型
    print(torch.randn(5,3))#生成torch.Tensor类型的5X3的随机数
    print(torch.zeros((2,3)))#生成一个2X3的全零矩阵
    print(torch.ones((2,3)))#生成一个2X3的全一矩阵
    a = torch.randn((2,3))
    b = a.numpy()#将一个torch.Tensor转换为numpy
    c = torch.from_numpy(b)#将numpy转换为Tensor
    print(a)
    print(b)
    print(c)
 
#使用Variable自动求导
def two():
    # 构建Variable
    x = Variable(torch.Tensor([1, 2, 3]), requires_grad=True)
    w = Variable(torch.Tensor([4, 5, 6]), requires_grad=True)
    b = Variable(torch.Tensor([7, 8, 9]), requires_grad=True)
    # 函数等式
    y = w * x ** 2 + b
    # 使用梯度下降计算各变量的偏导数
    y.backward(torch.Tensor([1, 1, 1]))
    print(x.grad)
    print(w.grad)
    print(b.grad)

线性回归例子:

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
 
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = 3*x+10+torch.rand(x.size())
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.Linear = nn.Linear(1,1)
    def forward(self,x):
        return self.Linear(x)
model = LinearRegression()
Loss = nn.MSELoss()
Opt = torch.optim.SGD(model.parameters(),lr=0.01)
for i in range(1000):
    inputs = Variable(x)
    targets = Variable(y)
    outputs = model(inputs)
    loss = Loss(outputs,targets)
    Opt.zero_grad()
    loss.backward()
    Opt.step()
model.eval()
predict = model(Variable(x))
plt.plot(x.numpy(),y.numpy(),'ro')
plt.plot(x.numpy(),predict.data.numpy())
plt.show()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python读写unicode文件的方法
Jul 10 Python
Python操作MySQL数据库9个实用实例
Dec 11 Python
Python脚本实现12306火车票查询系统
Sep 30 Python
Python闭包的两个注意事项(推荐)
Mar 20 Python
Python实现动态加载模块、类、函数的方法分析
Jul 18 Python
Django 生成登陆验证码代码分享
Dec 12 Python
Python爬虫之正则表达式基本用法实例分析
Aug 08 Python
django配置连接数据库及原生sql语句的使用方法
Mar 03 Python
django云端留言板实例详解
Jul 22 Python
python实现FTP文件传输的方法(服务器端和客户端)
Mar 20 Python
通过代码实例解析Pytest运行流程
Aug 20 Python
Python日志打印里logging.getLogger源码分析详解
Jan 17 Python
使用Python的开发框架Brownie部署以太坊智能合约
使用Pytorch实现two-head(多输出)模型的操作
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 #Python
pytorch锁死在dataloader(训练时卡死)
You might like
php访问查询mysql数据的三种方法
2006/10/09 PHP
php引用计数器进行垃圾收集机制介绍
2012/09/19 PHP
ThinkPHP CURD方法之table方法详解
2014/06/18 PHP
PHP中file_get_contents高?用法实例
2014/09/24 PHP
thinkphp实现发送邮件密码找回功能实例
2014/12/01 PHP
php解析字符串里所有URL地址的方法
2015/04/03 PHP
Laravel与CI框架中截取字符串函数
2016/05/08 PHP
PHP编程快速实现数组去重的方法详解
2017/07/22 PHP
js本身的局限性 别让javascript做太多事
2010/03/23 Javascript
jQuery 美元符冲突的解决方法
2010/03/28 Javascript
ExtJS4 Grid改变单元格背景颜色及Column render学习
2013/02/06 Javascript
jQuery.extend()、jQuery.fn.extend()扩展方法示例详解
2014/05/08 Javascript
jquery实现LED广告牌旋转系统图片切换效果代码分享
2015/08/26 Javascript
各式各样的导航条效果css3结合jquery代码实现
2016/09/17 Javascript
html判断当前页面是否在iframe中的实例
2016/11/30 Javascript
原生JavaScript实现AJAX、JSONP
2017/02/07 Javascript
使用Vue制作图片轮播组件思路详解
2018/03/21 Javascript
基于element-ui组件手动实现单选和上传功能
2018/12/06 Javascript
如何在Vue.js中实现标签页组件详解
2019/01/02 Javascript
Vue项目使用localStorage+Vuex保存用户登录信息
2019/05/27 Javascript
微信小程序如何使用canvas二维码保存至手机相册
2019/07/15 Javascript
解决layui-open关闭自身窗口的问题
2019/09/10 Javascript
JS如何操作DOM基于表格动态展示数据
2020/10/15 Javascript
[01:34]DAC2018主赛事第四日五佳镜头 Gh巨牙海民助Miracle-死里逃生
2018/04/07 DOTA
利用Python中的输入和输出功能进行读取和写入的教程
2015/04/14 Python
如何使用VSCode愉快的写Python于调试配置步骤
2018/04/06 Python
PyCharm配置mongo插件的方法
2018/11/30 Python
Python 3.8新特征之asyncio REPL
2019/05/28 Python
python删除列表元素的三种方法(remove,pop,del)
2019/07/22 Python
python求一个字符串的所有排列的实现方法
2020/02/04 Python
在pycharm中实现删除bookmark
2020/02/14 Python
茶叶生产计划书
2014/01/10 职场文书
高中军训广播稿
2014/01/14 职场文书
公司清洁工岗位职责
2015/04/15 职场文书
Python3 使用pip安装git并获取Yahoo金融数据的操作
2021/04/08 Python
如何解决goland,idea全局搜索快捷键失效问题
2022/04/03 Golang