用tensorflow构建线性回归模型的示例代码


Posted in Python onMarch 05, 2018

用tensorflow构建简单的线性回归模型是tensorflow的一个基础样例,但是原有的样例存在一些问题,我在实际调试的过程中做了一点自己的改进,并且有一些体会。

首先总结一下tf构建模型的总体套路

1、先定义模型的整体图结构,未知的部分,比如输入就用placeholder来代替。

2、再定义最后与目标的误差函数。

3、最后选择优化方法。

另外几个值得注意的地方是:

1、tensorflow构建模型第一步是先用代码搭建图模型,此时图模型是静止的,是不产生任何运算结果的,必须使用Session来驱动。

2、第二步根据问题的不同要求构建不同的误差函数,这个函数就是要求优化的函数。

3、调用合适的优化器优化误差函数,注意,此时反向传播调整参数的过程隐藏在了图模型当中,并没有显式显现出来。

4、tensorflow的中文意思是张量流动,也就是说有两个意思,一个是参与运算的不仅仅是标量或是矩阵,甚至可以是具有很高维度的张量,第二个意思是这些数据在图模型中流动,不停地更新。

5、session的run函数中,按照传入的操作向上查找,凡是操作中涉及的无论是变量、常量都要参与运算,占位符则要在run过程中以字典形式传入。

以上时tensorflow的一点认识,下面是关于梯度下降的一点新认识。

1、梯度下降法分为批量梯度下降和随机梯度下降法,第一种是所有数据都参与运算后,计算误差函数,根据此误差函数来更新模型参数,实际调试发现,如果定义误差函数为平方误差函数,这个值很快就会飞掉,原因是,批量平方误差都加起来可能会很大,如果此时学习率比较高,那么调整就会过,造成模型参数向一个方向大幅调整,造成最终结果发散。所以这个时候要降低学习率,让参数变化不要太快。

2、随机梯度下降法,每次用一个数据计算误差函数,然后更新模型参数,这个方法有可能会造成结果出现震荡,而且麻烦的是由于要一个个取出数据参与运算,而不是像批量计算那样采用了广播或者向量化乘法的机制,收敛会慢一些。但是速度要比使用批量梯度下降要快,原因是不需要每次计算全部数据的梯度了。比较折中的办法是mini-batch,也就是每次选用一小部分数据做梯度下降,目前这也是最为常用的方法了。

3、epoch概念:所有样本集过完一轮,就是一个epoch,很明显,如果是严格的随机梯度下降法,一个epoch内更新了样本个数这么多次参数,而批量法只更新了一次。

以上是我个人的一点认识,希望大家看到有不对的地方及时批评指针,不胜感激!

#encoding=utf-8 
__author__ = 'freedom' 
import tensorflow as tf 
import numpy as np 
 
def createData(dataNum,w,b,sigma): 
 train_x = np.arange(dataNum) 
 train_y = w*train_x+b+np.random.randn()*sigma 
 #print train_x 
 #print train_y 
 return train_x,train_y 
 
def linerRegression(train_x,train_y,epoch=100000,rate = 0.000001): 
 train_x = np.array(train_x) 
 train_y = np.array(train_y) 
 n = train_x.shape[0] 
 x = tf.placeholder("float") 
 y = tf.placeholder("float") 
 w = tf.Variable(tf.random_normal([1])) # 生成随机权重 
 b = tf.Variable(tf.random_normal([1])) 
 
 pred = tf.add(tf.mul(x,w),b) 
 loss = tf.reduce_sum(tf.pow(pred-y,2)) 
 optimizer = tf.train.GradientDescentOptimizer(rate).minimize(loss) 
 init = tf.initialize_all_variables() 
 
 sess = tf.Session() 
 sess.run(init) 
 print 'w start is ',sess.run(w) 
 print 'b start is ',sess.run(b) 
 for index in range(epoch): 
  #for tx,ty in zip(train_x,train_y): 
   #sess.run(optimizer,{x:tx,y:ty}) 
  sess.run(optimizer,{x:train_x,y:train_y}) 
  # print 'w is ',sess.run(w) 
  # print 'b is ',sess.run(b) 
  # print 'pred is ',sess.run(pred,{x:train_x}) 
  # print 'loss is ',sess.run(loss,{x:train_x,y:train_y}) 
  #print '------------------' 
 print 'loss is ',sess.run(loss,{x:train_x,y:train_y}) 
 w = sess.run(w) 
 b = sess.run(b) 
 return w,b 
 
def predictionTest(test_x,test_y,w,b): 
 W = tf.placeholder(tf.float32) 
 B = tf.placeholder(tf.float32) 
 X = tf.placeholder(tf.float32) 
 Y = tf.placeholder(tf.float32) 
 n = test_x.shape[0] 
 pred = tf.add(tf.mul(X,W),B) 
 loss = tf.reduce_mean(tf.pow(pred-Y,2)) 
 sess = tf.Session() 
 loss = sess.run(loss,{X:test_x,Y:test_y,W:w,B:b}) 
 return loss 
 
if __name__ == "__main__": 
 train_x,train_y = createData(50,2.0,7.0,1.0) 
 test_x,test_y = createData(20,2.0,7.0,1.0) 
 w,b = linerRegression(train_x,train_y) 
 print 'weights',w 
 print 'bias',b 
 loss = predictionTest(test_x,test_y,w,b) 
 print loss

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python的Django框架中的模版相关知识
Jul 15 Python
Python的Flask框架中的Jinja2模板引擎学习教程
Jun 30 Python
Python使用迭代器捕获Generator返回值的方法
Apr 05 Python
tensorflow实现KNN识别MNIST
Mar 12 Python
详解Django+Uwsgi+Nginx的生产环境部署
Jun 25 Python
用python一行代码得到数组中某个元素的个数方法
Jan 28 Python
pytorch 更改预训练模型网络结构的方法
Aug 19 Python
python应用文件读取与登录注册功能
Sep 23 Python
flask框架json数据的拿取和返回操作示例
Nov 28 Python
Django更新models数据库结构步骤
Apr 01 Python
详解Python中的编码问题(encoding与decode、str与bytes)
Sep 30 Python
Python+Tkinter制作专属图形化界面
Apr 01 Python
详解python实现线程安全的单例模式
Mar 05 #Python
分析python动态规划的递归、非递归实现
Mar 04 #Python
python3.x上post发送json数据
Mar 04 #Python
python数据封装json格式数据
Mar 04 #Python
Python爬虫实例扒取2345天气预报
Mar 04 #Python
Python爬虫设置代理IP的方法(爬虫技巧)
Mar 04 #Python
浅析python实现scrapy定时执行爬虫
Mar 04 #Python
You might like
PHP网页游戏学习之Xnova(ogame)源码解读(十四)
2014/06/26 PHP
PHPUnit安装及使用示例
2014/10/29 PHP
PHP实现服务器状态监控的方法
2014/12/09 PHP
php操作xml入门之xml标签的属性分析
2015/01/23 PHP
PHP获取某个月最大天数(最后一天)的方法
2015/07/29 PHP
PHP的Yii框架中行为的定义与绑定方法讲解
2016/03/18 PHP
php mysqli查询语句返回值类型实例分析
2016/06/29 PHP
PHP Post获取不到非表单数据的问题解决办法
2018/02/27 PHP
Yii 框架入口脚本示例分析
2020/05/19 PHP
如何在父窗口中得知window.open()出的子窗口关闭事件
2013/10/15 Javascript
jQuery实现限制textarea文本框输入字符数量的方法
2015/05/28 Javascript
浅谈JavaScript的Polymer框架中的behaviors对象
2015/07/29 Javascript
原生js编写焦点图效果
2016/12/08 Javascript
footer定位页面底部(代码分享)
2017/03/07 Javascript
vuex的使用及持久化state的方式详解
2018/01/23 Javascript
vue-froala-wysiwyg 富文本编辑器功能
2019/09/19 Javascript
微信小程序用户拒绝授权的处理方法详解
2019/09/20 Javascript
vue动态禁用控件绑定disable的例子
2019/10/28 Javascript
使用axios发送post请求,将JSON数据改为form类型的示例
2019/10/31 Javascript
python使用paramiko模块实现ssh远程登陆上传文件并执行
2014/01/27 Python
在Python程序中操作文件之flush()方法的使用教程
2015/05/24 Python
Python实现采用进度条实时显示处理进度的方法
2017/12/19 Python
使用python的pandas库读取csv文件保存至mysql数据库
2018/08/20 Python
wtfPython—Python中一组有趣微妙的代码【收藏】
2018/08/31 Python
不归路系列:Python入门之旅-一定要注意缩进!!!(推荐)
2019/04/16 Python
用python写爬虫简单吗
2020/07/28 Python
Belle Maison倍美丛官网:日本千趣会旗下邮购网站
2016/07/22 全球购物
植物选择:Botanic Choice
2017/02/15 全球购物
印度在线购买电子产品网站:Croma
2020/01/02 全球购物
优秀演讲稿范文
2013/12/29 职场文书
《争吵》教学反思
2014/02/15 职场文书
法人授权委托书
2014/04/03 职场文书
2014年乡镇工会工作总结
2014/12/02 职场文书
勤俭节约主题班会
2015/08/13 职场文书
工作自我评价范文
2019/03/21 职场文书
浅谈如何写好演讲稿?
2019/06/12 职场文书