用tensorflow构建线性回归模型的示例代码


Posted in Python onMarch 05, 2018

用tensorflow构建简单的线性回归模型是tensorflow的一个基础样例,但是原有的样例存在一些问题,我在实际调试的过程中做了一点自己的改进,并且有一些体会。

首先总结一下tf构建模型的总体套路

1、先定义模型的整体图结构,未知的部分,比如输入就用placeholder来代替。

2、再定义最后与目标的误差函数。

3、最后选择优化方法。

另外几个值得注意的地方是:

1、tensorflow构建模型第一步是先用代码搭建图模型,此时图模型是静止的,是不产生任何运算结果的,必须使用Session来驱动。

2、第二步根据问题的不同要求构建不同的误差函数,这个函数就是要求优化的函数。

3、调用合适的优化器优化误差函数,注意,此时反向传播调整参数的过程隐藏在了图模型当中,并没有显式显现出来。

4、tensorflow的中文意思是张量流动,也就是说有两个意思,一个是参与运算的不仅仅是标量或是矩阵,甚至可以是具有很高维度的张量,第二个意思是这些数据在图模型中流动,不停地更新。

5、session的run函数中,按照传入的操作向上查找,凡是操作中涉及的无论是变量、常量都要参与运算,占位符则要在run过程中以字典形式传入。

以上时tensorflow的一点认识,下面是关于梯度下降的一点新认识。

1、梯度下降法分为批量梯度下降和随机梯度下降法,第一种是所有数据都参与运算后,计算误差函数,根据此误差函数来更新模型参数,实际调试发现,如果定义误差函数为平方误差函数,这个值很快就会飞掉,原因是,批量平方误差都加起来可能会很大,如果此时学习率比较高,那么调整就会过,造成模型参数向一个方向大幅调整,造成最终结果发散。所以这个时候要降低学习率,让参数变化不要太快。

2、随机梯度下降法,每次用一个数据计算误差函数,然后更新模型参数,这个方法有可能会造成结果出现震荡,而且麻烦的是由于要一个个取出数据参与运算,而不是像批量计算那样采用了广播或者向量化乘法的机制,收敛会慢一些。但是速度要比使用批量梯度下降要快,原因是不需要每次计算全部数据的梯度了。比较折中的办法是mini-batch,也就是每次选用一小部分数据做梯度下降,目前这也是最为常用的方法了。

3、epoch概念:所有样本集过完一轮,就是一个epoch,很明显,如果是严格的随机梯度下降法,一个epoch内更新了样本个数这么多次参数,而批量法只更新了一次。

以上是我个人的一点认识,希望大家看到有不对的地方及时批评指针,不胜感激!

#encoding=utf-8 
__author__ = 'freedom' 
import tensorflow as tf 
import numpy as np 
 
def createData(dataNum,w,b,sigma): 
 train_x = np.arange(dataNum) 
 train_y = w*train_x+b+np.random.randn()*sigma 
 #print train_x 
 #print train_y 
 return train_x,train_y 
 
def linerRegression(train_x,train_y,epoch=100000,rate = 0.000001): 
 train_x = np.array(train_x) 
 train_y = np.array(train_y) 
 n = train_x.shape[0] 
 x = tf.placeholder("float") 
 y = tf.placeholder("float") 
 w = tf.Variable(tf.random_normal([1])) # 生成随机权重 
 b = tf.Variable(tf.random_normal([1])) 
 
 pred = tf.add(tf.mul(x,w),b) 
 loss = tf.reduce_sum(tf.pow(pred-y,2)) 
 optimizer = tf.train.GradientDescentOptimizer(rate).minimize(loss) 
 init = tf.initialize_all_variables() 
 
 sess = tf.Session() 
 sess.run(init) 
 print 'w start is ',sess.run(w) 
 print 'b start is ',sess.run(b) 
 for index in range(epoch): 
  #for tx,ty in zip(train_x,train_y): 
   #sess.run(optimizer,{x:tx,y:ty}) 
  sess.run(optimizer,{x:train_x,y:train_y}) 
  # print 'w is ',sess.run(w) 
  # print 'b is ',sess.run(b) 
  # print 'pred is ',sess.run(pred,{x:train_x}) 
  # print 'loss is ',sess.run(loss,{x:train_x,y:train_y}) 
  #print '------------------' 
 print 'loss is ',sess.run(loss,{x:train_x,y:train_y}) 
 w = sess.run(w) 
 b = sess.run(b) 
 return w,b 
 
def predictionTest(test_x,test_y,w,b): 
 W = tf.placeholder(tf.float32) 
 B = tf.placeholder(tf.float32) 
 X = tf.placeholder(tf.float32) 
 Y = tf.placeholder(tf.float32) 
 n = test_x.shape[0] 
 pred = tf.add(tf.mul(X,W),B) 
 loss = tf.reduce_mean(tf.pow(pred-Y,2)) 
 sess = tf.Session() 
 loss = sess.run(loss,{X:test_x,Y:test_y,W:w,B:b}) 
 return loss 
 
if __name__ == "__main__": 
 train_x,train_y = createData(50,2.0,7.0,1.0) 
 test_x,test_y = createData(20,2.0,7.0,1.0) 
 w,b = linerRegression(train_x,train_y) 
 print 'weights',w 
 print 'bias',b 
 loss = predictionTest(test_x,test_y,w,b) 
 print loss

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python单链表实现代码实例
Nov 21 Python
python爬虫教程之爬取百度贴吧并下载的示例
Mar 07 Python
python用reduce和map把字符串转为数字的方法
Dec 19 Python
python实现windows下文件备份脚本
May 27 Python
Python 16进制与中文相互转换的实现方法
Jul 09 Python
python数据处理 根据颜色对图片进行分类的方法
Dec 08 Python
python字典的常用方法总结
Jul 31 Python
ORM Django 终端打印 SQL 语句实现解析
Aug 09 Python
Python3 tkinter 实现文件读取及保存功能
Sep 12 Python
Python谱减法语音降噪实例
Dec 18 Python
vscode+PyQt5安装详解步骤
Aug 12 Python
利用python制作拼图小游戏的全过程
Dec 04 Python
详解python实现线程安全的单例模式
Mar 05 #Python
分析python动态规划的递归、非递归实现
Mar 04 #Python
python3.x上post发送json数据
Mar 04 #Python
python数据封装json格式数据
Mar 04 #Python
Python爬虫实例扒取2345天气预报
Mar 04 #Python
Python爬虫设置代理IP的方法(爬虫技巧)
Mar 04 #Python
浅析python实现scrapy定时执行爬虫
Mar 04 #Python
You might like
分享下PHP register_globals 值为on与off的理解
2013/09/26 PHP
PHP addcslashes()函数讲解
2019/02/03 PHP
php 的多进程操作实践案例分析
2020/02/28 PHP
onclick与listeners的执行先后问题详细解剖
2013/01/07 Javascript
js中创建对象的几种方式示例介绍
2014/01/26 Javascript
jquery mobile动态添加元素之后不能正确渲染解决方法说明
2014/03/05 Javascript
JQuery插入DOM节点的方法
2015/06/11 Javascript
jquery实现左右滑动菜单效果代码
2015/08/27 Javascript
深入浅析同源策略和跨域访问
2015/11/26 Javascript
JavaScript的ExtJS框架中数面板TreePanel的使用实例解析
2016/05/21 Javascript
Highcharts 多个Y轴动态刷新数据的实现代码
2016/05/28 Javascript
js图片上传前预览功能(兼容所有浏览器)
2016/08/24 Javascript
js发送短信倒计时的简单实现方法
2016/09/08 Javascript
ionic cordova一次上传多张图片(类似input file提交表单)的实现方法
2016/12/16 Javascript
jQuery实现复制到粘贴板功能
2017/02/11 Javascript
js从输入框读取内容,比较两个数字的大小方法
2017/03/13 Javascript
Vue+ElementUI实现表单动态渲染、可视化配置的方法
2018/03/07 Javascript
详解Chart.js轻量级图表库的使用经验
2018/05/22 Javascript
微信小程序授权登陆及每次检查是否授权实例代码
2019/09/18 Javascript
javascript History对象原理解析
2020/02/17 Javascript
[02:33]2014DOTA2 TI每日综述 LGD涉险晋级DK闯入胜者组
2014/07/14 DOTA
python动态加载包的方法小结
2016/04/18 Python
Python3连接MySQL(pymysql)模拟转账实现代码
2016/05/24 Python
Python+Pika+RabbitMQ环境部署及实现工作队列的实例教程
2016/06/29 Python
python实现分页效果
2017/10/25 Python
人工智能最火编程语言 Python大战Java!
2017/11/13 Python
python微信公众号之关注公众号自动回复
2018/10/25 Python
python绘制地震散点图
2019/06/18 Python
python-django中的APPEND_SLASH实现方法
2019/06/21 Python
python判断一个对象是否可迭代的例子
2019/07/22 Python
Michael Kors美国官网:美式奢侈生活风格的代表
2016/11/25 全球购物
英国最大的运动营养公司之一:LA Muscle
2018/07/02 全球购物
IdealFit官方网站:女性蛋白质、补充剂和运动服装
2019/03/24 全球购物
关爱留守儿童倡议书
2014/04/15 职场文书
宣传活动总结范文
2014/07/01 职场文书
pytorch 实现多个Dataloader同时训练
2021/05/29 Python