TensorFlow实现iris数据集线性回归


Posted in Python onSeptember 07, 2018

本文将遍历批量数据点并让TensorFlow更新斜率和y截距。这次将使用Scikit Learn的内建iris数据集。特别地,我们将用数据点(x值代表花瓣宽度,y值代表花瓣长度)找到最优直线。选择这两种特征是因为它们具有线性关系,在后续结果中将会看到。本文将使用L2正则损失函数。

# 用TensorFlow实现线性回归算法
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve linear regression.
# y = Ax + b
#
# We will use the iris data, specifically:
# y = Sepal Length
# x = Petal Width

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

# 批量大小
batch_size = 25

# Initialize 占位符
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 模型变量
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# 增加线性模型,y=Ax+b
model_output = tf.add(tf.matmul(x_data, A), b)

# 声明L2损失函数,其为批量损失的平均值。
loss = tf.reduce_mean(tf.square(y_target - model_output))

# 声明优化器 学习率设为0.05
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(loss)

# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 批量训练遍历迭代
# 迭代100次,每25次迭代输出变量值和损失值
loss_vec = []
for i in range(100):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = np.transpose([x_vals[rand_index]])
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)
  if (i+1)%25==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Loss = ' + str(temp_loss))

# 抽取系数
[slope] = sess.run(A)
[y_intercept] = sess.run(b)

# 创建最佳拟合直线
best_fit = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)

# 绘制两幅图
# 拟合的直线
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='Best fit line', linewidth=3)
plt.legend(loc='upper left')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
# 迭代100次的L2正则损失函数
plt.plot(loss_vec, 'k-')
plt.title('L2 Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('L2 Loss')
plt.show()

结果:

Step #25 A = [[ 1.93474162]] b = [[ 3.11190438]]
Loss = 1.21364
Step #50 A = [[ 1.48641717]] b = [[ 3.81004381]]
Loss = 0.945256
Step #75 A = [[ 1.26089203]] b = [[ 4.221035]]
Loss = 0.254756
Step #100 A = [[ 1.1693294]] b = [[ 4.47258472]]
Loss = 0.281654

TensorFlow实现iris数据集线性回归

TensorFlow实现iris数据集线性回归

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 字符串定义
Sep 25 Python
Python的Bottle框架中获取制定cookie的教程
Apr 24 Python
Python增量循环删除MySQL表数据的方法
Sep 23 Python
Python如何抓取天猫商品详细信息及交易记录
Feb 23 Python
Python内置函数reversed()用法分析
Mar 20 Python
详解django中使用定时任务的方法
Sep 27 Python
python 日期排序的实例代码
Jul 11 Python
pytorch中nn.Conv1d的用法详解
Dec 31 Python
python获取依赖包和安装依赖包教程
Feb 13 Python
如何解决安装python3.6.1失败
Jul 01 Python
python smtplib发送多个email联系人的实现
Oct 09 Python
详解Python中的进程和线程
Jun 23 Python
TensorFlow实现模型评估
Sep 07 #Python
使用tensorflow实现线性svm
Sep 07 #Python
Python多进程池 multiprocessing Pool用法示例
Sep 07 #Python
详解python while 函数及while和for的区别
Sep 07 #Python
使用TensorFlow实现SVM
Sep 06 #Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
python实现机器学习之多元线性回归
Sep 06 #Python
You might like
PHP.MVC的模板标签系统(一)
2006/09/05 PHP
php获取textarea的值并处理回车换行的方法
2014/10/20 PHP
PHP获取数组中重复最多的元素的实现方法
2014/11/11 PHP
PHP Yaf框架的简单安装使用教程(推荐)
2016/06/08 PHP
php基于ob_start(ob_gzhandler)实现网页压缩功能的方法
2017/02/18 PHP
详细对比php中类继承和接口继承
2018/10/11 PHP
PHP设计模式之适配器模式(Adapter)原理与用法详解
2019/12/12 PHP
php操作redis数据库常见方法实例总结
2020/02/20 PHP
发现的以前不知道的函数
2006/09/19 Javascript
基于jquery的获取mouse坐标插件的实现代码
2010/04/01 Javascript
jquery提升性能最佳实践小结
2010/12/06 Javascript
javascript 弹出窗口中是否显示地址栏的实现代码
2011/04/14 Javascript
如何用JavaScript定义一个类
2014/09/12 Javascript
jquery实现简单的表单验证
2015/11/17 Javascript
AngularJS的脏检查深入分析
2017/04/22 Javascript
jquery获取链接地址和跳转详解(推荐)
2017/08/15 jQuery
vue中的scope使用详解
2017/10/29 Javascript
bootstrap3中container与container_fluid外层容器的区别讲解
2017/12/04 Javascript
使用vue-cli创建项目的图文教程(新手入门篇)
2018/05/02 Javascript
JavaScript设计模式之装饰者模式定义与应用示例
2018/07/25 Javascript
angular的输入和输出的使用方法
2018/09/22 Javascript
小程序异步问题之多个网络请求依次执行并依次收集请求结果
2019/05/05 Javascript
vue中实现拖动调整左右两侧div的宽度的示例代码
2020/07/22 Javascript
使用Python的Bottle框架写一个简单的服务接口的示例
2015/08/25 Python
python SQLAlchemy的Mapping与Declarative详解
2019/07/04 Python
python中eval与int的区别浅析
2019/08/11 Python
python机器学习库xgboost的使用
2020/01/20 Python
Python并发请求下限制QPS(每秒查询率)的实现代码
2020/06/05 Python
Python unittest装饰器实现原理及代码
2020/09/08 Python
学校办公室主任职责
2013/12/27 职场文书
读书演讲主持词
2014/03/18 职场文书
婚纱摄影师求职信范文
2014/04/17 职场文书
中文专业求职信
2014/06/20 职场文书
2015年班主任德育工作总结
2015/05/21 职场文书
分享几个JavaScript运算符的使用技巧
2021/04/24 Javascript
CentOS7和8下安装Maven3.8.4
2022/04/07 Servers