基于python中theano库的线性回归


Posted in Python onAugust 31, 2018

theano库是做deep learning重要的一部分,其最吸引人的地方之一是你给出符号化的公式之后,能自动生成导数。本文使用梯度下降的方法,进行数据拟合,现在把代码贴在下方

代码块

import numpy as np 
import theano.tensor as T 
import theano 
import time 

class Linear_Reg(object): 
  def __init__(self,x): 
    self.a = theano.shared(value = np.zeros((1,), dtype=theano.config.floatX),name = 'a') 
    self.b = theano.shared(value = np.zeros((1,), 
dtype=theano.config.floatX),name = 'b') 
    self.result = self.a * x + self.b 
    self.params = [self.a,self.b] 
  def msl(self,y): 
    return T.mean((y - self.result)**2) 

def regrun(rate,data,labels): 

  X = theano.shared(np.asarray(data, 
                 dtype=theano.config.floatX),borrow = True) 
  Y = theano.shared(np.asarray(labels, 
                 dtype=theano.config.floatX),borrow = True) 

  index = T.lscalar() #定义符号化的公式
  x = T.dscalar('x')  #定义符号化的公式
  y = T.dscalar('y')  #定义符号化的公式

  reg = Linear_Reg(x = x) 
  cost = reg.msl(y) 


  a_g = T.grad(cost = cost,wrt = reg.a) #计算梯度 
  b_g = T.grad(cost = cost, wrt = reg.b) #计算梯度

  updates=[(reg.a,reg.a - rate * a_g),(reg.b,reg.b - rate * b_g)] #更新参数
  train_model = theano.function(inputs=[index], outputs = reg.msl(y),updates = updates,givens = {x:X[index], y:Y[index]}) 

  done = True 
  err = 0.0 
  count = 0 
  last = 0.0 
  start_time = time.clock() 
  while done: 
    #err_s = [train_model(i) for i in xrange(data.shape[0])] 
    for i in xxx:
      err_s = [train_model(i) ]
      err = np.mean(err_s)  

    #print err 
    count = count + 1 
    if count > 10000 or err <0.1: 
      done = False 
    last = err 
  end_time = time.clock() 
  print 'Total time is :',end_time -start_time,' s' # 5.12s 
  print 'last error :',err 
  print 'a value : ',reg.a.get_value() # [ 2.92394467]  
  print 'b value : ',reg.b.get_value() # [ 1.81334458] 

if __name__ == '__main__':  
  rate = 0.01 
  data = np.linspace(1,10,10) 
  labels = data * 3 + np.ones(data.shape[0],dtype=np.float64) +np.random.rand(data.shape[0])
  regrun(rate,data,labels)

其基本思想是随机梯度下降。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python验证码识别的方法
Jul 10 Python
基于python yield机制的异步操作同步化编程模型
Mar 18 Python
Python温度转换实例分析
Jan 17 Python
Python SQLite3简介
Feb 22 Python
mac 安装python网络请求包requests方法
Jun 13 Python
python+selenium 定位到元素,无法点击的解决方法
Jan 30 Python
Python使用matplotlib实现交换式图形显示功能示例
Sep 06 Python
Python如何使用bokeh包和geojson数据绘制地图
Mar 21 Python
AI:如何训练机器学习的模型
Apr 16 Python
Python还能这么玩之用Python做个小游戏的外挂
Jun 04 Python
用python基于appium模块开发一个自动收取能量的小助手
Sep 25 Python
Python用any()函数检查字符串中的字母以及如何使用all()函数
Apr 14 Python
基于随机梯度下降的矩阵分解推荐算法(python)
Aug 31 #Python
python实现梯度下降算法
Mar 24 #Python
wtfPython—Python中一组有趣微妙的代码【收藏】
Aug 31 #Python
opencv python 图像去噪的实现方法
Aug 31 #Python
python+numpy+matplotalib实现梯度下降法
Aug 31 #Python
python实现随机梯度下降法
Mar 24 #Python
python实现决策树分类(2)
Aug 30 #Python
You might like
免费的ip数据库淘宝IP地址库简介和PHP调用实例
2014/04/08 PHP
PHP迭代与递归实现无限级分类
2017/08/28 PHP
php删除二维数组中的重复值方法
2018/03/12 PHP
打开超链需要“确认”对话框的方法
2007/03/08 Javascript
javascript Firefox与IE 替换节点的方法
2010/02/24 Javascript
JS数组的赋值介绍
2014/03/10 Javascript
angularjs中的e2e测试实例
2014/12/06 Javascript
Javascript核心读书有感之语言核心
2015/02/01 Javascript
js传值后台中文出现乱码的解决方法
2016/06/30 Javascript
将鼠标焦点定位到文本框最后(代码分享)
2017/01/11 Javascript
基于vue.js路由参数的实例讲解——简单易懂
2017/09/07 Javascript
微信小程序canvas拖拽、截图组件功能
2018/09/04 Javascript
详解vue-cli3多页应用改造
2019/06/04 Javascript
layui的layedit富文本赋值方法
2019/09/18 Javascript
vue之组件内监控$store中定义变量的变化详解
2019/11/08 Javascript
JavaScript中使用Spread运算符的八种方法总结
2020/06/18 Javascript
微信小程序实现多图上传
2020/06/19 Javascript
Element-ui el-tree新增和删除节点后如何刷新tree的实例
2020/08/31 Javascript
Vue项目配置跨域访问和代理proxy设置方式
2020/09/08 Javascript
python提取页面内url列表的方法
2015/05/25 Python
在Python的Django框架中显示对象子集的方法
2015/07/21 Python
关于Python中异常(Exception)的汇总
2017/01/18 Python
对Python字符串中的换行符和制表符介绍
2018/05/03 Python
Python实现计算文件MD5和SHA1的方法示例
2019/06/11 Python
python3.6生成器yield用法实例分析
2019/08/23 Python
Python协程操作之gevent(yield阻塞,greenlet),协程实现多任务(有规律的交替协作执行)用法详解
2019/10/14 Python
Python求凸包及多边形面积教程
2020/04/12 Python
Python unittest框架操作实例解析
2020/04/13 Python
基于Tensorflow的MNIST手写数字识别分类
2020/06/17 Python
蒂芙尼澳大利亚官方网站:Tiffany&Co. Australia
2017/08/27 全球购物
*p++ 自增p 还是p所指向的变量
2016/07/16 面试题
自动化专业个人求职信范文
2013/12/30 职场文书
单位工作证明书格式
2014/10/04 职场文书
党员贯彻十八大精神思想汇报范文
2014/10/25 职场文书
党员群众路线学习心得体会
2014/11/04 职场文书
2015年评职称个人工作总结
2015/10/15 职场文书