PyTorch搭建一维线性回归模型(二)


Posted in Python onMay 22, 2019

PyTorch基础入门二:PyTorch搭建一维线性回归模型

1)一维线性回归模型的理论基础

给定数据集PyTorch搭建一维线性回归模型(二),线性回归希望能够优化出一个好的函数PyTorch搭建一维线性回归模型(二),使得PyTorch搭建一维线性回归模型(二)能够和PyTorch搭建一维线性回归模型(二)尽可能接近。

如何才能学习到参数PyTorch搭建一维线性回归模型(二)PyTorch搭建一维线性回归模型(二)呢?很简单,只需要确定如何衡量PyTorch搭建一维线性回归模型(二)PyTorch搭建一维线性回归模型(二)之间的差别,我们一般通过损失函数(Loss Funciton)来衡量:PyTorch搭建一维线性回归模型(二)。取平方是因为距离有正有负,我们于是将它们变为全是正的。这就是著名的均方误差。我们要做的事情就是希望能够找到PyTorch搭建一维线性回归模型(二)PyTorch搭建一维线性回归模型(二),使得:

PyTorch搭建一维线性回归模型(二)

PyTorch搭建一维线性回归模型(二)

均方差误差非常直观,也有着很好的几何意义,对应了常用的欧式距离。现在要求解这个连续函数的最小值,我们很自然想到的方法就是求它的偏导数,让它的偏导数等于0来估计它的参数,即:

PyTorch搭建一维线性回归模型(二)

PyTorch搭建一维线性回归模型(二)

求解以上两式,我们就可以得到最优解。

2)代码实现

首先,我们需要“制造”出一些数据集:

import torch
import matplotlib.pyplot as plt
 
 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = 3*x + 10 + torch.rand(x.size())
# 上面这行代码是制造出接近y=3x+10的数据集,后面加上torch.rand()函数制造噪音
 
# 画图
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

我们想要拟合的一维回归模型是PyTorch搭建一维线性回归模型(二)。上面制造的数据集也是比较接近这个模型的,但是为了达到学习效果,人为地加上了torch.rand()值增加一些干扰。

上面人为制造出来的数据集的分布如下:

PyTorch搭建一维线性回归模型(二)

有了数据,我们就要开始定义我们的模型,这里定义的是一个输入层和输出层都只有一维的模型,并且使用了“先判断后使用”的基本结构来合理使用GPU加速。

class LinearRegression(nn.Module):
  def __init__(self):
    super(LinearRegression, self).__init__()
    self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
  def forward(self, x):
    out = self.linear(x)
    return out
 
if torch.cuda.is_available():
  model = LinearRegression().cuda()
else:
  model = LinearRegression()

然后我们定义出损失函数和优化函数,这里使用均方误差作为损失函数,使用梯度下降进行优化:

criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

接下来,开始进行模型的训练。

num_epochs = 1000
for epoch in range(num_epochs):
  if torch.cuda.is_available():
    inputs = Variable(x).cuda()
    target = Variable(y).cuda()
  else:
    inputs = Variable(x)
    target = Variable(y)
 
  # 向前传播
  out = model(inputs)
  loss = criterion(out, target)
 
  # 向后传播
  optimizer.zero_grad() # 注意每次迭代都需要清零
  loss.backward()
  optimizer.step()
 
  if (epoch+1) %20 == 0:
    print('Epoch[{}/{}], loss:{:.6f}'.format(epoch+1, num_epochs, loss.data[0]))

首先定义了迭代的次数,这里为1000次,先向前传播计算出损失函数,然后向后传播计算梯度,这里需要注意的是,每次计算梯度前都要记得将梯度归零,不然梯度会累加到一起造成结果不收敛。为了便于看到结果,每隔一段时间输出当前的迭代轮数和损失函数。

接下来,我们通过model.eval()函数将模型变为测试模式,然后将数据放入模型中进行预测。最后,通过画图工具matplotlib看一下我们拟合的结果,代码如下:

model.eval()
if torch.cuda.is_available():
  predict = model(Variable(x).cuda())
  predict = predict.data.cpu().numpy()
else:
  predict = model(Variable(x))
  predict = predict.data.numpy()
plt.plot(x.numpy(), y.numpy(), 'ro', label='Original Data')
plt.plot(x.numpy(), predict, label='Fitting Line')
plt.show()

其拟合结果如下图:

PyTorch搭建一维线性回归模型(二)

附上完整代码:

# !/usr/bin/python
# coding: utf8
# @Time  : 2018-07-28 18:40
# @Author : Liam
# @Email  : luyu.real@qq.com
# @Software: PyCharm
#            .::::.
#           .::::::::.
#           :::::::::::
#         ..:::::::::::'
#        '::::::::::::'
#         .::::::::::
#      '::::::::::::::..
#         ..::::::::::::.
#        ``::::::::::::::::
#        ::::``:::::::::'    .:::.
#        ::::'  ':::::'    .::::::::.
#       .::::'   ::::   .:::::::'::::.
#      .:::'    ::::: .:::::::::' ':::::.
#      .::'    :::::.:::::::::'   ':::::.
#     .::'     ::::::::::::::'     ``::::.
#   ...:::      ::::::::::::'       ``::.
#   ```` ':.     ':::::::::'         ::::..
#            '.:::::'          ':'````..
#           美女保佑 永无BUG
 
import torch
from torch.autograd import Variable
import numpy as np
import random
import matplotlib.pyplot as plt
from torch import nn
 
 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = 3*x + 10 + torch.rand(x.size())
# 上面这行代码是制造出接近y=3x+10的数据集,后面加上torch.rand()函数制造噪音
 
# 画图
# plt.scatter(x.data.numpy(), y.data.numpy())
# plt.show()
class LinearRegression(nn.Module):
  def __init__(self):
    super(LinearRegression, self).__init__()
    self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
  def forward(self, x):
    out = self.linear(x)
    return out
 
if torch.cuda.is_available():
  model = LinearRegression().cuda()
else:
  model = LinearRegression()
 
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
 
num_epochs = 1000
for epoch in range(num_epochs):
  if torch.cuda.is_available():
    inputs = Variable(x).cuda()
    target = Variable(y).cuda()
  else:
    inputs = Variable(x)
    target = Variable(y)
 
  # 向前传播
  out = model(inputs)
  loss = criterion(out, target)
 
  # 向后传播
  optimizer.zero_grad() # 注意每次迭代都需要清零
  loss.backward()
  optimizer.step()
 
  if (epoch+1) %20 == 0:
    print('Epoch[{}/{}], loss:{:.6f}'.format(epoch+1, num_epochs, loss.data[0]))
model.eval()
if torch.cuda.is_available():
  predict = model(Variable(x).cuda())
  predict = predict.data.cpu().numpy()
else:
  predict = model(Variable(x))
  predict = predict.data.numpy()
plt.plot(x.numpy(), y.numpy(), 'ro', label='Original Data')
plt.plot(x.numpy(), predict, label='Fitting Line')
plt.show()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现将html表格转换成CSV文件的方法
Jun 28 Python
Python的Django框架中的表单处理示例
Jul 17 Python
python操作mysql代码总结
Jun 01 Python
对python制作自己的数据集实例讲解
Dec 12 Python
CentOS6.9 Python环境配置(python2.7、pip、virtualenv)
May 06 Python
Python 20行简单实现有道在线翻译的详解
May 15 Python
Python OpenCV之图片缩放的实现(cv2.resize)
Jun 28 Python
python 浅谈serial与stm32通信的编码问题
Dec 18 Python
Python爬虫库requests获取响应内容、响应状态码、响应头
Jan 25 Python
python实现sm2和sm4国密(国家商用密码)算法的示例
Sep 26 Python
python requests库的使用
Jan 06 Python
opencv深入浅出了解机器学习和深度学习
Mar 17 Python
PyTorch基本数据类型(一)
May 22 #Python
PyTorch搭建多项式回归模型(三)
May 22 #Python
pytorch使用Variable实现线性回归
May 21 #Python
Python面向对象进阶学习
May 21 #Python
谈一谈基于python的面向对象编程基础
May 21 #Python
python字符串和常用数据结构知识总结
May 21 #Python
Opencv实现抠图背景图替换功能
May 21 #Python
You might like
关于PHP语言构造器介绍
2013/07/08 PHP
php使用curl发送json格式数据实例
2013/12/17 PHP
WordPress开发中的get_post_custom()函数使用解析
2016/01/04 PHP
Yii2 assets清除缓存的方法
2016/05/16 PHP
既简单又安全的PHP验证码 附调用方法
2016/06/02 PHP
Laravel框架实现model层的增删改查(CURD)操作示例
2018/05/12 PHP
js操作label给label赋值及取label的值示例
2013/11/07 Javascript
nodeType属性返回被选节点的节点类型介绍
2013/11/22 Javascript
详解JavaScript中shift()方法的使用
2015/06/09 Javascript
js实现带农历和八字等信息的日历特效
2016/05/16 Javascript
jQuery实现页面下拉100像素出现悬浮窗口的方法
2016/09/05 Javascript
vuejs如何配置less
2017/04/25 Javascript
vue系列之动态路由详解【原创】
2017/09/10 Javascript
JS中的两种数据类型及实现引用类型的深拷贝的方法
2018/08/12 Javascript
VUE 直接通过JS 修改html对象的值导致没有更新到数据中解决方法分析
2019/12/02 Javascript
Node.js设置定时任务之node-schedule模块的使用详解
2020/04/28 Javascript
vue 扩展现有组件的操作
2020/08/14 Javascript
Python的collections模块中namedtuple结构使用示例
2016/07/07 Python
python算法演练_One Rule 算法(详解)
2017/05/17 Python
pandas数据预处理之dataframe的groupby操作方法
2018/04/13 Python
python实现定时提取实时日志程序
2018/06/22 Python
对Python函数设计规范详解
2019/07/19 Python
pandas通过字典生成dataframe的方法步骤
2019/07/23 Python
python点击鼠标获取坐标(Graphics)
2019/08/10 Python
Python模块汇总(常用第三方库)
2019/10/07 Python
python实现一个猜拳游戏
2020/04/05 Python
Python分析微信好友性别比例和省份城市分布比例的方法示例【基于itchat模块】
2020/05/29 Python
为什么称python为胶水语言
2020/06/16 Python
简单整理HTML5的基本特性和语法
2016/02/18 HTML / CSS
关于canvas.toDataURL 在iOS运行失败的问题解决
2020/09/16 HTML / CSS
施华洛世奇巴西官网:SWAROVSKI巴西
2019/12/03 全球购物
五月的鲜花活动方案
2014/08/21 职场文书
后进生评语大全
2015/01/04 职场文书
课外活动实习计划
2015/01/19 职场文书
煤矿隐患排查制度
2015/08/05 职场文书
Django REST framework 限流功能的使用
2021/06/24 Python