详解Pytorch 使用Pytorch拟合多项式(多项式回归)


Posted in Python onMay 24, 2018

使用Pytorch来编写神经网络具有很多优势,比起Tensorflow,我认为Pytorch更加简单,结构更加清晰。

希望通过实战几个Pytorch的例子,让大家熟悉Pytorch的使用方法,包括数据集创建,各种网络层结构的定义,以及前向传播与权重更新方式。

比如这里给出

详解Pytorch 使用Pytorch拟合多项式(多项式回归)    

很显然,这里我们只需要假定

详解Pytorch 使用Pytorch拟合多项式(多项式回归)

这里我们只需要设置一个合适尺寸的全连接网络,根据不断迭代,求出最接近的参数即可。

但是这里需要思考一个问题,使用全连接网络结构是毫无疑问的,但是我们的输入与输出格式是什么样的呢?

只将一个x作为输入合理吗?显然是不合理的,因为每一个神经元其实模拟的是wx+b的计算过程,无法模拟幂运算,所以显然我们需要将x,x的平方,x的三次方,x的四次方组合成一个向量作为输入,假设有n个不同的x值,我们就可以将n个组合向量合在一起组成输入矩阵。

这一步代码如下:

def make_features(x): 
 x = x.unsqueeze(1) 
 return torch.cat([x ** i for i in range(1,4)] , 1)

我们需要生成一些随机数作为网络输入:

def get_batch(batch_size=32): 
 random = torch.randn(batch_size) 
 x = make_features(random) 
 '''Compute the actual results''' 
 y = f(x) 
 if torch.cuda.is_available(): 
  return Variable(x).cuda(), Variable(y).cuda() 
 else: 
  return Variable(x), Variable(y)

其中的f(x)定义如下:

w_target = torch.FloatTensor([0.5,3,2.4]).unsqueeze(1) 
b_target = torch.FloatTensor([0.9]) 
 
def f(x): 
 return x.mm(w_target)+b_target[0]

接下来定义模型:

class poly_model(nn.Module): 
 def __init__(self): 
  super(poly_model, self).__init__() 
  self.poly = nn.Linear(3,1) 
 
 def forward(self, x): 
  out = self.poly(x) 
  return out
if torch.cuda.is_available(): 
 model = poly_model().cuda() 
else: 
 model = poly_model()

接下来我们定义损失函数和优化器:

criterion = nn.MSELoss() 
optimizer = optim.SGD(model.parameters(), lr = 1e-3)

网络部件定义完后,开始训练:

epoch = 0 
while True: 
 batch_x,batch_y = get_batch() 
 output = model(batch_x) 
 loss = criterion(output,batch_y) 
 print_loss = loss.data[0] 
 optimizer.zero_grad() 
 loss.backward() 
 optimizer.step() 
 epoch+=1 
 if print_loss < 1e-3: 
  break

到此我们的所有代码就敲完了,接下来我们开始详细了解一下其中的一些代码。

在make_features()定义中,torch.cat是将计算出的向量拼接成矩阵。unsqueeze是作一个维度上的变化。

get_batch中,torch.randn是产生指定维度的随机数,如果你的机器支持GPU加速,可以将Variable放在GPU上进行运算,类似语句含义相通。

x.mm是作矩阵乘法。

模型定义是重中之重,其实当你掌握Pytorch之后,你会发现模型定义是十分简单的,各种基本的层结构都已经为你封装好了。所有的层结构和损失函数都来自torch.nn,所有的模型构建都是从这个基类 nn.Module继承的。模型定义中,__init__与forward是有模板的,大家可以自己体会。

nn.Linear是做一个线性的运算,参数的含义代表了输入层与输出层的结构,即3*1;在训练阶段,有几行是Pytorch不同于别的框架的,首先loss是一个Variable,通过loss.data可以取出一个Tensor,再通过data[0]可以得到一个int或者float类型的值,我们才可以进行基本运算或者显示。每次计算梯度之前,都需要将梯度归零,否则梯度会叠加。个人觉得别的语句还是比较好懂的,如果有疑问可以在下方评论。

下面是我们的拟合结果

详解Pytorch 使用Pytorch拟合多项式(多项式回归)

其实效果肯定会很好,因为只是一个非常简单的全连接网络,希望大家通过这个小例子可以学到Pytorch的一些基本操作。往后我们会继续更新,完整代码请戳,https://github.com/ZhichaoDuan/PytorchCourse

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用BeautifulSoup分页网页中超链接的方法
Apr 04 Python
Python科学画图代码分享
Nov 29 Python
Python内置模块ConfigParser实现配置读写功能的方法
Feb 12 Python
Python使用Selenium+BeautifulSoup爬取淘宝搜索页
Feb 24 Python
Python使用gRPC传输协议教程
Oct 16 Python
python 使用正则表达式按照多个空格分割字符的实例
Dec 20 Python
给keras层命名,并提取中间层输出值,保存到文档的实例
May 23 Python
Pyinstaller 打包发布经验总结
Jun 02 Python
Python使用pyexecjs代码案例解析
Jul 13 Python
Python如何获取文件路径/目录
Sep 22 Python
python基础详解之if循环语句
Apr 24 Python
深入理解Pytorch微调torchvision模型
Nov 11 Python
Python获取系统所有进程PID及进程名称的方法示例
May 24 #Python
好的Python培训机构应该具备哪些条件
May 23 #Python
Python实现的根据IP地址计算子网掩码位数功能示例
May 23 #Python
Python加载带有注释的Json文件实例
May 23 #Python
Python实现判断一行代码是否为注释的方法
May 23 #Python
对python的文件内注释 help注释方法
May 23 #Python
Python基于生成器迭代实现的八皇后问题示例
May 23 #Python
You might like
磨咖啡豆的密诀
2021/03/03 冲泡冲煮
那些年一起学习的PHP(一)
2012/03/21 PHP
php 模拟get_headers函数的代码示例
2013/04/27 PHP
ThinkPHP框架实现session跨域问题的解决方法
2014/07/01 PHP
PHP中构造函数和析构函数解析
2014/10/10 PHP
PHP实现Soap通讯的方法
2014/11/03 PHP
PHP生成短网址方法汇总
2016/07/12 PHP
PHP实现RSA签名生成订单功能【支付宝示例】
2017/06/06 PHP
什么是PHP文件?如何打开PHP文件?
2017/06/27 PHP
gearman中任务的优先级和返回状态实例分析
2020/02/27 PHP
javascript下阻止表单重复提交、防刷新、防后退
2007/08/17 Javascript
js window.event对象详尽解析
2009/02/17 Javascript
页面只能打开一次Cooike如何实现
2012/12/04 Javascript
jquery封装的对话框简单实现
2013/07/21 Javascript
使用CSS和jQuery模拟select并附提交后取得数据的代码
2013/10/18 Javascript
javascript 上下banner替换具体实现
2013/11/14 Javascript
javascript实现youku的视频代码自适应宽度
2015/05/25 Javascript
在AngularJS中使用jQuery的zTree插件的方法
2016/04/21 Javascript
基于BootStrap Metronic开发框架经验小结【五】Bootstrap File Input文件上传插件的用法详解
2016/05/12 Javascript
jQuery提示插件qTip2用法分析(支持ajax及多种样式)
2016/06/08 Javascript
JavaScript中日常收集常见的10种错误(推荐)
2017/01/08 Javascript
jQuery实现别踩白块儿网页版小游戏
2017/01/18 Javascript
JS 使用 window对象的print方法实现分页打印功能
2018/05/16 Javascript
vue调试工具vue-devtools安装及使用方法
2018/11/07 Javascript
[13:38]2015国际邀请赛中国战队出征仪式
2015/05/29 DOTA
在Python中使用NLTK库实现对词干的提取的教程
2015/04/08 Python
Python匹配中文的正则表达式
2016/05/11 Python
使用 Python 实现简单的 switch/case 语句的方法
2018/09/17 Python
window环境pip切换国内源(pip安装异常缓慢的问题)
2019/12/31 Python
Tensorflow训练模型越来越慢的2种解决方案
2020/02/07 Python
房地产融资计划书
2014/01/10 职场文书
铁路工务反思材料
2014/02/07 职场文书
财务学生的职业生涯发展
2014/02/11 职场文书
优秀家长事迹材料
2014/05/17 职场文书
邹越感恩父母演讲稿
2014/08/28 职场文书
pt-archiver 主键自增
2022/04/26 MySQL