PyTorch搭建一维线性回归模型(二)


Posted in Python onMay 22, 2019

PyTorch基础入门二:PyTorch搭建一维线性回归模型

1)一维线性回归模型的理论基础

给定数据集PyTorch搭建一维线性回归模型(二),线性回归希望能够优化出一个好的函数PyTorch搭建一维线性回归模型(二),使得PyTorch搭建一维线性回归模型(二)能够和PyTorch搭建一维线性回归模型(二)尽可能接近。

如何才能学习到参数PyTorch搭建一维线性回归模型(二)PyTorch搭建一维线性回归模型(二)呢?很简单,只需要确定如何衡量PyTorch搭建一维线性回归模型(二)PyTorch搭建一维线性回归模型(二)之间的差别,我们一般通过损失函数(Loss Funciton)来衡量:PyTorch搭建一维线性回归模型(二)。取平方是因为距离有正有负,我们于是将它们变为全是正的。这就是著名的均方误差。我们要做的事情就是希望能够找到PyTorch搭建一维线性回归模型(二)PyTorch搭建一维线性回归模型(二),使得:

PyTorch搭建一维线性回归模型(二)

PyTorch搭建一维线性回归模型(二)

均方差误差非常直观,也有着很好的几何意义,对应了常用的欧式距离。现在要求解这个连续函数的最小值,我们很自然想到的方法就是求它的偏导数,让它的偏导数等于0来估计它的参数,即:

PyTorch搭建一维线性回归模型(二)

PyTorch搭建一维线性回归模型(二)

求解以上两式,我们就可以得到最优解。

2)代码实现

首先,我们需要“制造”出一些数据集:

import torch
import matplotlib.pyplot as plt
 
 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = 3*x + 10 + torch.rand(x.size())
# 上面这行代码是制造出接近y=3x+10的数据集,后面加上torch.rand()函数制造噪音
 
# 画图
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

我们想要拟合的一维回归模型是PyTorch搭建一维线性回归模型(二)。上面制造的数据集也是比较接近这个模型的,但是为了达到学习效果,人为地加上了torch.rand()值增加一些干扰。

上面人为制造出来的数据集的分布如下:

PyTorch搭建一维线性回归模型(二)

有了数据,我们就要开始定义我们的模型,这里定义的是一个输入层和输出层都只有一维的模型,并且使用了“先判断后使用”的基本结构来合理使用GPU加速。

class LinearRegression(nn.Module):
  def __init__(self):
    super(LinearRegression, self).__init__()
    self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
  def forward(self, x):
    out = self.linear(x)
    return out
 
if torch.cuda.is_available():
  model = LinearRegression().cuda()
else:
  model = LinearRegression()

然后我们定义出损失函数和优化函数,这里使用均方误差作为损失函数,使用梯度下降进行优化:

criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

接下来,开始进行模型的训练。

num_epochs = 1000
for epoch in range(num_epochs):
  if torch.cuda.is_available():
    inputs = Variable(x).cuda()
    target = Variable(y).cuda()
  else:
    inputs = Variable(x)
    target = Variable(y)
 
  # 向前传播
  out = model(inputs)
  loss = criterion(out, target)
 
  # 向后传播
  optimizer.zero_grad() # 注意每次迭代都需要清零
  loss.backward()
  optimizer.step()
 
  if (epoch+1) %20 == 0:
    print('Epoch[{}/{}], loss:{:.6f}'.format(epoch+1, num_epochs, loss.data[0]))

首先定义了迭代的次数,这里为1000次,先向前传播计算出损失函数,然后向后传播计算梯度,这里需要注意的是,每次计算梯度前都要记得将梯度归零,不然梯度会累加到一起造成结果不收敛。为了便于看到结果,每隔一段时间输出当前的迭代轮数和损失函数。

接下来,我们通过model.eval()函数将模型变为测试模式,然后将数据放入模型中进行预测。最后,通过画图工具matplotlib看一下我们拟合的结果,代码如下:

model.eval()
if torch.cuda.is_available():
  predict = model(Variable(x).cuda())
  predict = predict.data.cpu().numpy()
else:
  predict = model(Variable(x))
  predict = predict.data.numpy()
plt.plot(x.numpy(), y.numpy(), 'ro', label='Original Data')
plt.plot(x.numpy(), predict, label='Fitting Line')
plt.show()

其拟合结果如下图:

PyTorch搭建一维线性回归模型(二)

附上完整代码:

# !/usr/bin/python
# coding: utf8
# @Time  : 2018-07-28 18:40
# @Author : Liam
# @Email  : luyu.real@qq.com
# @Software: PyCharm
#            .::::.
#           .::::::::.
#           :::::::::::
#         ..:::::::::::'
#        '::::::::::::'
#         .::::::::::
#      '::::::::::::::..
#         ..::::::::::::.
#        ``::::::::::::::::
#        ::::``:::::::::'    .:::.
#        ::::'  ':::::'    .::::::::.
#       .::::'   ::::   .:::::::'::::.
#      .:::'    ::::: .:::::::::' ':::::.
#      .::'    :::::.:::::::::'   ':::::.
#     .::'     ::::::::::::::'     ``::::.
#   ...:::      ::::::::::::'       ``::.
#   ```` ':.     ':::::::::'         ::::..
#            '.:::::'          ':'````..
#           美女保佑 永无BUG
 
import torch
from torch.autograd import Variable
import numpy as np
import random
import matplotlib.pyplot as plt
from torch import nn
 
 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = 3*x + 10 + torch.rand(x.size())
# 上面这行代码是制造出接近y=3x+10的数据集,后面加上torch.rand()函数制造噪音
 
# 画图
# plt.scatter(x.data.numpy(), y.data.numpy())
# plt.show()
class LinearRegression(nn.Module):
  def __init__(self):
    super(LinearRegression, self).__init__()
    self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
  def forward(self, x):
    out = self.linear(x)
    return out
 
if torch.cuda.is_available():
  model = LinearRegression().cuda()
else:
  model = LinearRegression()
 
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
 
num_epochs = 1000
for epoch in range(num_epochs):
  if torch.cuda.is_available():
    inputs = Variable(x).cuda()
    target = Variable(y).cuda()
  else:
    inputs = Variable(x)
    target = Variable(y)
 
  # 向前传播
  out = model(inputs)
  loss = criterion(out, target)
 
  # 向后传播
  optimizer.zero_grad() # 注意每次迭代都需要清零
  loss.backward()
  optimizer.step()
 
  if (epoch+1) %20 == 0:
    print('Epoch[{}/{}], loss:{:.6f}'.format(epoch+1, num_epochs, loss.data[0]))
model.eval()
if torch.cuda.is_available():
  predict = model(Variable(x).cuda())
  predict = predict.data.cpu().numpy()
else:
  predict = model(Variable(x))
  predict = predict.data.numpy()
plt.plot(x.numpy(), y.numpy(), 'ro', label='Original Data')
plt.plot(x.numpy(), predict, label='Fitting Line')
plt.show()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用urllib模块和pyquery实现阿里巴巴排名查询
Jan 16 Python
10种检测Python程序运行时间、CPU和内存占用的方法
Apr 01 Python
Python中的lstrip()方法使用简介
May 19 Python
Python两个内置函数 locals 和globals(学习笔记)
Aug 28 Python
Python迭代和迭代器详解
Nov 10 Python
浅谈Python实现Apriori算法介绍
Dec 20 Python
Django添加KindEditor富文本编辑器的使用
Oct 24 Python
python3使用matplotlib绘制散点图
Mar 19 Python
python中update的基本使用方法详解
Jul 17 Python
Python 实现数组相减示例
Dec 27 Python
python实现交并比IOU教程
Apr 16 Python
python 利用panda 实现列联表(交叉表)
Feb 06 Python
PyTorch基本数据类型(一)
May 22 #Python
PyTorch搭建多项式回归模型(三)
May 22 #Python
pytorch使用Variable实现线性回归
May 21 #Python
Python面向对象进阶学习
May 21 #Python
谈一谈基于python的面向对象编程基础
May 21 #Python
python字符串和常用数据结构知识总结
May 21 #Python
Opencv实现抠图背景图替换功能
May 21 #Python
You might like
2020显卡排行榜天梯图 显卡天梯图2020年3月最新版
2020/04/02 数码科技
如何使用PHP获取网络上文件
2006/10/09 PHP
执行、获取远程代码返回:file_get_contents 超时处理的问题详解
2013/06/25 PHP
一个完整的PHP类包含的七种语法说明
2015/06/04 PHP
把JS与CSS写在同一个文件里的书写方法
2007/06/02 Javascript
基于jQuery实现复选框的全选 全不选 反选功能
2014/11/24 Javascript
jQuery中的jQuery()方法用法分析
2014/12/27 Javascript
javascript模拟评分控件实现方法
2015/05/13 Javascript
详解JS中Array对象扩展与String对象扩展
2016/01/07 Javascript
浏览器环境下JavaScript脚本加载与执行探析之动态脚本与Ajax脚本注入
2016/01/19 Javascript
JavaScript中const、var和let区别浅析
2016/10/11 Javascript
浅谈jquery上下滑动的注意事项
2016/10/13 Javascript
深入理解选择框脚本[推荐]
2016/12/13 Javascript
js实现表格筛选功能
2017/01/18 Javascript
vue.js的手脚架vue-cli项目搭建的步骤
2017/08/30 Javascript
用vuex写了一个购物车H5页面的示例代码
2018/12/04 Javascript
详解微信小程序实现跑马灯效果(附完整代码)
2019/04/29 Javascript
uni-app之APP和小程序微信授权方法
2019/05/09 Javascript
vue全屏事件开发详解
2020/06/17 Javascript
[11:33]DAC2018 4.5SOLO赛决赛 MidOne vs Paparazi第二场
2018/04/06 DOTA
[01:38]完美世界DOTA2联赛(PWL)宣传片:第一站
2020/10/26 DOTA
python实现逻辑回归的方法示例
2017/05/02 Python
Python实现的FTP通信客户端与服务器端功能示例
2018/03/28 Python
Python实现模拟浏览器请求及会话保持操作示例
2018/07/30 Python
Python使用pydub库对mp3与wav格式进行互转的方法
2019/01/10 Python
解决pycharm下os.system执行命令返回有中文乱码的问题
2019/07/07 Python
通过PHP与Python代码对比的语法差异详解
2019/07/10 Python
python中web框架的自定义创建
2019/09/08 Python
python实现对列表中的元素进行倒序打印
2019/11/23 Python
PHP面试题及答案一
2012/06/18 面试题
请编程遍历页面上所有 TextBox 控件并给它赋值为 string.Empty
2015/12/03 面试题
安全演讲稿大全
2014/05/09 职场文书
写给同学的新学期寄语
2015/02/27 职场文书
python中Matplotlib绘制直线的实例代码
2021/07/04 Python
Python音乐爬虫完美绕过反爬
2021/08/30 Python
Java Redisson多策略注解限流
2022/09/23 Java/Android