TensorFlow实现iris数据集线性回归


Posted in Python onSeptember 07, 2018

本文将遍历批量数据点并让TensorFlow更新斜率和y截距。这次将使用Scikit Learn的内建iris数据集。特别地,我们将用数据点(x值代表花瓣宽度,y值代表花瓣长度)找到最优直线。选择这两种特征是因为它们具有线性关系,在后续结果中将会看到。本文将使用L2正则损失函数。

# 用TensorFlow实现线性回归算法
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve linear regression.
# y = Ax + b
#
# We will use the iris data, specifically:
# y = Sepal Length
# x = Petal Width

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

# 批量大小
batch_size = 25

# Initialize 占位符
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 模型变量
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# 增加线性模型,y=Ax+b
model_output = tf.add(tf.matmul(x_data, A), b)

# 声明L2损失函数,其为批量损失的平均值。
loss = tf.reduce_mean(tf.square(y_target - model_output))

# 声明优化器 学习率设为0.05
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(loss)

# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 批量训练遍历迭代
# 迭代100次,每25次迭代输出变量值和损失值
loss_vec = []
for i in range(100):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = np.transpose([x_vals[rand_index]])
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)
  if (i+1)%25==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Loss = ' + str(temp_loss))

# 抽取系数
[slope] = sess.run(A)
[y_intercept] = sess.run(b)

# 创建最佳拟合直线
best_fit = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)

# 绘制两幅图
# 拟合的直线
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='Best fit line', linewidth=3)
plt.legend(loc='upper left')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
# 迭代100次的L2正则损失函数
plt.plot(loss_vec, 'k-')
plt.title('L2 Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('L2 Loss')
plt.show()

结果:

Step #25 A = [[ 1.93474162]] b = [[ 3.11190438]]
Loss = 1.21364
Step #50 A = [[ 1.48641717]] b = [[ 3.81004381]]
Loss = 0.945256
Step #75 A = [[ 1.26089203]] b = [[ 4.221035]]
Loss = 0.254756
Step #100 A = [[ 1.1693294]] b = [[ 4.47258472]]
Loss = 0.281654

TensorFlow实现iris数据集线性回归

TensorFlow实现iris数据集线性回归

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之有容乃大的list(3)
Sep 15 Python
老生常谈Python startswith()函数与endswith函数
Sep 08 Python
Numpy 将二维图像矩阵转换为一维向量的方法
Jun 05 Python
Python3.7 新特性之dataclass装饰器
May 27 Python
浅谈python图片处理Image和skimage的区别
Aug 04 Python
Python判断字符串是否xx开始或结尾的示例
Aug 08 Python
python中树与树的表示知识点总结
Sep 14 Python
python使用信号量动态更新配置文件的操作
Apr 01 Python
基于Python的OCR实现示例
Apr 03 Python
Pycharm安装并配置jupyter notebook的实现
May 18 Python
python 合并多个excel中同名的sheet
Jan 22 Python
Python利器openpyxl之操作excel表格
Apr 17 Python
TensorFlow实现模型评估
Sep 07 #Python
使用tensorflow实现线性svm
Sep 07 #Python
Python多进程池 multiprocessing Pool用法示例
Sep 07 #Python
详解python while 函数及while和for的区别
Sep 07 #Python
使用TensorFlow实现SVM
Sep 06 #Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
python实现机器学习之多元线性回归
Sep 06 #Python
You might like
菜鸟学PHP之Smarty入门
2007/01/04 PHP
PHP实现使用DOM将XML数据存入数组的方法示例
2017/09/27 PHP
Javascript document.referrer判断访客来源网址
2020/05/15 Javascript
关于juqery radio写法的兼容性问题(新老版本jquery)
2010/06/14 Javascript
JavaScript中判断函数是new还是()调用的区别说明
2011/04/07 Javascript
jQuery点击后一组图片左右滑动的实现代码
2012/08/16 Javascript
jquery怎样实现ajax联动框(一)
2013/03/08 Javascript
jQuery 遍历- 关于closest() 的方法介绍以及与parents()的方法区别分析
2013/04/26 Javascript
js 剪切板应用clipboardData详细解析
2013/12/17 Javascript
在javascript中执行任意html代码的方法示例解读
2013/12/25 Javascript
Nodejs极简入门教程(三):进程
2014/10/27 NodeJs
Javascript 计算字符串在localStorage中所占字节数
2015/10/21 Javascript
JavaScript直播评论发弹幕切图功能点集合效果代码
2016/06/26 Javascript
详解BootStrap中Affix控件的使用及保持布局的美观的方法
2016/07/08 Javascript
jQuery的extend方法【三种】
2016/12/14 Javascript
如何获取vue单文件自身源码路径
2019/05/06 Javascript
解决layui表格的表头不滚动的问题
2019/09/04 Javascript
js实现课堂随机点名系统
2019/11/21 Javascript
小程序跳转H5页面的方法步骤
2020/03/06 Javascript
微信小程序用canvas画图并分享
2020/03/09 Javascript
vue-router重写push方法,解决相同路径跳转报错问题
2020/08/07 Javascript
[01:18]PWL开团时刻DAY4——圣剑与抢盾
2020/11/03 DOTA
Python Web框架Flask中使用百度云存储BCS实例
2015/02/08 Python
python显示生日是星期几的方法
2015/05/27 Python
VTK与Python实现机械臂三维模型可视化详解
2017/12/13 Python
Python学习笔记之错误和异常及访问错误消息详解
2019/08/08 Python
使用OpenCV实现仿射变换—旋转功能
2019/08/29 Python
Django ORM判断查询结果是否为空,判断django中的orm为空实例
2020/07/09 Python
瑞贝卡·明可弗包包官网:Rebecca Minkoff
2016/07/21 全球购物
编写用C语言实现的求n阶阶乘问题的递归算法
2014/10/21 面试题
Sony C++笔试题
2013/03/10 面试题
安全生产汇报材料
2014/02/17 职场文书
2015年医务人员医德医风自我评价
2015/03/03 职场文书
卫生院艾滋病宣传活动总结
2015/05/09 职场文书
单位接收证明格式
2015/06/18 职场文书
励志语录:你若不勇敢,谁替你坚强
2019/11/08 职场文书