PyTorch搭建一维线性回归模型(二)


Posted in Python onMay 22, 2019

PyTorch基础入门二:PyTorch搭建一维线性回归模型

1)一维线性回归模型的理论基础

给定数据集PyTorch搭建一维线性回归模型(二),线性回归希望能够优化出一个好的函数PyTorch搭建一维线性回归模型(二),使得PyTorch搭建一维线性回归模型(二)能够和PyTorch搭建一维线性回归模型(二)尽可能接近。

如何才能学习到参数PyTorch搭建一维线性回归模型(二)PyTorch搭建一维线性回归模型(二)呢?很简单,只需要确定如何衡量PyTorch搭建一维线性回归模型(二)PyTorch搭建一维线性回归模型(二)之间的差别,我们一般通过损失函数(Loss Funciton)来衡量:PyTorch搭建一维线性回归模型(二)。取平方是因为距离有正有负,我们于是将它们变为全是正的。这就是著名的均方误差。我们要做的事情就是希望能够找到PyTorch搭建一维线性回归模型(二)PyTorch搭建一维线性回归模型(二),使得:

PyTorch搭建一维线性回归模型(二)

PyTorch搭建一维线性回归模型(二)

均方差误差非常直观,也有着很好的几何意义,对应了常用的欧式距离。现在要求解这个连续函数的最小值,我们很自然想到的方法就是求它的偏导数,让它的偏导数等于0来估计它的参数,即:

PyTorch搭建一维线性回归模型(二)

PyTorch搭建一维线性回归模型(二)

求解以上两式,我们就可以得到最优解。

2)代码实现

首先,我们需要“制造”出一些数据集:

import torch
import matplotlib.pyplot as plt
 
 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = 3*x + 10 + torch.rand(x.size())
# 上面这行代码是制造出接近y=3x+10的数据集,后面加上torch.rand()函数制造噪音
 
# 画图
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

我们想要拟合的一维回归模型是PyTorch搭建一维线性回归模型(二)。上面制造的数据集也是比较接近这个模型的,但是为了达到学习效果,人为地加上了torch.rand()值增加一些干扰。

上面人为制造出来的数据集的分布如下:

PyTorch搭建一维线性回归模型(二)

有了数据,我们就要开始定义我们的模型,这里定义的是一个输入层和输出层都只有一维的模型,并且使用了“先判断后使用”的基本结构来合理使用GPU加速。

class LinearRegression(nn.Module):
  def __init__(self):
    super(LinearRegression, self).__init__()
    self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
  def forward(self, x):
    out = self.linear(x)
    return out
 
if torch.cuda.is_available():
  model = LinearRegression().cuda()
else:
  model = LinearRegression()

然后我们定义出损失函数和优化函数,这里使用均方误差作为损失函数,使用梯度下降进行优化:

criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

接下来,开始进行模型的训练。

num_epochs = 1000
for epoch in range(num_epochs):
  if torch.cuda.is_available():
    inputs = Variable(x).cuda()
    target = Variable(y).cuda()
  else:
    inputs = Variable(x)
    target = Variable(y)
 
  # 向前传播
  out = model(inputs)
  loss = criterion(out, target)
 
  # 向后传播
  optimizer.zero_grad() # 注意每次迭代都需要清零
  loss.backward()
  optimizer.step()
 
  if (epoch+1) %20 == 0:
    print('Epoch[{}/{}], loss:{:.6f}'.format(epoch+1, num_epochs, loss.data[0]))

首先定义了迭代的次数,这里为1000次,先向前传播计算出损失函数,然后向后传播计算梯度,这里需要注意的是,每次计算梯度前都要记得将梯度归零,不然梯度会累加到一起造成结果不收敛。为了便于看到结果,每隔一段时间输出当前的迭代轮数和损失函数。

接下来,我们通过model.eval()函数将模型变为测试模式,然后将数据放入模型中进行预测。最后,通过画图工具matplotlib看一下我们拟合的结果,代码如下:

model.eval()
if torch.cuda.is_available():
  predict = model(Variable(x).cuda())
  predict = predict.data.cpu().numpy()
else:
  predict = model(Variable(x))
  predict = predict.data.numpy()
plt.plot(x.numpy(), y.numpy(), 'ro', label='Original Data')
plt.plot(x.numpy(), predict, label='Fitting Line')
plt.show()

其拟合结果如下图:

PyTorch搭建一维线性回归模型(二)

附上完整代码:

# !/usr/bin/python
# coding: utf8
# @Time  : 2018-07-28 18:40
# @Author : Liam
# @Email  : luyu.real@qq.com
# @Software: PyCharm
#            .::::.
#           .::::::::.
#           :::::::::::
#         ..:::::::::::'
#        '::::::::::::'
#         .::::::::::
#      '::::::::::::::..
#         ..::::::::::::.
#        ``::::::::::::::::
#        ::::``:::::::::'    .:::.
#        ::::'  ':::::'    .::::::::.
#       .::::'   ::::   .:::::::'::::.
#      .:::'    ::::: .:::::::::' ':::::.
#      .::'    :::::.:::::::::'   ':::::.
#     .::'     ::::::::::::::'     ``::::.
#   ...:::      ::::::::::::'       ``::.
#   ```` ':.     ':::::::::'         ::::..
#            '.:::::'          ':'````..
#           美女保佑 永无BUG
 
import torch
from torch.autograd import Variable
import numpy as np
import random
import matplotlib.pyplot as plt
from torch import nn
 
 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = 3*x + 10 + torch.rand(x.size())
# 上面这行代码是制造出接近y=3x+10的数据集,后面加上torch.rand()函数制造噪音
 
# 画图
# plt.scatter(x.data.numpy(), y.data.numpy())
# plt.show()
class LinearRegression(nn.Module):
  def __init__(self):
    super(LinearRegression, self).__init__()
    self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
  def forward(self, x):
    out = self.linear(x)
    return out
 
if torch.cuda.is_available():
  model = LinearRegression().cuda()
else:
  model = LinearRegression()
 
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
 
num_epochs = 1000
for epoch in range(num_epochs):
  if torch.cuda.is_available():
    inputs = Variable(x).cuda()
    target = Variable(y).cuda()
  else:
    inputs = Variable(x)
    target = Variable(y)
 
  # 向前传播
  out = model(inputs)
  loss = criterion(out, target)
 
  # 向后传播
  optimizer.zero_grad() # 注意每次迭代都需要清零
  loss.backward()
  optimizer.step()
 
  if (epoch+1) %20 == 0:
    print('Epoch[{}/{}], loss:{:.6f}'.format(epoch+1, num_epochs, loss.data[0]))
model.eval()
if torch.cuda.is_available():
  predict = model(Variable(x).cuda())
  predict = predict.data.cpu().numpy()
else:
  predict = model(Variable(x))
  predict = predict.data.numpy()
plt.plot(x.numpy(), y.numpy(), 'ro', label='Original Data')
plt.plot(x.numpy(), predict, label='Fitting Line')
plt.show()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
提升Python程序运行效率的6个方法
Mar 31 Python
Anaconda 离线安装 python 包的操作方法
Jun 11 Python
python2和python3的输入和输出区别介绍
Nov 20 Python
python中实现控制小数点位数的方法
Jan 24 Python
python 列表转为字典的两个小方法(小结)
Jun 28 Python
django页面跳转问题及注意事项
Jul 18 Python
python利用dlib获取人脸的68个landmark
Nov 27 Python
python 实现将Numpy数组保存为图像
Jan 09 Python
Python实现投影法分割图像示例(二)
Jan 17 Python
python绘制封闭多边形教程
Feb 18 Python
解决django接口无法通过ip进行访问的问题
Mar 27 Python
关于jupyter打开之后不能直接跳转到浏览器的解决方式
Apr 13 Python
PyTorch基本数据类型(一)
May 22 #Python
PyTorch搭建多项式回归模型(三)
May 22 #Python
pytorch使用Variable实现线性回归
May 21 #Python
Python面向对象进阶学习
May 21 #Python
谈一谈基于python的面向对象编程基础
May 21 #Python
python字符串和常用数据结构知识总结
May 21 #Python
Opencv实现抠图背景图替换功能
May 21 #Python
You might like
php方法调用模式与函数调用模式简例
2011/09/20 PHP
详谈PHP文件目录基础操作
2014/11/11 PHP
php去除头尾空格的2种方法
2015/03/16 PHP
php封装的单文件(图片)上传类完整实例
2016/10/18 PHP
PHP用户管理中常用接口调用实例及解析(含源码)
2017/03/09 PHP
Laravel框架Request、Response及Session操作示例
2019/05/06 PHP
jQuery ajax cache缓存问题
2010/07/01 Javascript
使用jquery与图片美化checkbox和radio控件的代码(打包下载)
2010/11/11 Javascript
利用JS自动打开页面上链接的实现代码
2011/09/25 Javascript
jQuery使用一个按钮控制图片的伸缩实现思路
2013/04/19 Javascript
JQuery事件e参数的方法preventDefault()取消默认行为
2013/09/26 Javascript
原生js获取宽高与jquery获取宽高的方法关系对比
2014/04/04 Javascript
JavaScript实现文本框中默认显示背景图片在获得焦点后消失的方法
2015/07/01 Javascript
基于jQuery的网页影音播放器jPlayer的基本使用教程
2016/03/08 Javascript
EasyUI布局 高度自适应
2016/06/04 Javascript
详解使用React进行组件库开发
2018/02/06 Javascript
jQuery+ajax读取json数据并按照价格排序示例
2018/03/28 jQuery
webpack项目使用eslint建立代码规范实现
2019/05/16 Javascript
vue.js 打包时出现空白页和路径错误问题及解决方法
2019/06/26 Javascript
jQuery实现视频展示效果
2020/05/30 jQuery
Vue项目前后端联调(使用proxyTable实现跨域方式)
2020/07/18 Javascript
elementui实现预览图片组件二次封装
2020/12/29 Javascript
[47:38]Optic vs VGJ.S 2018国际邀请赛小组赛BO2 第二场 8.17
2018/08/20 DOTA
python集成开发环境配置(pycharm)
2020/02/14 Python
基于Python的Jenkins的二次开发操作
2020/05/12 Python
Python3.9 beta2版本发布了,看看这7个新的PEP都是什么
2020/06/10 Python
Python通过类的组合模拟街道红绿灯
2020/09/16 Python
哈利波特商店:Harry Potter Shop
2018/11/30 全球购物
客服服务心得体会
2013/12/30 职场文书
职业生涯规划设计步骤
2014/01/12 职场文书
春节晚会主持词
2014/03/24 职场文书
祝寿主持词
2015/07/02 职场文书
写给同事的离职感言
2015/08/04 职场文书
python开发飞机大战游戏
2021/07/15 Python
如何利用 CSS Overview 面板重构优化你的网站
2021/10/24 HTML / CSS
Python内置数据类型中的集合详解
2022/03/18 Python