使用Pytorch来拟合函数方式


Posted in Python onJanuary 14, 2020

其实各大深度学习框架背后的原理都可以理解为拟合一个参数数量特别庞大的函数,所以各框架都能用来拟合任意函数,Pytorch也能。

在这篇博客中,就以拟合y = ax + b为例(a和b为需要拟合的参数),说明在Pytorch中如何拟合一个函数。

一、定义拟合网络

1、观察普通的神经网络的优化流程

# 定义网络
net = ...
# 定义优化器
optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=0.0005)
# 定义损失函数
loss_op = torch.nn.MSELoss(reduction='sum')
# 优化
for step, (inputs, tag) in enumerate(dataset_loader):
 # 向前传播
 outputs = net(inputs)
 # 计算损失
 loss = loss_op(tag, outputs)
 # 清空梯度
 optimizer.zero_grad()
 # 向后传播
 loss.backward()
 # 更新梯度
 optimizer.step()

上面的代码就是一般情况下的流程。为了能使用Pytorch内置的优化器,所以我们需要定义一个一个网络,实现函数parameters(返回需要优化的参数)和forward(向前传播);为了能支持GPU优化,还需要实现cuda和cpu两个函数,把参数从内存复制到GPU上和从GPU复制回内存。

基于以上要求,网络的定义就类似于:

class Net:
  def __init__(self):
    # 在这里定义要求的参数
    pass

  def cuda(self):
    # 传输参数到GPU
    pass

  def cpu(self):
    # 把参数传输回内存
    pass

  def forward(self, inputs):
   # 实现向前传播,就是根据输入inputs计算一遍输出
    pass

  def parameters(self):
   # 返回参数
    pass

在拟合数据量很大时,还可以使用GPU来加速;如果没有英伟达显卡,则可以不实现cuda和cpu这两个函数。

2、初始化网络

回顾本文目的,拟合: y = ax + b, 所以在__init__函数中就需要定义a和b两个参数,另外为了实现parameters、cpu和cuda,还需要定义属性__parameters和__gpu:

def __init__(self):
    # y = a*x + b
    self.a = torch.rand(1, requires_grad=True) # 参数a
    self.b = torch.rand(1, requires_grad=True) # 参数b
    self.__parameters = dict(a=self.a, b=self.b) # 参数字典
    self.___gpu = False # 是否使用gpu来拟合

要拟合的参数,不能初始化为0! ,一般使用随机值即可。还需要把requires_grad参数设置为True,这是为了支持向后传播。

3、实现向前传播

def forward(self, inputs):
    return self.a * inputs + self.b

非常的简单,就是根据输入inputs计算一遍输出,在本例中,就是计算一下 y = ax + b。计算完了要记得返回计算的结果。

4、把参数传送到GPU

为了支持GPU来加速拟合,需要把参数传输到GPU,且需要更新参数字典__parameters:

def cuda(self):
    if not self.___gpu:
      self.a = self.a.cuda().detach().requires_grad_(True) # 把a传输到gpu
      self.b = self.b.cuda().detach().requires_grad_(True) # 把b传输到gpu
      self.__parameters = dict(a=self.a, b=self.b) # 更新参数
      self.___gpu = True # 更新标志,表示参数已经传输到gpu了
    # 返回self,以支持链式调用
    return self

参数a和b,都是先调用detach再调用requires_grad_,是为了避免错误raise ValueError("can't optimize a non-leaf Tensor")(参考:ValueError: can't optimize a non-leaf Tensor?)。

4、把参数传输回内存

类似于cuda函数,不做过多解释。

def cpu(self):
    if self.___gpu:
      self.a = self.a.cpu().detach().requires_grad_(True)
      self.b = self.b.cpu().detach().requires_grad_(True)
      self.__parameters = dict(a=self.a, b=self.b)
      self.___gpu = False
    return self

5、返回网络参数

为了能使用Pytorch内置的优化器,就要实现parameters函数,观察Pytorch里面的实现:

def parameters(self, recurse=True):
    r"""...
    """
    for name, param in self.named_parameters(recurse=recurse):
      yield param

实际上就是使用yield返回网络的所有参数,因此本例中的实现如下:

def parameters(self):
    for name, param in self.__parameters.items():
      yield param

完整的实现将会放在后面。

二、测试

1、生成测试数据

def main():
  # 生成虚假数据
  x = np.linspace(1, 50, 50)
  # 系数a、b
  a = 2
  b = 1
  # 生成y
  y = a * x + b
  # 转换为Tensor
  x = torch.from_numpy(x.astype(np.float32))
  y = torch.from_numpy(y.astype(np.float32))

2、定义网络

# 定义网络
  net = Net()
  # 定义优化器
  optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=0.0005)
  # 定义损失函数
  loss_op = torch.nn.MSELoss(reduction='sum')

3、把数据传输到GPU(可选)

# 传输到GPU
  if torch.cuda.is_available():
    x = x.cuda()
    y = y.cuda()
    net = net.cuda()

4、定义优化器和损失函数

如果要使用GPU加速,优化器必须要在网络的参数传输到GPU之后在定义,否则优化器里的参数还是内存里的那些参数,传到GPU里面的参数不能被更新。 可以根据代码来理解这句话。

# 定义优化器
  optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=0.0005)
  # 定义损失函数
  loss_op = torch.nn.MSELoss(reduction='sum')

5、拟合(也是优化)

# 最多优化20001次
  for i in range(1, 20001, 1):
   # 向前传播
    out = net.forward(x)
 # 计算损失
    loss = loss_op(y, out)
 # 清空梯度(非常重要)
    optimizer.zero_grad()
 # 向后传播,计算梯度
    loss.backward()
 # 更新参数
    optimizer.step()
 # 得到损失的numpy值
    loss_numpy = loss.cpu().detach().numpy()
    if i % 1000 == 0: # 每1000次打印一下损失
      print(i, loss_numpy)

    if loss_numpy < 0.00001: # 如果损失小于0.00001
     # 打印参数
     a = net.a.cpu().detach().numpy()
     b = net.b.cpu().detach().numpy()
      print(a, b)
      # 退出
      exit()

6、完整示例代码

# coding=utf-8
from __future__ import absolute_import, division, print_function
import torch
import numpy as np


class Net:
  def __init__(self):
    # y = a*x + b
    self.a = torch.rand(1, requires_grad=True) # 参数a
    self.b = torch.rand(1, requires_grad=True) # 参数b
    self.__parameters = dict(a=self.a, b=self.b) # 参数字典
    self.___gpu = False # 是否使用gpu来拟合

  def cuda(self):
    if not self.___gpu:
      self.a = self.a.cuda().detach().requires_grad_(True) # 把a传输到gpu
      self.b = self.b.cuda().detach().requires_grad_(True) # 把b传输到gpu
      self.__parameters = dict(a=self.a, b=self.b) # 更新参数
      self.___gpu = True # 更新标志,表示参数已经传输到gpu了
    # 返回self,以支持链式调用
    return self

  def cpu(self):
    if self.___gpu:
      self.a = self.a.cpu().detach().requires_grad_(True)
      self.b = self.b.cpu().detach().requires_grad_(True)
      self.__parameters = dict(a=self.a, b=self.b) # 更新参数
      self.___gpu = False
    return self

  def forward(self, inputs):
    return self.a * inputs + self.b

  def parameters(self):
    for name, param in self.__parameters.items():
      yield param


def main():

  # 生成虚假数据
  x = np.linspace(1, 50, 50)

  # 系数a、b
  a = 2
  b = 1

  # 生成y
  y = a * x + b

  # 转换为Tensor
  x = torch.from_numpy(x.astype(np.float32))
  y = torch.from_numpy(y.astype(np.float32))

  # 定义网络
  net = Net()

  # 传输到GPU
  if torch.cuda.is_available():
    x = x.cuda()
    y = y.cuda()
    net = net.cuda()

  # 定义优化器
  optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=0.0005)

  # 定义损失函数
  loss_op = torch.nn.MSELoss(reduction='sum')

  # 最多优化20001次
  for i in range(1, 20001, 1):
    # 向前传播
    out = net.forward(x)
    # 计算损失
    loss = loss_op(y, out)
    # 清空梯度(非常重要)
    optimizer.zero_grad()
    # 向后传播,计算梯度
    loss.backward()
    # 更新参数
    optimizer.step()
    # 得到损失的numpy值
    loss_numpy = loss.cpu().detach().numpy()
    if i % 1000 == 0: # 每1000次打印一下损失
      print(i, loss_numpy)

    if loss_numpy < 0.00001: # 如果损失小于0.00001
      # 打印参数
      a = net.a.cpu().detach().numpy()
      b = net.b.cpu().detach().numpy()
      print(a, b)
      # 退出
      exit()


if __name__ == '__main__':
  main()

以上这篇使用Pytorch来拟合函数方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
50行代码实现贪吃蛇(具体思路及代码)
Apr 27 Python
在Apache服务器上同时运行多个Django程序的方法
Jul 22 Python
python实现将内容分行输出
Nov 05 Python
python使用matplotlib绘制折线图教程
Feb 08 Python
python range()函数取反序遍历sequence的方法
Jun 25 Python
Python实现FTP文件传输的实例
Jul 07 Python
Python一键安装全部依赖包的方法
Aug 12 Python
Django中自定义模型管理器(Manager)及方法
Sep 23 Python
使用Python画出小人发射爱心的代码
Nov 23 Python
Python2与Python3的区别点整理
Dec 12 Python
PyQt5 界面显示无响应的实现
Mar 26 Python
关于python的缩进规则的知识点详解
Jun 22 Python
pytorch 模拟关系拟合——回归实例
Jan 14 #Python
PyTorch实现AlexNet示例
Jan 14 #Python
Pytorch 实现focal_loss 多类别和二分类示例
Jan 14 #Python
Python实现钉钉订阅消息功能
Jan 14 #Python
Python Tensor FLow简单使用方法实例详解
Jan 14 #Python
Python利用全连接神经网络求解MNIST问题详解
Jan 14 #Python
基于pytorch的lstm参数使用详解
Jan 14 #Python
You might like
邮箱正则表达式实现代码(针对php)
2013/06/21 PHP
php生成shtml类用法实例
2014/12/09 PHP
Laravel中七个非常有用但很少人知道的Carbon方法
2017/09/21 PHP
基于swoole实现多人聊天室
2018/06/14 PHP
javascript判断ie浏览器6/7版本加载不同样式表的实现代码
2011/12/26 Javascript
jquery解决图片路径不存在执行替换路径
2013/02/06 Javascript
HTML页面登录时的JS验证方法
2014/05/28 Javascript
javascript 回调函数详解
2014/11/11 Javascript
jQuery实现平滑滚动到指定锚点的方法
2015/03/20 Javascript
JavaScript和HTML DOM的区别与联系及Javascript和DOM的关系
2015/11/15 Javascript
javascript的几种写法总结
2016/09/30 Javascript
常用的js方法合集
2017/03/10 Javascript
详解AngularJs HTTP响应拦截器实现登陆、权限校验
2017/04/11 Javascript
深入理解js 中async 函数的含义和用法
2018/05/13 Javascript
详解javascript设计模式三:代理模式
2019/03/25 Javascript
JS实现简单贪吃蛇小游戏
2020/10/28 Javascript
vant时间控件使用方法详解
2020/12/24 Javascript
Python中格式化format()方法详解
2017/04/01 Python
利用python微信库itchat实现微信自动回复功能
2017/05/18 Python
python事件驱动event实现详解
2018/11/21 Python
Python中的heapq模块源码详析
2019/01/08 Python
python函数map()和partial()的知识点总结
2020/05/26 Python
使用Html5、CSS实现文字阴影效果
2018/01/17 HTML / CSS
香港优质食材和美酒专门店:FoodWise
2017/09/01 全球购物
三个Unix的命令面试题
2015/04/12 面试题
一道Delphi面试题
2016/10/28 面试题
大学生优秀团员事迹材料
2014/01/30 职场文书
技术比武方案
2014/05/19 职场文书
关于随地扔垃圾的检讨书
2014/09/30 职场文书
2014班子成员自我剖析材料思想汇报
2014/10/01 职场文书
2014年纳税评估工作总结
2014/12/23 职场文书
银行员工考核评语
2014/12/31 职场文书
2014年终个人总结报告
2015/03/09 职场文书
MySQL之DML语言
2021/04/05 MySQL
vue实现移动端div拖动效果
2022/03/03 Vue.js
Win Server2016远程桌面如何允许多用户同时登录
2022/06/10 Servers