详解Pytorch 使用Pytorch拟合多项式(多项式回归)


Posted in Python onMay 24, 2018

使用Pytorch来编写神经网络具有很多优势,比起Tensorflow,我认为Pytorch更加简单,结构更加清晰。

希望通过实战几个Pytorch的例子,让大家熟悉Pytorch的使用方法,包括数据集创建,各种网络层结构的定义,以及前向传播与权重更新方式。

比如这里给出

详解Pytorch 使用Pytorch拟合多项式(多项式回归)    

很显然,这里我们只需要假定

详解Pytorch 使用Pytorch拟合多项式(多项式回归)

这里我们只需要设置一个合适尺寸的全连接网络,根据不断迭代,求出最接近的参数即可。

但是这里需要思考一个问题,使用全连接网络结构是毫无疑问的,但是我们的输入与输出格式是什么样的呢?

只将一个x作为输入合理吗?显然是不合理的,因为每一个神经元其实模拟的是wx+b的计算过程,无法模拟幂运算,所以显然我们需要将x,x的平方,x的三次方,x的四次方组合成一个向量作为输入,假设有n个不同的x值,我们就可以将n个组合向量合在一起组成输入矩阵。

这一步代码如下:

def make_features(x): 
 x = x.unsqueeze(1) 
 return torch.cat([x ** i for i in range(1,4)] , 1)

我们需要生成一些随机数作为网络输入:

def get_batch(batch_size=32): 
 random = torch.randn(batch_size) 
 x = make_features(random) 
 '''Compute the actual results''' 
 y = f(x) 
 if torch.cuda.is_available(): 
  return Variable(x).cuda(), Variable(y).cuda() 
 else: 
  return Variable(x), Variable(y)

其中的f(x)定义如下:

w_target = torch.FloatTensor([0.5,3,2.4]).unsqueeze(1) 
b_target = torch.FloatTensor([0.9]) 
 
def f(x): 
 return x.mm(w_target)+b_target[0]

接下来定义模型:

class poly_model(nn.Module): 
 def __init__(self): 
  super(poly_model, self).__init__() 
  self.poly = nn.Linear(3,1) 
 
 def forward(self, x): 
  out = self.poly(x) 
  return out
if torch.cuda.is_available(): 
 model = poly_model().cuda() 
else: 
 model = poly_model()

接下来我们定义损失函数和优化器:

criterion = nn.MSELoss() 
optimizer = optim.SGD(model.parameters(), lr = 1e-3)

网络部件定义完后,开始训练:

epoch = 0 
while True: 
 batch_x,batch_y = get_batch() 
 output = model(batch_x) 
 loss = criterion(output,batch_y) 
 print_loss = loss.data[0] 
 optimizer.zero_grad() 
 loss.backward() 
 optimizer.step() 
 epoch+=1 
 if print_loss < 1e-3: 
  break

到此我们的所有代码就敲完了,接下来我们开始详细了解一下其中的一些代码。

在make_features()定义中,torch.cat是将计算出的向量拼接成矩阵。unsqueeze是作一个维度上的变化。

get_batch中,torch.randn是产生指定维度的随机数,如果你的机器支持GPU加速,可以将Variable放在GPU上进行运算,类似语句含义相通。

x.mm是作矩阵乘法。

模型定义是重中之重,其实当你掌握Pytorch之后,你会发现模型定义是十分简单的,各种基本的层结构都已经为你封装好了。所有的层结构和损失函数都来自torch.nn,所有的模型构建都是从这个基类 nn.Module继承的。模型定义中,__init__与forward是有模板的,大家可以自己体会。

nn.Linear是做一个线性的运算,参数的含义代表了输入层与输出层的结构,即3*1;在训练阶段,有几行是Pytorch不同于别的框架的,首先loss是一个Variable,通过loss.data可以取出一个Tensor,再通过data[0]可以得到一个int或者float类型的值,我们才可以进行基本运算或者显示。每次计算梯度之前,都需要将梯度归零,否则梯度会叠加。个人觉得别的语句还是比较好懂的,如果有疑问可以在下方评论。

下面是我们的拟合结果

详解Pytorch 使用Pytorch拟合多项式(多项式回归)

其实效果肯定会很好,因为只是一个非常简单的全连接网络,希望大家通过这个小例子可以学到Pytorch的一些基本操作。往后我们会继续更新,完整代码请戳,https://github.com/ZhichaoDuan/PytorchCourse

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python的tkinter布局之简单的聊天窗口实现方法
Sep 03 Python
python使用socket进行简单网络连接的方法
Apr 29 Python
Python简明入门教程
Aug 04 Python
Python Requests安装与简单运用
Apr 07 Python
Python多进程multiprocessing用法实例分析
Aug 18 Python
python特性语法之遍历、公共方法、引用
Aug 08 Python
Puppeteer使用示例详解
Jun 20 Python
Python 进程之间共享数据(全局变量)的方法
Jul 16 Python
如何基于Python Matplotlib实现网格动画
Jul 20 Python
matplotlib基础绘图命令之bar的使用方法
Aug 13 Python
用Python提取PDF表格的方法
Apr 11 Python
教你怎么用Python实现GIF动图的提取及合成
Jun 15 Python
Python获取系统所有进程PID及进程名称的方法示例
May 24 #Python
好的Python培训机构应该具备哪些条件
May 23 #Python
Python实现的根据IP地址计算子网掩码位数功能示例
May 23 #Python
Python加载带有注释的Json文件实例
May 23 #Python
Python实现判断一行代码是否为注释的方法
May 23 #Python
对python的文件内注释 help注释方法
May 23 #Python
Python基于生成器迭代实现的八皇后问题示例
May 23 #Python
You might like
PHP定时自动生成静态HTML的实现代码
2010/06/20 PHP
PHPMailer邮件发送的实现代码
2013/05/04 PHP
thinkPHP下ueditor的使用方法详解
2015/12/26 PHP
php socket通信(tcp/udp)实例分析
2016/02/14 PHP
php版银联支付接口开发简明教程
2016/10/14 PHP
注释PHP和html混合代码的小技巧(分享)
2016/11/03 PHP
Dojo 学习笔记入门篇 First Dojo Example
2009/11/15 Javascript
javascript动态加载二
2012/08/22 Javascript
jquery offset函数应用实例
2012/11/14 Javascript
利用JS实现浏览器的title闪烁
2013/07/08 Javascript
js中reverse函数的用法详解
2013/12/26 Javascript
JavaScript中setUTCMilliseconds()方法的使用详解
2015/06/12 Javascript
javascript实现五星评价代码(源码下载)
2015/08/11 Javascript
JS实现无限级网页折叠菜单(类似树形菜单)效果代码
2015/09/17 Javascript
前端面试知识点锦集(JavaScript篇)
2016/12/28 Javascript
JavaScript判断浏览器及其版本信息
2017/01/20 Javascript
微信小程序promsie.all和promise顺序执行
2017/10/27 Javascript
three.js实现3D影院的原理的代码分析
2017/12/18 Javascript
vue-router懒加载速度缓慢问题及解决方法
2018/11/25 Javascript
jQuery实现的3D版图片轮播示例【滑动轮播】
2019/01/18 jQuery
如何利用vue+vue-router+elementUI实现简易通讯录
2019/05/13 Javascript
Javascript新手入门之字符串拼接与变量的应用
2020/12/03 Javascript
pyqt4教程之messagebox使用示例分享
2014/03/07 Python
详解小白之KMP算法及python实现
2019/04/04 Python
python pytest进阶之conftest.py详解
2019/06/27 Python
Django 请求Request的具体使用方法
2019/11/11 Python
Python turtle库绘制菱形的3种方式小结
2019/11/23 Python
selenium与xpath之获取指定位置的元素的实现
2021/01/26 Python
Ratchet 模态框的实现
2020/08/19 HTML / CSS
For Art’s Sake官网:手工制作的奢华眼镜
2018/12/15 全球购物
祖国在我心中演讲稿300字
2014/05/04 职场文书
煤矿安全演讲稿
2014/05/09 职场文书
党的群众路线教育实践活动通讯稿
2014/09/10 职场文书
主题班会开场白
2015/06/01 职场文书
MySQL5.7并行复制原理及实现
2021/06/03 MySQL
入门学习Go的基本语法
2021/07/07 Golang