TensorFlow实现iris数据集线性回归


Posted in Python onSeptember 07, 2018

本文将遍历批量数据点并让TensorFlow更新斜率和y截距。这次将使用Scikit Learn的内建iris数据集。特别地,我们将用数据点(x值代表花瓣宽度,y值代表花瓣长度)找到最优直线。选择这两种特征是因为它们具有线性关系,在后续结果中将会看到。本文将使用L2正则损失函数。

# 用TensorFlow实现线性回归算法
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve linear regression.
# y = Ax + b
#
# We will use the iris data, specifically:
# y = Sepal Length
# x = Petal Width

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

# 批量大小
batch_size = 25

# Initialize 占位符
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 模型变量
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# 增加线性模型,y=Ax+b
model_output = tf.add(tf.matmul(x_data, A), b)

# 声明L2损失函数,其为批量损失的平均值。
loss = tf.reduce_mean(tf.square(y_target - model_output))

# 声明优化器 学习率设为0.05
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(loss)

# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 批量训练遍历迭代
# 迭代100次,每25次迭代输出变量值和损失值
loss_vec = []
for i in range(100):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = np.transpose([x_vals[rand_index]])
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)
  if (i+1)%25==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Loss = ' + str(temp_loss))

# 抽取系数
[slope] = sess.run(A)
[y_intercept] = sess.run(b)

# 创建最佳拟合直线
best_fit = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)

# 绘制两幅图
# 拟合的直线
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='Best fit line', linewidth=3)
plt.legend(loc='upper left')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
# 迭代100次的L2正则损失函数
plt.plot(loss_vec, 'k-')
plt.title('L2 Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('L2 Loss')
plt.show()

结果:

Step #25 A = [[ 1.93474162]] b = [[ 3.11190438]]
Loss = 1.21364
Step #50 A = [[ 1.48641717]] b = [[ 3.81004381]]
Loss = 0.945256
Step #75 A = [[ 1.26089203]] b = [[ 4.221035]]
Loss = 0.254756
Step #100 A = [[ 1.1693294]] b = [[ 4.47258472]]
Loss = 0.281654

TensorFlow实现iris数据集线性回归

TensorFlow实现iris数据集线性回归

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python matplotlib通过plt.scatter画空心圆标记出特定的点方法
Dec 13 Python
在Pycharm terminal中字体大小设置的方法
Jan 16 Python
python装饰器常见使用方法分析
Jun 26 Python
Django中如何使用sass的方法步骤
Jul 09 Python
Django框架之DRF 基于mixins来封装的视图详解
Jul 23 Python
python函数的作用域及关键字详解
Aug 20 Python
python 初始化一个定长的数组实例
Dec 02 Python
PyQt5+Pycharm安装和配置图文教程详解
Mar 24 Python
如何解决cmd运行python提示不是内部命令
Jul 01 Python
降低python版本的操作方法
Sep 11 Python
python判断all函数输出结果是否为true的方法
Dec 03 Python
python实现三种随机请求头方式
Jan 05 Python
TensorFlow实现模型评估
Sep 07 #Python
使用tensorflow实现线性svm
Sep 07 #Python
Python多进程池 multiprocessing Pool用法示例
Sep 07 #Python
详解python while 函数及while和for的区别
Sep 07 #Python
使用TensorFlow实现SVM
Sep 06 #Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
python实现机器学习之多元线性回归
Sep 06 #Python
You might like
输出控制类
2006/10/09 PHP
几个学习PHP的网址
2006/11/25 PHP
php利用cookie实现自动登录的方法
2014/12/10 PHP
js parentElement和offsetParent之间的区别
2010/03/23 Javascript
深入分析JSONP跨域的原理
2014/12/10 Javascript
NodeJS学习笔记之FS文件模块
2015/01/13 NodeJs
JavaScript 监控微信浏览器且自带返回按钮时间
2016/11/27 Javascript
angularjs下拉框空白的解决办法
2017/06/20 Javascript
在React中如何优雅的处理事件响应详解
2017/07/24 Javascript
Node.js学习之查询字符串解析querystring详解
2017/09/28 Javascript
js用类封装pop弹窗组件
2017/10/08 Javascript
vue 不使用select实现下拉框功能(推荐)
2018/05/17 Javascript
js中Generator函数的深入讲解
2019/04/07 Javascript
JavaScript实现的联动菜单特效示例
2019/07/08 Javascript
使用layui 的layedit定义自己的toolbar方法
2019/09/18 Javascript
[05:02]2014DOTA2 TI中国区预选赛精彩TOPPLAY第三弹
2014/06/25 DOTA
Java中重定向输出流实现用文件记录程序日志
2015/06/12 Python
Python随手笔记之标准类型内建函数
2015/12/02 Python
python 时间信息“2018-02-04 18:23:35“ 解析成字典形式的结果代码详解
2018/04/19 Python
python调用c++ ctype list传数组或者返回数组的方法
2019/02/13 Python
django 多数据库及分库实现方式
2020/04/01 Python
python处理写入数据代码讲解
2020/10/22 Python
Python爬虫Scrapy框架CrawlSpider原理及使用案例
2020/11/20 Python
iframe跨域的几种常用方法
2019/11/11 HTML / CSS
春秋航空官方网站:Spring Airlines
2017/09/27 全球购物
Skip Hop官网:好莱坞宝宝挚爱品牌
2018/06/17 全球购物
印尼极简主义和实惠的在线家具店:Fabelio
2019/03/27 全球购物
寒假实习自荐信
2014/01/26 职场文书
数学与统计学院学生个人职业生涯规划书
2014/02/10 职场文书
质量月活动总结
2014/08/26 职场文书
领导干部作风建设总结
2014/10/23 职场文书
2014年销售人员工作总结
2014/11/27 职场文书
2015年社区科普工作总结
2015/05/13 职场文书
办公室主任岗位竞聘书
2015/09/15 职场文书
MySQL的全局锁和表级锁的具体使用
2021/08/23 MySQL
NodeJs使用webpack打包项目的方法详解
2022/02/28 NodeJs