TensorFlow实现iris数据集线性回归


Posted in Python onSeptember 07, 2018

本文将遍历批量数据点并让TensorFlow更新斜率和y截距。这次将使用Scikit Learn的内建iris数据集。特别地,我们将用数据点(x值代表花瓣宽度,y值代表花瓣长度)找到最优直线。选择这两种特征是因为它们具有线性关系,在后续结果中将会看到。本文将使用L2正则损失函数。

# 用TensorFlow实现线性回归算法
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve linear regression.
# y = Ax + b
#
# We will use the iris data, specifically:
# y = Sepal Length
# x = Petal Width

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

# 批量大小
batch_size = 25

# Initialize 占位符
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 模型变量
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# 增加线性模型,y=Ax+b
model_output = tf.add(tf.matmul(x_data, A), b)

# 声明L2损失函数,其为批量损失的平均值。
loss = tf.reduce_mean(tf.square(y_target - model_output))

# 声明优化器 学习率设为0.05
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(loss)

# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 批量训练遍历迭代
# 迭代100次,每25次迭代输出变量值和损失值
loss_vec = []
for i in range(100):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = np.transpose([x_vals[rand_index]])
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)
  if (i+1)%25==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Loss = ' + str(temp_loss))

# 抽取系数
[slope] = sess.run(A)
[y_intercept] = sess.run(b)

# 创建最佳拟合直线
best_fit = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)

# 绘制两幅图
# 拟合的直线
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='Best fit line', linewidth=3)
plt.legend(loc='upper left')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
# 迭代100次的L2正则损失函数
plt.plot(loss_vec, 'k-')
plt.title('L2 Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('L2 Loss')
plt.show()

结果:

Step #25 A = [[ 1.93474162]] b = [[ 3.11190438]]
Loss = 1.21364
Step #50 A = [[ 1.48641717]] b = [[ 3.81004381]]
Loss = 0.945256
Step #75 A = [[ 1.26089203]] b = [[ 4.221035]]
Loss = 0.254756
Step #100 A = [[ 1.1693294]] b = [[ 4.47258472]]
Loss = 0.281654

TensorFlow实现iris数据集线性回归

TensorFlow实现iris数据集线性回归

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python help()函数用法详解
Mar 11 Python
Python深入学习之装饰器
Aug 31 Python
python中遍历文件的3个方法
Sep 02 Python
python获取当前日期和时间的方法
Apr 30 Python
python中的错误处理
Apr 10 Python
python3实现暴力穷举博客园密码
Jun 19 Python
利用aardio给python编写图形界面
Aug 21 Python
Python:Numpy 求平均向量的实例
Jun 29 Python
python实现简单银行管理系统
Oct 25 Python
Python实现序列化及csv文件读取
Jan 19 Python
Python如何对齐字符串
Jul 30 Python
python+pygame实现坦克大战小游戏的示例代码(可以自定义子弹速度)
Aug 11 Python
TensorFlow实现模型评估
Sep 07 #Python
使用tensorflow实现线性svm
Sep 07 #Python
Python多进程池 multiprocessing Pool用法示例
Sep 07 #Python
详解python while 函数及while和for的区别
Sep 07 #Python
使用TensorFlow实现SVM
Sep 06 #Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
python实现机器学习之多元线性回归
Sep 06 #Python
You might like
控制PHP的输出:缓存并压缩动态页面
2013/06/11 PHP
ThinkPHP调用common/common.php函数提示错误function undefined的解决方法
2014/08/25 PHP
php实现通用的从数据库表读取数据到数组的函数实例
2015/03/21 PHP
输入密码检测大写是否锁定js实现代码
2012/12/03 Javascript
jquery easyui滚动条部分设置介绍
2013/09/12 Javascript
简单常用的幻灯片播放实现代码
2013/09/25 Javascript
jquery获取radio值实例
2014/10/16 Javascript
jQuery实现点击该行即可删除HTML表格行
2014/10/17 Javascript
javascript实现控制的多级下拉菜单
2015/07/05 Javascript
jQuery和hwSlider实现内容响应式可触控滑动切换效果附源码下载(二)
2016/06/22 Javascript
NodeJs的优势和适合开发的程序
2016/08/14 NodeJs
AngularJS 过滤与排序详解及实例代码
2016/09/14 Javascript
自己封装的一个原生JS拖动方法(推荐)
2016/11/22 Javascript
vue2.0全局组件之pdf详解
2017/06/26 Javascript
详解使用webpack打包编写一个vue-toast插件
2017/11/08 Javascript
node.js读取Excel数据(下载图片)的方法示例
2018/08/02 Javascript
jQuery实现的简单日历组件定义与用法示例
2018/12/24 jQuery
原生js实现弹幕效果
2020/11/29 Javascript
python日志记录模块实例及改进
2017/02/12 Python
使用Python Pandas处理亿级数据的方法
2019/06/24 Python
python 申请内存空间,用于创建多维数组的实例
2019/12/02 Python
django为Form生成的label标签添加class方式
2020/05/20 Python
django 模型字段设置默认值代码
2020/07/15 Python
安装Anaconda3及使用Jupyter的方法
2020/10/27 Python
解决Python import .pyd 可能遇到路径的问题
2021/03/04 Python
用css3实现转换过渡和动画效果
2020/03/13 HTML / CSS
加拿大最大的书店:Indigo
2017/01/01 全球购物
thinkphp5 redis缓存新增方法实例讲解
2021/03/24 PHP
《太阳》教学反思
2014/02/21 职场文书
网络营销计划
2015/01/17 职场文书
员工担保书范本
2015/09/22 职场文书
2016感恩父亲节主题广播稿
2015/12/18 职场文书
阳光体育运动标语口号
2015/12/26 职场文书
浅谈Python 中的复数问题
2021/05/19 Python
JavaScript实现班级抽签小程序
2021/05/19 Javascript
关于Oracle12C默认用户名system密码不正确的解决方案
2021/10/16 Oracle