详解Pytorch 使用Pytorch拟合多项式(多项式回归)


Posted in Python onMay 24, 2018

使用Pytorch来编写神经网络具有很多优势,比起Tensorflow,我认为Pytorch更加简单,结构更加清晰。

希望通过实战几个Pytorch的例子,让大家熟悉Pytorch的使用方法,包括数据集创建,各种网络层结构的定义,以及前向传播与权重更新方式。

比如这里给出

详解Pytorch 使用Pytorch拟合多项式(多项式回归)    

很显然,这里我们只需要假定

详解Pytorch 使用Pytorch拟合多项式(多项式回归)

这里我们只需要设置一个合适尺寸的全连接网络,根据不断迭代,求出最接近的参数即可。

但是这里需要思考一个问题,使用全连接网络结构是毫无疑问的,但是我们的输入与输出格式是什么样的呢?

只将一个x作为输入合理吗?显然是不合理的,因为每一个神经元其实模拟的是wx+b的计算过程,无法模拟幂运算,所以显然我们需要将x,x的平方,x的三次方,x的四次方组合成一个向量作为输入,假设有n个不同的x值,我们就可以将n个组合向量合在一起组成输入矩阵。

这一步代码如下:

def make_features(x): 
 x = x.unsqueeze(1) 
 return torch.cat([x ** i for i in range(1,4)] , 1)

我们需要生成一些随机数作为网络输入:

def get_batch(batch_size=32): 
 random = torch.randn(batch_size) 
 x = make_features(random) 
 '''Compute the actual results''' 
 y = f(x) 
 if torch.cuda.is_available(): 
  return Variable(x).cuda(), Variable(y).cuda() 
 else: 
  return Variable(x), Variable(y)

其中的f(x)定义如下:

w_target = torch.FloatTensor([0.5,3,2.4]).unsqueeze(1) 
b_target = torch.FloatTensor([0.9]) 
 
def f(x): 
 return x.mm(w_target)+b_target[0]

接下来定义模型:

class poly_model(nn.Module): 
 def __init__(self): 
  super(poly_model, self).__init__() 
  self.poly = nn.Linear(3,1) 
 
 def forward(self, x): 
  out = self.poly(x) 
  return out
if torch.cuda.is_available(): 
 model = poly_model().cuda() 
else: 
 model = poly_model()

接下来我们定义损失函数和优化器:

criterion = nn.MSELoss() 
optimizer = optim.SGD(model.parameters(), lr = 1e-3)

网络部件定义完后,开始训练:

epoch = 0 
while True: 
 batch_x,batch_y = get_batch() 
 output = model(batch_x) 
 loss = criterion(output,batch_y) 
 print_loss = loss.data[0] 
 optimizer.zero_grad() 
 loss.backward() 
 optimizer.step() 
 epoch+=1 
 if print_loss < 1e-3: 
  break

到此我们的所有代码就敲完了,接下来我们开始详细了解一下其中的一些代码。

在make_features()定义中,torch.cat是将计算出的向量拼接成矩阵。unsqueeze是作一个维度上的变化。

get_batch中,torch.randn是产生指定维度的随机数,如果你的机器支持GPU加速,可以将Variable放在GPU上进行运算,类似语句含义相通。

x.mm是作矩阵乘法。

模型定义是重中之重,其实当你掌握Pytorch之后,你会发现模型定义是十分简单的,各种基本的层结构都已经为你封装好了。所有的层结构和损失函数都来自torch.nn,所有的模型构建都是从这个基类 nn.Module继承的。模型定义中,__init__与forward是有模板的,大家可以自己体会。

nn.Linear是做一个线性的运算,参数的含义代表了输入层与输出层的结构,即3*1;在训练阶段,有几行是Pytorch不同于别的框架的,首先loss是一个Variable,通过loss.data可以取出一个Tensor,再通过data[0]可以得到一个int或者float类型的值,我们才可以进行基本运算或者显示。每次计算梯度之前,都需要将梯度归零,否则梯度会叠加。个人觉得别的语句还是比较好懂的,如果有疑问可以在下方评论。

下面是我们的拟合结果

详解Pytorch 使用Pytorch拟合多项式(多项式回归)

其实效果肯定会很好,因为只是一个非常简单的全连接网络,希望大家通过这个小例子可以学到Pytorch的一些基本操作。往后我们会继续更新,完整代码请戳,https://github.com/ZhichaoDuan/PytorchCourse

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python写的贪吃蛇游戏例子
Jun 16 Python
python根据文件大小打log日志
Oct 09 Python
Python2.x版本中maketrans()方法的使用介绍
May 19 Python
CentOS 7下Python 2.7升级至Python3.6.1的实战教程
Jul 06 Python
详解Python用户登录接口的方法
Apr 17 Python
Python直接赋值、浅拷贝与深度拷贝实例分析
Jun 18 Python
Python单元测试工具doctest和unittest使用解析
Sep 02 Python
Pyorch之numpy与torch之间相互转换方式
Dec 31 Python
python和pywin32实现窗口查找、遍历和点击的示例代码
Apr 01 Python
python 已知平行四边形三个点,求第四个点的案例
Apr 12 Python
三步解决python PermissionError: [WinError 5]拒绝访问的情况
Apr 22 Python
细说NumPy数组的四种乘法的使用
Dec 18 Python
Python获取系统所有进程PID及进程名称的方法示例
May 24 #Python
好的Python培训机构应该具备哪些条件
May 23 #Python
Python实现的根据IP地址计算子网掩码位数功能示例
May 23 #Python
Python加载带有注释的Json文件实例
May 23 #Python
Python实现判断一行代码是否为注释的方法
May 23 #Python
对python的文件内注释 help注释方法
May 23 #Python
Python基于生成器迭代实现的八皇后问题示例
May 23 #Python
You might like
DSP接收机前端设想
2021/03/02 无线电
php格式化金额函数分享
2015/02/02 PHP
thinkphp3.2.3 分页代码分享
2016/07/28 PHP
&amp;lt;script defer&amp;gt; defer 是什么意思
2009/05/10 Javascript
JQuery的Validation插件中Remote验证的中文问题
2010/07/26 Javascript
纯文字版返回顶端的js代码
2013/08/01 Javascript
javascript制作坦克大战全纪录(1)
2014/11/27 Javascript
JavaScript中的console.trace()函数介绍
2014/12/29 Javascript
JS实现为表格动态添加标题的方法
2015/03/31 Javascript
AngularJS在IE8的不支持的解决方法
2016/05/13 Javascript
Web开发中客户端的跳转与服务器端的跳转的区别
2017/03/05 Javascript
jQuery插件FusionWidgets实现的Cylinder图效果示例【附demo源码】
2017/03/23 jQuery
详解node-ccap模块生成captcha验证码
2017/07/01 Javascript
jQuery修改DOM结构_动力节点Java学院整理
2017/07/05 jQuery
AngularJS与BootStrap模仿百度分页的示例代码
2018/05/23 Javascript
JavaScript设计模式之缓存代理模式原理与简单用法示例
2018/08/07 Javascript
详解package.json版本号规则
2019/08/01 Javascript
Python通过websocket与js客户端通信示例分析
2014/06/25 Python
Python实现统计英文单词个数及字符串分割代码
2015/05/28 Python
Using Django with GAE Python 后台抓取多个网站的页面全文
2016/02/17 Python
python绘制热力图heatmap
2020/03/23 Python
Python跳出多重循环的方法示例
2019/07/03 Python
把vgg-face.mat权重迁移到pytorch模型示例
2019/12/27 Python
Python实现的北京积分落户数据分析示例
2020/03/27 Python
django日志默认打印request请求信息的方法示例
2020/05/17 Python
PyTorch如何搭建一个简单的网络
2020/08/24 Python
超酷炫 CSS3垂直手风琴菜单
2016/06/28 HTML / CSS
自我鉴定怎么写
2014/01/12 职场文书
工作睡觉检讨书
2014/02/25 职场文书
青年志愿者活动总结
2014/04/26 职场文书
护士长2014年终工作总结
2014/11/11 职场文书
大学生团员个人总结
2015/02/14 职场文书
担保书范文
2019/07/09 职场文书
2019年警察入党转正申请书最新范文
2019/09/03 职场文书
微信小程序实现聊天室功能
2021/06/14 Javascript
python APScheduler执行定时任务介绍
2022/04/19 Python