TensorFlow实现iris数据集线性回归


Posted in Python onSeptember 07, 2018

本文将遍历批量数据点并让TensorFlow更新斜率和y截距。这次将使用Scikit Learn的内建iris数据集。特别地,我们将用数据点(x值代表花瓣宽度,y值代表花瓣长度)找到最优直线。选择这两种特征是因为它们具有线性关系,在后续结果中将会看到。本文将使用L2正则损失函数。

# 用TensorFlow实现线性回归算法
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve linear regression.
# y = Ax + b
#
# We will use the iris data, specifically:
# y = Sepal Length
# x = Petal Width

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

# 批量大小
batch_size = 25

# Initialize 占位符
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 模型变量
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# 增加线性模型,y=Ax+b
model_output = tf.add(tf.matmul(x_data, A), b)

# 声明L2损失函数,其为批量损失的平均值。
loss = tf.reduce_mean(tf.square(y_target - model_output))

# 声明优化器 学习率设为0.05
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(loss)

# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 批量训练遍历迭代
# 迭代100次,每25次迭代输出变量值和损失值
loss_vec = []
for i in range(100):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = np.transpose([x_vals[rand_index]])
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)
  if (i+1)%25==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Loss = ' + str(temp_loss))

# 抽取系数
[slope] = sess.run(A)
[y_intercept] = sess.run(b)

# 创建最佳拟合直线
best_fit = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)

# 绘制两幅图
# 拟合的直线
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='Best fit line', linewidth=3)
plt.legend(loc='upper left')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
# 迭代100次的L2正则损失函数
plt.plot(loss_vec, 'k-')
plt.title('L2 Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('L2 Loss')
plt.show()

结果:

Step #25 A = [[ 1.93474162]] b = [[ 3.11190438]]
Loss = 1.21364
Step #50 A = [[ 1.48641717]] b = [[ 3.81004381]]
Loss = 0.945256
Step #75 A = [[ 1.26089203]] b = [[ 4.221035]]
Loss = 0.254756
Step #100 A = [[ 1.1693294]] b = [[ 4.47258472]]
Loss = 0.281654

TensorFlow实现iris数据集线性回归

TensorFlow实现iris数据集线性回归

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python的单元测试
Apr 28 Python
python实现数值积分的Simpson方法实例分析
Jun 05 Python
Python3 XML 获取雅虎天气的实现方法
Feb 01 Python
Python读取Word(.docx)正文信息的方法
Mar 15 Python
对python opencv 添加文字 cv2.putText 的各参数介绍
Dec 05 Python
python将list转为matrix的方法
Dec 12 Python
Python使用post及get方式提交数据的实例
Jan 24 Python
python函数的万能参数传参详解
Jul 26 Python
在Python IDLE 下调用anaconda中的库教程
Mar 09 Python
基于PyTorch的permute和reshape/view的区别介绍
Jun 18 Python
python爬不同图片分别保存在不同文件夹中的实现
Apr 02 Python
Python实现科学占卜 让视频自动打码
Apr 09 Python
TensorFlow实现模型评估
Sep 07 #Python
使用tensorflow实现线性svm
Sep 07 #Python
Python多进程池 multiprocessing Pool用法示例
Sep 07 #Python
详解python while 函数及while和for的区别
Sep 07 #Python
使用TensorFlow实现SVM
Sep 06 #Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
python实现机器学习之多元线性回归
Sep 06 #Python
You might like
Php output buffering缓存及程序缓存深入解析
2013/07/15 PHP
php简单备份与还原MySql的方法
2016/05/09 PHP
Javascript this关键字使用分析
2008/10/21 Javascript
jQuery源码分析-04 选择器-Sizzle-工作原理分析
2011/11/14 Javascript
不同的jQuery API来处理不同的浏览器事件
2012/12/09 Javascript
js的alert样式如何更改如背景颜色
2014/01/22 Javascript
jQuery学习笔记之总体架构
2014/06/03 Javascript
JS+DIV+CSS实现的经典标签切换效果代码
2015/09/14 Javascript
JS日期格式化之javascript Date format
2015/10/01 Javascript
js实现无缝循环滚动
2020/06/23 Javascript
JS实现页面进入和返回定位到具体位置
2016/12/08 Javascript
Bootstrap CSS布局之图像
2016/12/17 Javascript
微信小程序 新建登录页并实现tabBar隐藏
2017/06/13 Javascript
使用Angular-CLI构建NPM包的方法
2018/09/07 Javascript
vue element table 表格请求后台排序的方法
2018/09/28 Javascript
微信小程序实现弹出菜单动画
2019/06/21 Javascript
javascript执行上下文、变量对象实例分析
2020/04/25 Javascript
vue路由权限校验功能的实现代码
2020/06/07 Javascript
nodeJs项目在阿里云的简单部署
2020/11/27 NodeJs
python使用PythonMagick将jpg图片转换成ico图片的方法
2015/03/26 Python
python按照多个条件排序的方法
2019/02/08 Python
pyqt5 禁止窗口最大化和禁止窗口拉伸的方法
2019/06/18 Python
Python3显示当前时间、计算时间差及时间加减法示例代码
2019/09/07 Python
用Python做一个久坐提醒小助手的示例代码
2020/02/10 Python
python GUI库图形界面开发之PyQt5开发环境配置与基础使用
2020/02/25 Python
Pythonic版二分查找实现过程原理解析
2020/08/11 Python
你可能不熟练的十个前端HTML5经典面试题
2018/07/03 HTML / CSS
关于运动会的稿件
2014/02/02 职场文书
捐献物资倡议书范文
2014/05/19 职场文书
学校督导评估方案
2014/06/10 职场文书
公司户外活动总结
2014/07/04 职场文书
办理信用卡工作证明
2014/09/30 职场文书
我们的节日中秋节活动总结
2015/03/23 职场文书
幼儿园端午节活动总结
2015/05/05 职场文书
海上钢琴师观后感
2015/06/03 职场文书
2015年小学实验室工作总结
2015/07/28 职场文书