详解Pytorch 使用Pytorch拟合多项式(多项式回归)


Posted in Python onMay 24, 2018

使用Pytorch来编写神经网络具有很多优势,比起Tensorflow,我认为Pytorch更加简单,结构更加清晰。

希望通过实战几个Pytorch的例子,让大家熟悉Pytorch的使用方法,包括数据集创建,各种网络层结构的定义,以及前向传播与权重更新方式。

比如这里给出

详解Pytorch 使用Pytorch拟合多项式(多项式回归)    

很显然,这里我们只需要假定

详解Pytorch 使用Pytorch拟合多项式(多项式回归)

这里我们只需要设置一个合适尺寸的全连接网络,根据不断迭代,求出最接近的参数即可。

但是这里需要思考一个问题,使用全连接网络结构是毫无疑问的,但是我们的输入与输出格式是什么样的呢?

只将一个x作为输入合理吗?显然是不合理的,因为每一个神经元其实模拟的是wx+b的计算过程,无法模拟幂运算,所以显然我们需要将x,x的平方,x的三次方,x的四次方组合成一个向量作为输入,假设有n个不同的x值,我们就可以将n个组合向量合在一起组成输入矩阵。

这一步代码如下:

def make_features(x): 
 x = x.unsqueeze(1) 
 return torch.cat([x ** i for i in range(1,4)] , 1)

我们需要生成一些随机数作为网络输入:

def get_batch(batch_size=32): 
 random = torch.randn(batch_size) 
 x = make_features(random) 
 '''Compute the actual results''' 
 y = f(x) 
 if torch.cuda.is_available(): 
  return Variable(x).cuda(), Variable(y).cuda() 
 else: 
  return Variable(x), Variable(y)

其中的f(x)定义如下:

w_target = torch.FloatTensor([0.5,3,2.4]).unsqueeze(1) 
b_target = torch.FloatTensor([0.9]) 
 
def f(x): 
 return x.mm(w_target)+b_target[0]

接下来定义模型:

class poly_model(nn.Module): 
 def __init__(self): 
  super(poly_model, self).__init__() 
  self.poly = nn.Linear(3,1) 
 
 def forward(self, x): 
  out = self.poly(x) 
  return out
if torch.cuda.is_available(): 
 model = poly_model().cuda() 
else: 
 model = poly_model()

接下来我们定义损失函数和优化器:

criterion = nn.MSELoss() 
optimizer = optim.SGD(model.parameters(), lr = 1e-3)

网络部件定义完后,开始训练:

epoch = 0 
while True: 
 batch_x,batch_y = get_batch() 
 output = model(batch_x) 
 loss = criterion(output,batch_y) 
 print_loss = loss.data[0] 
 optimizer.zero_grad() 
 loss.backward() 
 optimizer.step() 
 epoch+=1 
 if print_loss < 1e-3: 
  break

到此我们的所有代码就敲完了,接下来我们开始详细了解一下其中的一些代码。

在make_features()定义中,torch.cat是将计算出的向量拼接成矩阵。unsqueeze是作一个维度上的变化。

get_batch中,torch.randn是产生指定维度的随机数,如果你的机器支持GPU加速,可以将Variable放在GPU上进行运算,类似语句含义相通。

x.mm是作矩阵乘法。

模型定义是重中之重,其实当你掌握Pytorch之后,你会发现模型定义是十分简单的,各种基本的层结构都已经为你封装好了。所有的层结构和损失函数都来自torch.nn,所有的模型构建都是从这个基类 nn.Module继承的。模型定义中,__init__与forward是有模板的,大家可以自己体会。

nn.Linear是做一个线性的运算,参数的含义代表了输入层与输出层的结构,即3*1;在训练阶段,有几行是Pytorch不同于别的框架的,首先loss是一个Variable,通过loss.data可以取出一个Tensor,再通过data[0]可以得到一个int或者float类型的值,我们才可以进行基本运算或者显示。每次计算梯度之前,都需要将梯度归零,否则梯度会叠加。个人觉得别的语句还是比较好懂的,如果有疑问可以在下方评论。

下面是我们的拟合结果

详解Pytorch 使用Pytorch拟合多项式(多项式回归)

其实效果肯定会很好,因为只是一个非常简单的全连接网络,希望大家通过这个小例子可以学到Pytorch的一些基本操作。往后我们会继续更新,完整代码请戳,https://github.com/ZhichaoDuan/PytorchCourse

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python追加元素到列表的方法
Jul 28 Python
python 创建弹出式菜单的实现代码
Jul 11 Python
Python3 操作符重载方法示例
Nov 23 Python
Python解决N阶台阶走法问题的方法分析
Dec 28 Python
教你用一行Python代码实现并行任务(附代码)
Feb 02 Python
几种实用的pythonic语法实例代码
Feb 24 Python
对Pyhon实现静态变量全局变量的方法详解
Jan 11 Python
Python实现12306火车票抢票系统
Jul 04 Python
python之array赋值技巧分享
Nov 28 Python
Python基于Dlib的人脸识别系统的实现
Feb 26 Python
Python读取excel文件中带公式的值的实现
Apr 17 Python
Django框架安装及项目创建过程解析
Sep 14 Python
Python获取系统所有进程PID及进程名称的方法示例
May 24 #Python
好的Python培训机构应该具备哪些条件
May 23 #Python
Python实现的根据IP地址计算子网掩码位数功能示例
May 23 #Python
Python加载带有注释的Json文件实例
May 23 #Python
Python实现判断一行代码是否为注释的方法
May 23 #Python
对python的文件内注释 help注释方法
May 23 #Python
Python基于生成器迭代实现的八皇后问题示例
May 23 #Python
You might like
透析PHP的配置文件php.ini
2006/10/09 PHP
Thinkphp实现MySQL读写分离操作示例
2014/06/25 PHP
php微信公众开发之获取周边酒店信息的方法
2014/12/22 PHP
php通过执行CutyCapt命令实现网页截图的方法
2016/09/30 PHP
PHP数组操作实例分析【添加,删除,计算,反转,排序,查找等】
2016/12/24 PHP
CI(CodeIgniter)框架实现图片上传的方法
2017/03/24 PHP
20款超赞的jQuery插件 Web开发人员必备
2011/02/26 Javascript
捕获键盘事件(且兼容各浏览器)
2013/07/03 Javascript
jQuery实现点击小图片淡入淡出显示大图片特效
2015/09/09 Javascript
浅析jquery如何判断滚动条滚到页面底部并执行事件
2016/04/29 Javascript
JavaScript结合Bootstrap仿微信后台多图文界面管理
2016/07/22 Javascript
ES6所改良的javascript“缺陷”问题
2016/08/23 Javascript
Vue2递归组件实现树形菜单
2017/04/10 Javascript
Javascript实现数组中的元素上下移动
2017/04/28 Javascript
详解vue移动端日期选择组件
2018/02/22 Javascript
js序列化和反序列化的使用讲解
2019/01/19 Javascript
ES6中字符串的使用方法扩展
2019/06/04 Javascript
layui清除radio的选中状态实例
2019/11/14 Javascript
Vue如何循环提取对象数组中的值
2020/11/18 Vue.js
[08:04]TI4西雅图DOTA2前线报道 海涛探访各路人马
2014/07/09 DOTA
python使用chardet判断字符串编码的方法
2015/03/13 Python
Python实现简单过滤文本段的方法
2017/05/24 Python
一道python走迷宫算法题
2018/01/22 Python
Python实现的维尼吉亚密码算法示例
2018/04/12 Python
超简单使用Python换脸实例
2019/03/27 Python
详解Python 实现 ZeroMQ 的三种基本工作模式
2020/03/24 Python
CSS实现定位元素居中的方法
2015/06/23 HTML / CSS
如何使用html5与css3完成google涂鸦动画
2012/12/16 HTML / CSS
一站式跨境收款解决方案:Payoneer(派安盈)
2018/09/06 全球购物
俄罗斯名牌服装网上商店:UNIQUE FABRIC
2019/07/25 全球购物
Python里面如何实现tuple和list的转换
2012/06/13 面试题
奥林匹克的口号
2014/06/13 职场文书
慈善捐赠倡议书
2014/08/30 职场文书
邓小平理论心得体会
2014/09/09 职场文书
初中历史教学反思
2016/02/19 职场文书
MongoDB数据库常用的10条操作命令
2021/06/18 MongoDB