PyTorch搭建多项式回归模型(三)


Posted in Python onMay 22, 2019

PyTorch基础入门三:PyTorch搭建多项式回归模型 

1)理论简介

对于一般的线性回归模型,由于该函数拟合出来的是一条直线,所以精度欠佳,我们可以考虑多项式回归来拟合更多的模型。所谓多项式回归,其本质也是线性回归。也就是说,我们采取的方法是,提高每个属性的次数来增加维度数。比如,请看下面这样的例子:

如果我们想要拟合方程:

PyTorch搭建多项式回归模型(三)

对于输入变量PyTorch搭建多项式回归模型(三)和输出值PyTorch搭建多项式回归模型(三),我们只需要增加其平方项、三次方项系数即可。所以,我们可以设置如下参数方程:

PyTorch搭建多项式回归模型(三)

可以看到,上述方程与线性回归方程并没有本质区别。所以我们可以采用线性回归的方式来进行多项式的拟合。下面请看代码部分。

2)代码实现

当然最先要做的就是导包了,下面需要说明的只有一个:itertools中的count,这个是用来记数用的,其可以记数到无穷,第一个参数是记数的起始值,第二个参数是步长。其内部实现相当于如下代码:

def count(firstval=0, step=1):
 x = firstval
 while 1:
 yield x
 x += step

下面是导包部分代码,这里定义了一个常量POLY_DEGREE = 3用来指定多项式最高次数。

from itertools import count
import torch
import torch.autograd
import torch.nn.functional as F
 
POLY_DEGREE = 3

然后我们需要将数据处理成矩阵的形式:

PyTorch搭建多项式回归模型(三)

在PyTorch里面使用torch.cat()函数来实现Tensor的拼接:

def make_features(x):
 """Builds features i.e. a matrix with columns [x, x^2, x^3, x^4]."""
 x = x.unsqueeze(1)
 return torch.cat([x ** i for i in range(1, POLY_DEGREE+1)], 1)

对于输入的PyTorch搭建多项式回归模型(三)个数据,我们将其扩展成上面矩阵所示的样子。

然后定义出我们需要拟合的多项式,可以随机抽取一个多项式来作为我们的目标多项式。当然,系数PyTorch搭建多项式回归模型(三)和偏置PyTorch搭建多项式回归模型(三)确定了,多项式也就确定了:

W_target = torch.randn(POLY_DEGREE, 1)
b_target = torch.randn(1)
 
def f(x):
 """Approximated function."""
 return x.mm(W_target) + b_target.item()

这里的权重已经定义好了,x.mm(W_target)表示做矩阵乘法,PyTorch搭建多项式回归模型(三)就是每次输入一个PyTorch搭建多项式回归模型(三)得到一个PyTorch搭建多项式回归模型(三)的真实函数。

在训练的时候我们需要采样一些点,可以随机生成一批数据来得到训练集。下面的函数可以让我们每次取batch_size这么多个数据,然后将其转化为矩阵形式,再把这个值通过函数之后的结果也返回作为真实的输出值:

def get_batch(batch_size=32):
 """Builds a batch i.e. (x, f(x)) pair."""
 random = torch.randn(batch_size)
 x = make_features(random)
 y = f(x)
 return x, y

接下来我们需要定义模型,这里采用一种简写的方式定义模型,torch.nn.Linear()表示定义一个线性模型,这里定义了是输入值和目标参数的行数一致(和POLY_DEGREE一致,本次实验中为3),输出值为1的模型。

# Define model
fc = torch.nn.Linear(W_target.size(0), 1)

下面开始训练模型,训练的过程让其不断优化,直到随机取出的batch_size个点中计算出来的均方误差小于0.001为止。

for batch_idx in count(1):
 # Get data
 batch_x, batch_y = get_batch()
 
 # Reset gradients
 fc.zero_grad()
 
 # Forward pass
 output = F.smooth_l1_loss(fc(batch_x), batch_y)
 loss = output.item()
 
 # Backward pass
 output.backward()
 
 # Apply gradients
 for param in fc.parameters():
 param.data.add_(-0.1 * param.grad.data)
 
 # Stop criterion
 if loss < 1e-3:
 break

这样就已经训练出了我们的多项式回归模型,为了方便观察,定义了如下打印函数来打印出我们拟合的多项式表达式:

def poly_desc(W, b):
 """Creates a string description of a polynomial."""
 result = 'y = '
 for i, w in enumerate(W):
 result += '{:+.2f} x^{} '.format(w, len(W) - i)
 result += '{:+.2f}'.format(b[0])
 return result
 
print('Loss: {:.6f} after {} batches'.format(loss, batch_idx))
print('==> Learned function:\t' + poly_desc(fc.weight.view(-1), fc.bias))
print('==> Actual function:\t' + poly_desc(W_target.view(-1), b_target))

程序运行结果如下图所示:

PyTorch搭建多项式回归模型(三)

可以看出,真实的多项式表达式和我们拟合的多项式十分接近。现实世界中很多问题都不是简单的线性回归,涉及到很多复杂的非线性模型。但是我们可以在其特征量上进行研究,改变或者增加其特征,从而将非线性问题转化为线性问题来解决,这种处理问题的思路是我们从多项式回归的算法中应该汲取到的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用MYSQLDB实现从数据库中导出XML文件的方法
May 11 Python
python 调用c语言函数的方法
Sep 29 Python
python之virtualenv的简单使用方法(必看篇)
Nov 25 Python
Python第三方库face_recognition在windows上的安装过程
May 03 Python
python3文件复制、延迟文件复制任务的实现方法
Sep 02 Python
对tensorflow中cifar-10文档的Read操作详解
Feb 10 Python
Python格式化输出--%s,%d,%f的代码解析
Apr 29 Python
python 实现PIL模块在图片画线写字
May 16 Python
python文件读取失败怎么处理
Jun 23 Python
最新pycharm安装教程
Nov 18 Python
pandas apply使用多列计算生成新的列实现示例
Feb 24 Python
opencv深入浅出了解机器学习和深度学习
Mar 17 Python
pytorch使用Variable实现线性回归
May 21 #Python
Python面向对象进阶学习
May 21 #Python
谈一谈基于python的面向对象编程基础
May 21 #Python
python字符串和常用数据结构知识总结
May 21 #Python
Opencv实现抠图背景图替换功能
May 21 #Python
python多进程读图提取特征存npy
May 21 #Python
Python中使用pypdf2合并、分割、加密pdf文件的代码详解
May 21 #Python
You might like
PHP把网页保存为word文件的三种方法
2014/04/01 PHP
自动检查并替换文本框内的字符
2006/06/30 Javascript
jQuery 图像裁剪插件Jcrop的简单使用
2009/05/22 Javascript
jquery下为Event handler传递动态参数的代码
2011/01/06 Javascript
JavaScript对象和字串之间的转换实例探讨
2013/04/21 Javascript
图片Slider 带左右按钮的js示例
2013/08/30 Javascript
火狐下table中创建form导致两个table之间出现空白
2013/09/02 Javascript
js实现带搜索功能的下拉框实时搜索实时匹配
2013/11/05 Javascript
js实现数字每三位加逗号的方法
2015/02/05 Javascript
jQuery 1.9.1源码分析系列(十)事件系统之主动触发事件和模拟冒泡处理
2015/11/24 Javascript
解决Node.js使用MySQL出现connect ECONNREFUSED 127.0.0.1:3306的问题
2017/03/09 Javascript
关于javascript获取内联样式与嵌入式样式的实例
2017/06/01 Javascript
使用Vue构建可重用的分页组件
2018/03/26 Javascript
小程序开发基础之view视图容器
2018/08/21 Javascript
Node.js中package.json中库的版本号(~和^)
2019/04/02 Javascript
layui字体图标 loading图标静止不旋转的解决方法
2019/09/23 Javascript
简单使用webpack打包文件的实现
2019/10/29 Javascript
Vue中引入svg图标的两种方式
2021/01/14 Vue.js
[36:43]NB vs Optic 2018国际邀请赛小组赛BO1 B组加赛 8.19
2018/08/21 DOTA
python使用cStringIO实现临时内存文件访问的方法
2015/03/26 Python
Python抓取电影天堂电影信息的代码
2016/04/07 Python
Sanic框架配置操作分析
2018/07/17 Python
Python访问MongoDB,并且转换成Dataframe的方法
2018/10/15 Python
对python特殊函数 __call__()的使用详解
2019/07/02 Python
对Django中static(静态)文件详解以及{% static %}标签的使用方法
2019/07/28 Python
python3.7+selenium模拟淘宝登录功能的实现
2020/05/26 Python
Python环境配置实现pip加速过程解析
2020/11/27 Python
浅析数据存储的三种方式 cookie sessionstorage localstorage 的异同
2020/06/04 HTML / CSS
作为网站管理者应当如何防范XSS
2014/08/16 面试题
怎样在 Applet 中建立自己的菜单(MenuBar/Menu)?
2012/06/20 面试题
精彩的大学生自我评价
2013/11/17 职场文书
仓管员岗位职责范本
2015/04/01 职场文书
运动会闭幕式致辞
2015/07/29 职场文书
医生行业员工的辞职信
2019/06/24 职场文书
PyTorch的Debug指南
2021/05/07 Python
详解Android中的TimePickerView(时间选择器)的用法
2022/04/30 Java/Android