python 还原梯度下降算法实现一维线性回归


Posted in Python onOctober 22, 2020

首先我们看公式:

python 还原梯度下降算法实现一维线性回归

这个是要拟合的函数

然后我们求出它的损失函数, 注意:这里的n和m均为数据集的长度,写的时候忘了

python 还原梯度下降算法实现一维线性回归

注意,前面的theta0-theta1x是实际值,后面的y是期望值
接着我们求出损失函数的偏导数:

python 还原梯度下降算法实现一维线性回归

最终,梯度下降的算法:

python 还原梯度下降算法实现一维线性回归

学习率一般小于1,当损失函数是0时,我们输出theta0和theta1.
接下来上代码!

class LinearRegression():

  def __init__(self, data, theta0, theta1, learning_rate):
    self.data = data
    self.theta0 = theta0
    self.theta1 = theta1
    self.learning_rate = learning_rate
    self.length = len(data)

  # hypothesis
  def h_theta(self, x):
    return self.theta0 + self.theta1 * x

  # cost function
  def J(self):
    temp = 0
    for i in range(self.length):
      temp += pow(self.h_theta(self.data[i][0]) - self.data[i][1], 2)
    return 1 / (2 * self.m) * temp

  # partial derivative
  def pd_theta0_J(self):
    temp = 0
    for i in range(self.length):
      temp += self.h_theta(self.data[i][0]) - self.data[i][1]
    return 1 / self.m * temp

  def pd_theta1_J(self):
    temp = 0
    for i in range(self.length):
      temp += (self.h_theta(data[i][0]) - self.data[i][1]) * self.data[i][0]
    return 1 / self.m * temp

  # gradient descent
  def gd(self):
    min_cost = 0.00001
    round = 1
    max_round = 10000
    while min_cost < abs(self.J()) and round <= max_round:
      self.theta0 = self.theta0 - self.learning_rate * self.pd_theta0_J()
      self.theta1 = self.theta1 - self.learning_rate * self.pd_theta1_J()

      print('round', round, ':\t theta0=%.16f' % self.theta0, '\t theta1=%.16f' % self.theta1)
      round += 1
    return self.theta0, self.theta1

def main():
	data = [[1, 2], [2, 5], [4, 8], [5, 9], [8, 15]] # 这里换成你想拟合的数[x, y]
	 # plot scatter
  x = []
  y = []
  for i in range(len(data)):
    x.append(data[i][0])
    y.append(data[i][1])
  plt.scatter(x, y)

  # gradient descent
  linear_regression = LinearRegression(data, theta0, theta1, learning_rate)
  theta0, theta1 = linear_regression.gd()

  # plot returned linear
  x = np.arange(0, 10, 0.01)
  y = theta0 + theta1 * x
  plt.plot(x, y)
  plt.show()

到此这篇关于python 还原梯度下降算法实现一维线性回归 的文章就介绍到这了,更多相关python 一维线性回归 内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python连接MySQL数据库实例分析
May 12 Python
Python中基础的socket编程实战攻略
Jun 01 Python
利用Python爬虫给孩子起个好名字
Feb 14 Python
Python3编程实现获取阿里云ECS实例及监控的方法
Aug 18 Python
详解Python核心编程中的浅拷贝与深拷贝
Jan 07 Python
Python读取excel中的图片完美解决方法
Jul 27 Python
Python csv模块使用方法代码实例
Aug 29 Python
关于Python Tkinter Button控件command传参问题的解决方式
Mar 04 Python
python接口自动化之ConfigParser配置文件的使用详解
Aug 03 Python
使用BeautifulSoup4解析XML的方法小结
Dec 07 Python
Python实现socket库网络通信套接字
Jun 04 Python
Python中文分词库jieba(结巴分词)详细使用介绍
Apr 07 Python
利用Pycharm + Django搭建一个简单Python Web项目的步骤
Oct 22 #Python
python处理写入数据代码讲解
Oct 22 #Python
基于Python爬取股票数据过程详解
Oct 21 #Python
OpenCV利用python来实现图像的直方图均衡化
Oct 21 #Python
Python实现手势识别
Oct 21 #Python
利用Python优雅的登录校园网
Oct 21 #Python
python 使用三引号时容易犯的小错误
Oct 21 #Python
You might like
javascript,php获取函数参数对象的代码
2011/02/03 PHP
一个php生成16位随机数的代码(两种方法)
2014/09/16 PHP
php画图实例
2014/11/05 PHP
动手学习无线电
2021/03/10 无线电
jquery实现文本框鼠标右击无效以及不能输入的代码
2010/11/05 Javascript
javascript中运用闭包和自执行函数解决大量的全局变量问题
2010/12/30 Javascript
jquery load事件(callback/data)使用方法及注意事项
2013/02/06 Javascript
JavaScript 函数参数是传值(byVal)还是传址(byRef) 分享
2013/07/02 Javascript
使用控制台破解百小度一个月只准改一次名字
2015/08/13 Javascript
JS根据key值获取URL中的参数值及把URL的参数转换成json对象
2015/08/26 Javascript
Seajs是什么及sea.js 由来,特点以及优势
2016/10/13 Javascript
jQuery实现表格与ckeckbox的全选与单选功能
2016/11/24 Javascript
JS实现间歇滚动的运动效果实例
2016/12/22 Javascript
Javascript中引用类型传递的知识点小结
2017/03/06 Javascript
从对象列表中获取一个对象的方法,依据关键字和值
2017/09/20 Javascript
Node做中转服务器转发接口
2017/10/18 Javascript
详解微信图片防盗链“此图片来自微信公众平台 未经允许不得引用”的解决方案
2019/04/04 Javascript
详解vue中使用微信jssdk
2019/04/19 Javascript
VUE注册全局组件和局部组件过程解析
2019/10/10 Javascript
50行代码实现贪吃蛇(具体思路及代码)
2013/04/27 Python
使用python3.5仿微软记事本notepad
2016/06/15 Python
使用Python操作excel文件的实例代码
2017/10/15 Python
Django ORM框架的定时任务如何使用详解
2017/10/19 Python
Python基于Flask框架配置依赖包信息的项目迁移部署
2018/03/02 Python
使用Python读取二进制文件的实例讲解
2018/07/09 Python
Python小程序 控制鼠标循环点击代码实例
2019/10/08 Python
numpy 矩阵形状调整:拉伸、变成一位数组的实例
2020/06/18 Python
matplotlib实现数据实时刷新的示例代码
2021/01/05 Python
美国专业级皮肤病和spa品质护肤品的高级零售网站:SkinCareRx
2017/02/06 全球购物
PHP面试题附答案
2015/11/28 面试题
什么是测试驱动开发(TDD)
2012/02/15 面试题
生物技术毕业生自荐信
2013/10/23 职场文书
爱与责任演讲稿
2014/05/20 职场文书
加班费申请报告
2015/05/15 职场文书
2015年社区宣传工作总结
2015/05/20 职场文书
JS 4个超级实用的小技巧 提升开发效率
2021/10/05 Javascript