python 还原梯度下降算法实现一维线性回归


Posted in Python onOctober 22, 2020

首先我们看公式:

python 还原梯度下降算法实现一维线性回归

这个是要拟合的函数

然后我们求出它的损失函数, 注意:这里的n和m均为数据集的长度,写的时候忘了

python 还原梯度下降算法实现一维线性回归

注意,前面的theta0-theta1x是实际值,后面的y是期望值
接着我们求出损失函数的偏导数:

python 还原梯度下降算法实现一维线性回归

最终,梯度下降的算法:

python 还原梯度下降算法实现一维线性回归

学习率一般小于1,当损失函数是0时,我们输出theta0和theta1.
接下来上代码!

class LinearRegression():

  def __init__(self, data, theta0, theta1, learning_rate):
    self.data = data
    self.theta0 = theta0
    self.theta1 = theta1
    self.learning_rate = learning_rate
    self.length = len(data)

  # hypothesis
  def h_theta(self, x):
    return self.theta0 + self.theta1 * x

  # cost function
  def J(self):
    temp = 0
    for i in range(self.length):
      temp += pow(self.h_theta(self.data[i][0]) - self.data[i][1], 2)
    return 1 / (2 * self.m) * temp

  # partial derivative
  def pd_theta0_J(self):
    temp = 0
    for i in range(self.length):
      temp += self.h_theta(self.data[i][0]) - self.data[i][1]
    return 1 / self.m * temp

  def pd_theta1_J(self):
    temp = 0
    for i in range(self.length):
      temp += (self.h_theta(data[i][0]) - self.data[i][1]) * self.data[i][0]
    return 1 / self.m * temp

  # gradient descent
  def gd(self):
    min_cost = 0.00001
    round = 1
    max_round = 10000
    while min_cost < abs(self.J()) and round <= max_round:
      self.theta0 = self.theta0 - self.learning_rate * self.pd_theta0_J()
      self.theta1 = self.theta1 - self.learning_rate * self.pd_theta1_J()

      print('round', round, ':\t theta0=%.16f' % self.theta0, '\t theta1=%.16f' % self.theta1)
      round += 1
    return self.theta0, self.theta1

def main():
	data = [[1, 2], [2, 5], [4, 8], [5, 9], [8, 15]] # 这里换成你想拟合的数[x, y]
	 # plot scatter
  x = []
  y = []
  for i in range(len(data)):
    x.append(data[i][0])
    y.append(data[i][1])
  plt.scatter(x, y)

  # gradient descent
  linear_regression = LinearRegression(data, theta0, theta1, learning_rate)
  theta0, theta1 = linear_regression.gd()

  # plot returned linear
  x = np.arange(0, 10, 0.01)
  y = theta0 + theta1 * x
  plt.plot(x, y)
  plt.show()

到此这篇关于python 还原梯度下降算法实现一维线性回归 的文章就介绍到这了,更多相关python 一维线性回归 内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python使用Socket(Https)Post登录百度的实现代码
May 18 Python
python定时采集摄像头图像上传ftp服务器功能实现
Dec 23 Python
使用wxPython获取系统剪贴板中的数据的教程
May 06 Python
django接入新浪微博OAuth的方法
Jun 29 Python
Python中defaultdict与lambda表达式用法实例小结
Apr 09 Python
Linux下多个Python版本安装教程
Aug 15 Python
python 将list转成字符串,中间用符号分隔的方法
Oct 23 Python
Python实现SQL注入检测插件实例代码
Feb 02 Python
python 实现单通道转3通道
Dec 03 Python
关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)
Feb 20 Python
使用Python第三方库pygame写个贪吃蛇小游戏
Mar 06 Python
基于keras中的回调函数用法说明
Jun 17 Python
利用Pycharm + Django搭建一个简单Python Web项目的步骤
Oct 22 #Python
python处理写入数据代码讲解
Oct 22 #Python
基于Python爬取股票数据过程详解
Oct 21 #Python
OpenCV利用python来实现图像的直方图均衡化
Oct 21 #Python
Python实现手势识别
Oct 21 #Python
利用Python优雅的登录校园网
Oct 21 #Python
python 使用三引号时容易犯的小错误
Oct 21 #Python
You might like
实现分十页分向前十页向后十页的处理
2006/10/09 PHP
php函数array_merge用法一例(合并同类数组)
2013/02/03 PHP
PHP解析目录路径的3个函数总结
2014/11/18 PHP
PHP十六进制颜色随机生成器功能示例
2017/07/24 PHP
PHP中用Trait封装单例模式的实现
2019/12/18 PHP
php gethostbyname获取域名ip地址函数详解
2010/01/24 Javascript
ExtJs使用IFrame的实现代码
2010/03/24 Javascript
ScrollDown的基本操作示例
2013/06/09 Javascript
javascript代码运行不出来执行错误的可能情况整理
2013/10/18 Javascript
javascript获取checkbox复选框获取选中的选项
2014/08/12 Javascript
JavaScript使用Math.Min返回两个数中较小数的方法
2015/04/06 Javascript
JavaScript使用cookie实现记住账号密码功能
2015/04/27 Javascript
jQuery实现鼠标双击Table单元格变成文本框及输入内容后更新到数据库的方法
2015/11/25 Javascript
AngularJS 实现弹性盒子布局的方法
2016/08/30 Javascript
使用react实现手机号的数据同步显示功能的示例代码
2018/04/03 Javascript
微信小程序实现长按删除图片的示例
2018/05/18 Javascript
vue动态改变背景图片demo分享
2018/09/13 Javascript
vue双向绑定数据限制长度的方法
2019/11/04 Javascript
JS数组扁平化、去重、排序操作实例详解
2020/02/24 Javascript
微信小程序开发(三):返回上一级页面并刷新操作示例【页面栈】
2020/06/01 Javascript
vue设置全局访问接口API地址操作
2020/08/14 Javascript
antd中table展开行默认展示,且不需要前边的加号操作
2020/11/02 Javascript
[50:29]2014 DOTA2华西杯精英邀请赛 5 24 DK VS iG
2014/05/26 DOTA
Python2.5/2.6实用教程 入门基础篇
2009/11/29 Python
Python实现全角半角字符互转的方法
2016/11/28 Python
numpy中实现ndarray数组返回符合特定条件的索引方法
2018/04/17 Python
给Python学习者的文件读写指南(含基础与进阶)
2020/01/29 Python
python中os.remove()用法及注意事项
2021/01/31 Python
用纯css3实现的图片放大镜特效效果非常不错
2014/09/02 HTML / CSS
Kathmandu澳洲户外商店:新西兰户外运动品牌
2017/11/12 全球购物
汽车维修专业自荐书
2014/05/26 职场文书
幸福家庭标语
2014/06/27 职场文书
2015年暑期社会实践活动总结
2015/03/27 职场文书
教育读书笔记
2015/07/02 职场文书
Python 如何将integer转化为罗马数(3999以内)
2021/06/05 Python
Java异常处理try catch的基本用法
2021/12/06 Java/Android