PyTorch搭建一维线性回归模型(二)


Posted in Python onMay 22, 2019

PyTorch基础入门二:PyTorch搭建一维线性回归模型

1)一维线性回归模型的理论基础

给定数据集PyTorch搭建一维线性回归模型(二),线性回归希望能够优化出一个好的函数PyTorch搭建一维线性回归模型(二),使得PyTorch搭建一维线性回归模型(二)能够和PyTorch搭建一维线性回归模型(二)尽可能接近。

如何才能学习到参数PyTorch搭建一维线性回归模型(二)PyTorch搭建一维线性回归模型(二)呢?很简单,只需要确定如何衡量PyTorch搭建一维线性回归模型(二)PyTorch搭建一维线性回归模型(二)之间的差别,我们一般通过损失函数(Loss Funciton)来衡量:PyTorch搭建一维线性回归模型(二)。取平方是因为距离有正有负,我们于是将它们变为全是正的。这就是著名的均方误差。我们要做的事情就是希望能够找到PyTorch搭建一维线性回归模型(二)PyTorch搭建一维线性回归模型(二),使得:

PyTorch搭建一维线性回归模型(二)

PyTorch搭建一维线性回归模型(二)

均方差误差非常直观,也有着很好的几何意义,对应了常用的欧式距离。现在要求解这个连续函数的最小值,我们很自然想到的方法就是求它的偏导数,让它的偏导数等于0来估计它的参数,即:

PyTorch搭建一维线性回归模型(二)

PyTorch搭建一维线性回归模型(二)

求解以上两式,我们就可以得到最优解。

2)代码实现

首先,我们需要“制造”出一些数据集:

import torch
import matplotlib.pyplot as plt
 
 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = 3*x + 10 + torch.rand(x.size())
# 上面这行代码是制造出接近y=3x+10的数据集,后面加上torch.rand()函数制造噪音
 
# 画图
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

我们想要拟合的一维回归模型是PyTorch搭建一维线性回归模型(二)。上面制造的数据集也是比较接近这个模型的,但是为了达到学习效果,人为地加上了torch.rand()值增加一些干扰。

上面人为制造出来的数据集的分布如下:

PyTorch搭建一维线性回归模型(二)

有了数据,我们就要开始定义我们的模型,这里定义的是一个输入层和输出层都只有一维的模型,并且使用了“先判断后使用”的基本结构来合理使用GPU加速。

class LinearRegression(nn.Module):
  def __init__(self):
    super(LinearRegression, self).__init__()
    self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
  def forward(self, x):
    out = self.linear(x)
    return out
 
if torch.cuda.is_available():
  model = LinearRegression().cuda()
else:
  model = LinearRegression()

然后我们定义出损失函数和优化函数,这里使用均方误差作为损失函数,使用梯度下降进行优化:

criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

接下来,开始进行模型的训练。

num_epochs = 1000
for epoch in range(num_epochs):
  if torch.cuda.is_available():
    inputs = Variable(x).cuda()
    target = Variable(y).cuda()
  else:
    inputs = Variable(x)
    target = Variable(y)
 
  # 向前传播
  out = model(inputs)
  loss = criterion(out, target)
 
  # 向后传播
  optimizer.zero_grad() # 注意每次迭代都需要清零
  loss.backward()
  optimizer.step()
 
  if (epoch+1) %20 == 0:
    print('Epoch[{}/{}], loss:{:.6f}'.format(epoch+1, num_epochs, loss.data[0]))

首先定义了迭代的次数,这里为1000次,先向前传播计算出损失函数,然后向后传播计算梯度,这里需要注意的是,每次计算梯度前都要记得将梯度归零,不然梯度会累加到一起造成结果不收敛。为了便于看到结果,每隔一段时间输出当前的迭代轮数和损失函数。

接下来,我们通过model.eval()函数将模型变为测试模式,然后将数据放入模型中进行预测。最后,通过画图工具matplotlib看一下我们拟合的结果,代码如下:

model.eval()
if torch.cuda.is_available():
  predict = model(Variable(x).cuda())
  predict = predict.data.cpu().numpy()
else:
  predict = model(Variable(x))
  predict = predict.data.numpy()
plt.plot(x.numpy(), y.numpy(), 'ro', label='Original Data')
plt.plot(x.numpy(), predict, label='Fitting Line')
plt.show()

其拟合结果如下图:

PyTorch搭建一维线性回归模型(二)

附上完整代码:

# !/usr/bin/python
# coding: utf8
# @Time  : 2018-07-28 18:40
# @Author : Liam
# @Email  : luyu.real@qq.com
# @Software: PyCharm
#            .::::.
#           .::::::::.
#           :::::::::::
#         ..:::::::::::'
#        '::::::::::::'
#         .::::::::::
#      '::::::::::::::..
#         ..::::::::::::.
#        ``::::::::::::::::
#        ::::``:::::::::'    .:::.
#        ::::'  ':::::'    .::::::::.
#       .::::'   ::::   .:::::::'::::.
#      .:::'    ::::: .:::::::::' ':::::.
#      .::'    :::::.:::::::::'   ':::::.
#     .::'     ::::::::::::::'     ``::::.
#   ...:::      ::::::::::::'       ``::.
#   ```` ':.     ':::::::::'         ::::..
#            '.:::::'          ':'````..
#           美女保佑 永无BUG
 
import torch
from torch.autograd import Variable
import numpy as np
import random
import matplotlib.pyplot as plt
from torch import nn
 
 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = 3*x + 10 + torch.rand(x.size())
# 上面这行代码是制造出接近y=3x+10的数据集,后面加上torch.rand()函数制造噪音
 
# 画图
# plt.scatter(x.data.numpy(), y.data.numpy())
# plt.show()
class LinearRegression(nn.Module):
  def __init__(self):
    super(LinearRegression, self).__init__()
    self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
  def forward(self, x):
    out = self.linear(x)
    return out
 
if torch.cuda.is_available():
  model = LinearRegression().cuda()
else:
  model = LinearRegression()
 
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
 
num_epochs = 1000
for epoch in range(num_epochs):
  if torch.cuda.is_available():
    inputs = Variable(x).cuda()
    target = Variable(y).cuda()
  else:
    inputs = Variable(x)
    target = Variable(y)
 
  # 向前传播
  out = model(inputs)
  loss = criterion(out, target)
 
  # 向后传播
  optimizer.zero_grad() # 注意每次迭代都需要清零
  loss.backward()
  optimizer.step()
 
  if (epoch+1) %20 == 0:
    print('Epoch[{}/{}], loss:{:.6f}'.format(epoch+1, num_epochs, loss.data[0]))
model.eval()
if torch.cuda.is_available():
  predict = model(Variable(x).cuda())
  predict = predict.data.cpu().numpy()
else:
  predict = model(Variable(x))
  predict = predict.data.numpy()
plt.plot(x.numpy(), y.numpy(), 'ro', label='Original Data')
plt.plot(x.numpy(), predict, label='Fitting Line')
plt.show()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python函数any()和all()的用法及区别介绍
Sep 14 Python
浅谈python脚本设置运行参数的方法
Dec 03 Python
python实现自动解数独小程序
Jan 21 Python
Python列表常见操作详解(获取,增加,删除,修改,排序等)
Feb 18 Python
基于python 微信小程序之获取已存在模板消息列表
Aug 05 Python
Python 共享变量加锁、释放详解
Aug 28 Python
Python序列类型的打包和解包实例
Dec 21 Python
python使用paramiko实现ssh的功能详解
Mar 06 Python
Window版下在Jupyter中编写TensorFlow的环境搭建
Apr 10 Python
15个Pythonic的代码示例(值得收藏)
Oct 29 Python
python两种获取剪贴板内容的方法
Nov 06 Python
Python之字符串的遍历的4种方式
Dec 08 Python
PyTorch基本数据类型(一)
May 22 #Python
PyTorch搭建多项式回归模型(三)
May 22 #Python
pytorch使用Variable实现线性回归
May 21 #Python
Python面向对象进阶学习
May 21 #Python
谈一谈基于python的面向对象编程基础
May 21 #Python
python字符串和常用数据结构知识总结
May 21 #Python
Opencv实现抠图背景图替换功能
May 21 #Python
You might like
比较discuz和ecshop的截取字符串函数php版
2012/09/03 PHP
PHP连接MySQL查询结果中文显示乱码解决方法
2013/10/25 PHP
在php和MySql中计算时间差的方法详解
2015/03/27 PHP
php resizeimage 部分jpg文件 生成缩略图失败的原因分析及解决办法
2016/03/23 PHP
学习PHP Cookie处理函数
2016/08/09 PHP
基于swoole实现多人聊天室
2018/06/14 PHP
PHP PDOStatement::nextRowset讲解
2019/02/01 PHP
showModelessDialog()使用详解
2006/09/07 Javascript
JavaScript 计算当天是本年本月的第几周
2009/03/22 Javascript
JS 控制小数位数的实现代码
2011/08/02 Javascript
javascript工具库代码
2012/03/29 Javascript
jQuery中last()方法用法实例
2015/01/06 Javascript
jquery判断复选框是否被选中的方法
2015/10/16 Javascript
jQuery实现的网页右下角tab样式在线客服效果代码
2015/10/23 Javascript
JavaScript实现多种排序算法
2016/02/24 Javascript
jQuery解析与处理服务器端返回xml格式数据的方法详解
2016/07/04 Javascript
AngularJS 表达式详解及实例代码
2016/09/14 Javascript
JavaScript中利用for循环遍历数组
2017/01/15 Javascript
bootstrap常用组件之头部导航实现代码
2017/04/20 Javascript
js 取消页面可以选中文字的功能方法
2018/01/02 Javascript
使用Vue开发自己的Chrome扩展程序过程详解
2019/06/21 Javascript
vue项目强制清除页面缓存的例子
2019/11/06 Javascript
Element实现表格嵌套、多个表格共用一个表头的方法
2020/05/09 Javascript
详解JavaScript中的数据类型,以及检测数据类型的方法
2020/09/17 Javascript
[01:04:32]DOTA2-DPC中国联赛 正赛 Aster vs LBZS BO3 第二场 2月23日
2021/03/11 DOTA
深入讨论Python函数的参数的默认值所引发的问题的原因
2015/03/30 Python
为什么选择python编程语言入门黑客攻防 给你几个理由!
2018/02/02 Python
pytorch 模拟关系拟合——回归实例
2020/01/14 Python
Selenium+BeautifulSoup+json获取Script标签内的json数据
2020/12/07 Python
世界上最大的二手相机店:KEN
2017/05/17 全球购物
阿联酋航空丹麦官方网站:Emirates DK
2019/08/25 全球购物
英国排名第一的冲浪店:Ann’s Cottage
2020/06/21 全球购物
学生党员思想汇报
2013/12/28 职场文书
员工安全责任书范本
2014/07/24 职场文书
2014年学校法制宣传日活动总结
2014/11/01 职场文书
2015年超市收银员工作总结
2015/04/25 职场文书