PyTorch搭建多项式回归模型(三)


Posted in Python onMay 22, 2019

PyTorch基础入门三:PyTorch搭建多项式回归模型 

1)理论简介

对于一般的线性回归模型,由于该函数拟合出来的是一条直线,所以精度欠佳,我们可以考虑多项式回归来拟合更多的模型。所谓多项式回归,其本质也是线性回归。也就是说,我们采取的方法是,提高每个属性的次数来增加维度数。比如,请看下面这样的例子:

如果我们想要拟合方程:

PyTorch搭建多项式回归模型(三)

对于输入变量PyTorch搭建多项式回归模型(三)和输出值PyTorch搭建多项式回归模型(三),我们只需要增加其平方项、三次方项系数即可。所以,我们可以设置如下参数方程:

PyTorch搭建多项式回归模型(三)

可以看到,上述方程与线性回归方程并没有本质区别。所以我们可以采用线性回归的方式来进行多项式的拟合。下面请看代码部分。

2)代码实现

当然最先要做的就是导包了,下面需要说明的只有一个:itertools中的count,这个是用来记数用的,其可以记数到无穷,第一个参数是记数的起始值,第二个参数是步长。其内部实现相当于如下代码:

def count(firstval=0, step=1):
 x = firstval
 while 1:
 yield x
 x += step

下面是导包部分代码,这里定义了一个常量POLY_DEGREE = 3用来指定多项式最高次数。

from itertools import count
import torch
import torch.autograd
import torch.nn.functional as F
 
POLY_DEGREE = 3

然后我们需要将数据处理成矩阵的形式:

PyTorch搭建多项式回归模型(三)

在PyTorch里面使用torch.cat()函数来实现Tensor的拼接:

def make_features(x):
 """Builds features i.e. a matrix with columns [x, x^2, x^3, x^4]."""
 x = x.unsqueeze(1)
 return torch.cat([x ** i for i in range(1, POLY_DEGREE+1)], 1)

对于输入的PyTorch搭建多项式回归模型(三)个数据,我们将其扩展成上面矩阵所示的样子。

然后定义出我们需要拟合的多项式,可以随机抽取一个多项式来作为我们的目标多项式。当然,系数PyTorch搭建多项式回归模型(三)和偏置PyTorch搭建多项式回归模型(三)确定了,多项式也就确定了:

W_target = torch.randn(POLY_DEGREE, 1)
b_target = torch.randn(1)
 
def f(x):
 """Approximated function."""
 return x.mm(W_target) + b_target.item()

这里的权重已经定义好了,x.mm(W_target)表示做矩阵乘法,PyTorch搭建多项式回归模型(三)就是每次输入一个PyTorch搭建多项式回归模型(三)得到一个PyTorch搭建多项式回归模型(三)的真实函数。

在训练的时候我们需要采样一些点,可以随机生成一批数据来得到训练集。下面的函数可以让我们每次取batch_size这么多个数据,然后将其转化为矩阵形式,再把这个值通过函数之后的结果也返回作为真实的输出值:

def get_batch(batch_size=32):
 """Builds a batch i.e. (x, f(x)) pair."""
 random = torch.randn(batch_size)
 x = make_features(random)
 y = f(x)
 return x, y

接下来我们需要定义模型,这里采用一种简写的方式定义模型,torch.nn.Linear()表示定义一个线性模型,这里定义了是输入值和目标参数的行数一致(和POLY_DEGREE一致,本次实验中为3),输出值为1的模型。

# Define model
fc = torch.nn.Linear(W_target.size(0), 1)

下面开始训练模型,训练的过程让其不断优化,直到随机取出的batch_size个点中计算出来的均方误差小于0.001为止。

for batch_idx in count(1):
 # Get data
 batch_x, batch_y = get_batch()
 
 # Reset gradients
 fc.zero_grad()
 
 # Forward pass
 output = F.smooth_l1_loss(fc(batch_x), batch_y)
 loss = output.item()
 
 # Backward pass
 output.backward()
 
 # Apply gradients
 for param in fc.parameters():
 param.data.add_(-0.1 * param.grad.data)
 
 # Stop criterion
 if loss < 1e-3:
 break

这样就已经训练出了我们的多项式回归模型,为了方便观察,定义了如下打印函数来打印出我们拟合的多项式表达式:

def poly_desc(W, b):
 """Creates a string description of a polynomial."""
 result = 'y = '
 for i, w in enumerate(W):
 result += '{:+.2f} x^{} '.format(w, len(W) - i)
 result += '{:+.2f}'.format(b[0])
 return result
 
print('Loss: {:.6f} after {} batches'.format(loss, batch_idx))
print('==> Learned function:\t' + poly_desc(fc.weight.view(-1), fc.bias))
print('==> Actual function:\t' + poly_desc(W_target.view(-1), b_target))

程序运行结果如下图所示:

PyTorch搭建多项式回归模型(三)

可以看出,真实的多项式表达式和我们拟合的多项式十分接近。现实世界中很多问题都不是简单的线性回归,涉及到很多复杂的非线性模型。但是我们可以在其特征量上进行研究,改变或者增加其特征,从而将非线性问题转化为线性问题来解决,这种处理问题的思路是我们从多项式回归的算法中应该汲取到的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解python 字符串和日期之间转换 StringAndDate
May 04 Python
Python递归实现汉诺塔算法示例
Mar 19 Python
Django中反向生成models.py的实例讲解
May 30 Python
python计算阶乘和的方法(1!+2!+3!+...+n!)
Feb 01 Python
pandas实现将dataframe满足某一条件的值选出
Jun 12 Python
python实时检测键盘输入函数的示例
Jul 17 Python
pyinstaller打包opencv和numpy程序运行错误解决
Aug 16 Python
python 如何将数据写入本地txt文本文件的实现方法
Sep 11 Python
python对验证码降噪的实现示例代码
Nov 12 Python
python GUI库图形界面开发之PyQt5结合Qt Designer创建信号与槽的详细方法与实例
Mar 08 Python
python实现数字炸弹游戏程序
Jul 17 Python
详解基于Scrapy的IP代理池搭建
Sep 29 Python
pytorch使用Variable实现线性回归
May 21 #Python
Python面向对象进阶学习
May 21 #Python
谈一谈基于python的面向对象编程基础
May 21 #Python
python字符串和常用数据结构知识总结
May 21 #Python
Opencv实现抠图背景图替换功能
May 21 #Python
python多进程读图提取特征存npy
May 21 #Python
Python中使用pypdf2合并、分割、加密pdf文件的代码详解
May 21 #Python
You might like
php验证手机号码(支持归属地查询及编码为UTF8)
2013/02/01 PHP
一致性哈希算法以及其PHP实现详细解析
2013/08/24 PHP
php socket通信(tcp/udp)实例分析
2016/02/14 PHP
HTTP状态代码以及定义(解释)
2007/02/02 Javascript
Javascript 模式实例 观察者模式
2009/10/24 Javascript
Javascript结合css实现网页换肤功能
2009/11/02 Javascript
javascript 不间断的图片滚动并可点击
2010/01/15 Javascript
实现点击列表弹出列表索引的两种方式
2013/03/08 Javascript
JS中数组Array的用法示例介绍
2014/02/20 Javascript
用jquery模仿的a的title属性的例子
2014/10/22 Javascript
ECharts仪表盘实例代码(附源码下载)
2016/02/18 Javascript
纯JavaScript 实现flappy bird小游戏实例代码
2016/09/27 Javascript
Javascript中arguments对象的详解与使用方法
2016/10/04 Javascript
微信js-sdk界面操作接口用法示例
2016/10/12 Javascript
bootstrap laydate日期组件使用详解
2017/01/04 Javascript
jQuery验证表单格式的使用方法
2017/01/10 Javascript
详解react内联样式使用webpack将px转rem
2018/09/13 Javascript
如何在基于vue-cli的项目自定义打包环境
2018/11/10 Javascript
[45:16]完美世界DOTA2联赛循环赛 IO vs FTD BO2第二场 11.05
2020/11/06 DOTA
[01:06:54]DOTA2-DPC中国联赛 正赛 SAG vs DLG BO3 第二场 2月28日
2021/03/11 DOTA
Python实现Tab自动补全和历史命令管理的方法
2015/03/12 Python
Python制作刷网页流量工具
2017/04/23 Python
Python logging模块用法示例
2018/08/28 Python
[原创]Python入门教程1. 基本运算【四则运算、变量、math模块等】
2018/10/28 Python
python爬虫开发之Beautiful Soup模块从安装到详细使用方法与实例
2020/03/09 Python
TensorFlow keras卷积神经网络 添加L2正则化方式
2020/05/22 Python
Move Free官方海外旗舰店:美国骨关节健康专业品牌
2017/12/06 全球购物
英国Office鞋店德国网站:在线购买鞋子、靴子和运动鞋
2018/12/19 全球购物
英国美发和美容产品商城:HQhair
2019/02/08 全球购物
精选鞋类、服装和配饰的全球领先目的地:Bodega
2021/02/27 全球购物
德国富尔达运动鞋店:43einhalb
2020/12/25 全球购物
师范毕业生求职自荐信
2013/09/25 职场文书
应届生英语教师求职信
2013/11/05 职场文书
红领巾心向党广播稿
2014/01/19 职场文书
小学少先队活动方案
2014/02/18 职场文书
2015年房地产销售工作总结
2015/04/20 职场文书