PyTorch搭建一维线性回归模型(二)


Posted in Python onMay 22, 2019

PyTorch基础入门二:PyTorch搭建一维线性回归模型

1)一维线性回归模型的理论基础

给定数据集PyTorch搭建一维线性回归模型(二),线性回归希望能够优化出一个好的函数PyTorch搭建一维线性回归模型(二),使得PyTorch搭建一维线性回归模型(二)能够和PyTorch搭建一维线性回归模型(二)尽可能接近。

如何才能学习到参数PyTorch搭建一维线性回归模型(二)PyTorch搭建一维线性回归模型(二)呢?很简单,只需要确定如何衡量PyTorch搭建一维线性回归模型(二)PyTorch搭建一维线性回归模型(二)之间的差别,我们一般通过损失函数(Loss Funciton)来衡量:PyTorch搭建一维线性回归模型(二)。取平方是因为距离有正有负,我们于是将它们变为全是正的。这就是著名的均方误差。我们要做的事情就是希望能够找到PyTorch搭建一维线性回归模型(二)PyTorch搭建一维线性回归模型(二),使得:

PyTorch搭建一维线性回归模型(二)

PyTorch搭建一维线性回归模型(二)

均方差误差非常直观,也有着很好的几何意义,对应了常用的欧式距离。现在要求解这个连续函数的最小值,我们很自然想到的方法就是求它的偏导数,让它的偏导数等于0来估计它的参数,即:

PyTorch搭建一维线性回归模型(二)

PyTorch搭建一维线性回归模型(二)

求解以上两式,我们就可以得到最优解。

2)代码实现

首先,我们需要“制造”出一些数据集:

import torch
import matplotlib.pyplot as plt
 
 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = 3*x + 10 + torch.rand(x.size())
# 上面这行代码是制造出接近y=3x+10的数据集,后面加上torch.rand()函数制造噪音
 
# 画图
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

我们想要拟合的一维回归模型是PyTorch搭建一维线性回归模型(二)。上面制造的数据集也是比较接近这个模型的,但是为了达到学习效果,人为地加上了torch.rand()值增加一些干扰。

上面人为制造出来的数据集的分布如下:

PyTorch搭建一维线性回归模型(二)

有了数据,我们就要开始定义我们的模型,这里定义的是一个输入层和输出层都只有一维的模型,并且使用了“先判断后使用”的基本结构来合理使用GPU加速。

class LinearRegression(nn.Module):
  def __init__(self):
    super(LinearRegression, self).__init__()
    self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
  def forward(self, x):
    out = self.linear(x)
    return out
 
if torch.cuda.is_available():
  model = LinearRegression().cuda()
else:
  model = LinearRegression()

然后我们定义出损失函数和优化函数,这里使用均方误差作为损失函数,使用梯度下降进行优化:

criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

接下来,开始进行模型的训练。

num_epochs = 1000
for epoch in range(num_epochs):
  if torch.cuda.is_available():
    inputs = Variable(x).cuda()
    target = Variable(y).cuda()
  else:
    inputs = Variable(x)
    target = Variable(y)
 
  # 向前传播
  out = model(inputs)
  loss = criterion(out, target)
 
  # 向后传播
  optimizer.zero_grad() # 注意每次迭代都需要清零
  loss.backward()
  optimizer.step()
 
  if (epoch+1) %20 == 0:
    print('Epoch[{}/{}], loss:{:.6f}'.format(epoch+1, num_epochs, loss.data[0]))

首先定义了迭代的次数,这里为1000次,先向前传播计算出损失函数,然后向后传播计算梯度,这里需要注意的是,每次计算梯度前都要记得将梯度归零,不然梯度会累加到一起造成结果不收敛。为了便于看到结果,每隔一段时间输出当前的迭代轮数和损失函数。

接下来,我们通过model.eval()函数将模型变为测试模式,然后将数据放入模型中进行预测。最后,通过画图工具matplotlib看一下我们拟合的结果,代码如下:

model.eval()
if torch.cuda.is_available():
  predict = model(Variable(x).cuda())
  predict = predict.data.cpu().numpy()
else:
  predict = model(Variable(x))
  predict = predict.data.numpy()
plt.plot(x.numpy(), y.numpy(), 'ro', label='Original Data')
plt.plot(x.numpy(), predict, label='Fitting Line')
plt.show()

其拟合结果如下图:

PyTorch搭建一维线性回归模型(二)

附上完整代码:

# !/usr/bin/python
# coding: utf8
# @Time  : 2018-07-28 18:40
# @Author : Liam
# @Email  : luyu.real@qq.com
# @Software: PyCharm
#            .::::.
#           .::::::::.
#           :::::::::::
#         ..:::::::::::'
#        '::::::::::::'
#         .::::::::::
#      '::::::::::::::..
#         ..::::::::::::.
#        ``::::::::::::::::
#        ::::``:::::::::'    .:::.
#        ::::'  ':::::'    .::::::::.
#       .::::'   ::::   .:::::::'::::.
#      .:::'    ::::: .:::::::::' ':::::.
#      .::'    :::::.:::::::::'   ':::::.
#     .::'     ::::::::::::::'     ``::::.
#   ...:::      ::::::::::::'       ``::.
#   ```` ':.     ':::::::::'         ::::..
#            '.:::::'          ':'````..
#           美女保佑 永无BUG
 
import torch
from torch.autograd import Variable
import numpy as np
import random
import matplotlib.pyplot as plt
from torch import nn
 
 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = 3*x + 10 + torch.rand(x.size())
# 上面这行代码是制造出接近y=3x+10的数据集,后面加上torch.rand()函数制造噪音
 
# 画图
# plt.scatter(x.data.numpy(), y.data.numpy())
# plt.show()
class LinearRegression(nn.Module):
  def __init__(self):
    super(LinearRegression, self).__init__()
    self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
  def forward(self, x):
    out = self.linear(x)
    return out
 
if torch.cuda.is_available():
  model = LinearRegression().cuda()
else:
  model = LinearRegression()
 
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
 
num_epochs = 1000
for epoch in range(num_epochs):
  if torch.cuda.is_available():
    inputs = Variable(x).cuda()
    target = Variable(y).cuda()
  else:
    inputs = Variable(x)
    target = Variable(y)
 
  # 向前传播
  out = model(inputs)
  loss = criterion(out, target)
 
  # 向后传播
  optimizer.zero_grad() # 注意每次迭代都需要清零
  loss.backward()
  optimizer.step()
 
  if (epoch+1) %20 == 0:
    print('Epoch[{}/{}], loss:{:.6f}'.format(epoch+1, num_epochs, loss.data[0]))
model.eval()
if torch.cuda.is_available():
  predict = model(Variable(x).cuda())
  predict = predict.data.cpu().numpy()
else:
  predict = model(Variable(x))
  predict = predict.data.numpy()
plt.plot(x.numpy(), y.numpy(), 'ro', label='Original Data')
plt.plot(x.numpy(), predict, label='Fitting Line')
plt.show()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Django实现自定义404,500页面教程
Mar 26 Python
Python中用字符串调用函数或方法示例代码
Aug 04 Python
python爬虫之urllib3的使用示例
Jul 09 Python
Pandas DataFrame 取一行数据会得到Series的方法
Nov 10 Python
Python编程flask使用页面模版的方法
Dec 28 Python
Python制作exe文件简单流程
Jan 24 Python
int在python中的含义以及用法
Jun 27 Python
Django打印出在数据库中执行的语句问题
Jul 25 Python
Python爬虫实现使用beautifulSoup4爬取名言网功能案例
Sep 15 Python
手动安装python3.6的操作过程详解
Jan 13 Python
keras .h5转移动端的.tflite文件实现方式
May 25 Python
python爬取股票最新数据并用excel绘制树状图的示例
Mar 01 Python
PyTorch基本数据类型(一)
May 22 #Python
PyTorch搭建多项式回归模型(三)
May 22 #Python
pytorch使用Variable实现线性回归
May 21 #Python
Python面向对象进阶学习
May 21 #Python
谈一谈基于python的面向对象编程基础
May 21 #Python
python字符串和常用数据结构知识总结
May 21 #Python
Opencv实现抠图背景图替换功能
May 21 #Python
You might like
基于curl数据采集之单页面采集函数get_html的使用
2013/04/28 PHP
基于PHP CURL获取邮箱地址的详解
2013/06/03 PHP
PHP面向对象程序设计高级特性详解(接口,继承,抽象类,析构,克隆等)
2016/12/02 PHP
详解PHP5.6.30与Apache2.4.x配置
2017/06/02 PHP
cnblogs 代码高亮显示后的代码复制问题解决实现代码
2011/12/14 Javascript
JQuery 中几个类选择器的简单使用介绍
2013/03/14 Javascript
『jQuery』名称冲突使用noConflict方法解决
2013/04/22 Javascript
捕获浏览器关闭、刷新事件不同情况下的处理方法
2013/06/02 Javascript
JS实现匀速运动的代码实例
2013/11/29 Javascript
JS字符串拼接在ie中都报错的解决方法
2014/03/27 Javascript
原生javascript实现获取指定元素下所有后代元素的方法
2014/10/28 Javascript
兼容Firefox的Javascript XSLT 处理XML文件
2014/12/31 Javascript
JS动画效果打开、关闭层的实现方法
2015/05/09 Javascript
jQuery实现表单动态添加数据并提交的方法
2018/07/19 jQuery
Vue-Quill-Editor富文本编辑器的使用教程
2018/09/21 Javascript
Vue中错误图片的处理的实现代码
2019/11/07 Javascript
JS实现的进制转换,浮点数相加,数字判断操作示例
2019/11/09 Javascript
JavaScript实现无限轮播效果
2020/11/19 Javascript
[01:03:13]VG vs Pain 2018国际邀请赛小组赛BO2 第一场 8.18
2018/08/19 DOTA
python生成指定长度的随机数密码
2014/01/23 Python
决策树的python实现方法
2014/11/18 Python
python实现网站的模拟登录
2016/01/04 Python
利用Tkinter(python3.6)实现一个简单计算器
2017/12/21 Python
python中的数据结构比较
2019/05/13 Python
Python循环中else,break和continue的用法实例详解
2019/07/11 Python
python 实现生成均匀分布的点
2019/12/05 Python
Python处理PDF与CDF实例
2020/02/26 Python
python 使用递归的方式实现语义图片分割功能
2020/07/16 Python
Python变量及数据类型用法原理汇总
2020/08/06 Python
基于HTML5+Webkit实现树叶飘落动画
2017/12/28 HTML / CSS
双立人美国官方商店:ZWILLING集团餐具和炊具
2020/05/07 全球购物
C#实现对任一张表的数据进行增,删,改,查要求,运用Webservice,体现出三层架构
2014/07/11 面试题
计算机专业自我鉴定
2013/10/15 职场文书
学生自我评价范文
2014/02/02 职场文书
小学毕业典礼主持词
2014/03/27 职场文书
预防艾滋病宣传活动总结
2015/05/09 职场文书