TensorFlow实现iris数据集线性回归


Posted in Python onSeptember 07, 2018

本文将遍历批量数据点并让TensorFlow更新斜率和y截距。这次将使用Scikit Learn的内建iris数据集。特别地,我们将用数据点(x值代表花瓣宽度,y值代表花瓣长度)找到最优直线。选择这两种特征是因为它们具有线性关系,在后续结果中将会看到。本文将使用L2正则损失函数。

# 用TensorFlow实现线性回归算法
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve linear regression.
# y = Ax + b
#
# We will use the iris data, specifically:
# y = Sepal Length
# x = Petal Width

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

# 批量大小
batch_size = 25

# Initialize 占位符
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 模型变量
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# 增加线性模型,y=Ax+b
model_output = tf.add(tf.matmul(x_data, A), b)

# 声明L2损失函数,其为批量损失的平均值。
loss = tf.reduce_mean(tf.square(y_target - model_output))

# 声明优化器 学习率设为0.05
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(loss)

# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 批量训练遍历迭代
# 迭代100次,每25次迭代输出变量值和损失值
loss_vec = []
for i in range(100):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = np.transpose([x_vals[rand_index]])
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)
  if (i+1)%25==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Loss = ' + str(temp_loss))

# 抽取系数
[slope] = sess.run(A)
[y_intercept] = sess.run(b)

# 创建最佳拟合直线
best_fit = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)

# 绘制两幅图
# 拟合的直线
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='Best fit line', linewidth=3)
plt.legend(loc='upper left')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
# 迭代100次的L2正则损失函数
plt.plot(loss_vec, 'k-')
plt.title('L2 Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('L2 Loss')
plt.show()

结果:

Step #25 A = [[ 1.93474162]] b = [[ 3.11190438]]
Loss = 1.21364
Step #50 A = [[ 1.48641717]] b = [[ 3.81004381]]
Loss = 0.945256
Step #75 A = [[ 1.26089203]] b = [[ 4.221035]]
Loss = 0.254756
Step #100 A = [[ 1.1693294]] b = [[ 4.47258472]]
Loss = 0.281654

TensorFlow实现iris数据集线性回归

TensorFlow实现iris数据集线性回归

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中类型检查的详细介绍
Feb 13 Python
python 计算两个日期相差多少个月实例代码
May 24 Python
windows下python之mysqldb模块安装方法
Sep 07 Python
PyCharm设置SSH远程调试的方法
Jul 17 Python
python3.6.3安装图文教程 TensorFlow安装配置方法
Jun 24 Python
numpy中的ndarray方法和属性详解
May 27 Python
详解使用Python下载文件的几种方法
Oct 13 Python
在Django中实现添加user到group并查看
Nov 18 Python
简单了解Python读取大文件代码实例
Dec 18 Python
Python tkinter三种布局实例详解
Jan 06 Python
Python处理PDF与CDF实例
Feb 26 Python
用 Python 制作地球仪的方法
Apr 24 Python
TensorFlow实现模型评估
Sep 07 #Python
使用tensorflow实现线性svm
Sep 07 #Python
Python多进程池 multiprocessing Pool用法示例
Sep 07 #Python
详解python while 函数及while和for的区别
Sep 07 #Python
使用TensorFlow实现SVM
Sep 06 #Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
python实现机器学习之多元线性回归
Sep 06 #Python
You might like
PHP Google的translate API代码
2008/12/10 PHP
网页游戏开发入门教程二(游戏模式+系统)
2009/11/02 PHP
PHP 第三节 变量介绍
2012/04/28 PHP
PHP处理JSON字符串key缺少双引号的解决方法
2014/09/16 PHP
合格的PHP程序员必备技能
2015/11/13 PHP
基于thinkphp6.0的success、error实现方法
2019/11/05 PHP
js 图片等比例缩放代码
2010/05/13 Javascript
一个关于jqGrid使用的小例子(行按钮)
2011/11/04 Javascript
js替换字符串的所有示例代码
2013/07/23 Javascript
js控制li的隐藏和显示实例代码
2016/10/15 Javascript
微信小程序 判断手机号的实现代码
2017/04/19 Javascript
Angular 4.x 动态创建表单实例
2017/04/25 Javascript
JS实现预加载视频音频/视频获取截图(返回canvas截图)
2017/10/09 Javascript
详解为Bootstrap Modal添加拖拽的方法
2018/01/05 Javascript
Vue实现左右菜单联动实现代码
2018/08/12 Javascript
记一次webapck4 配置文件无效的解决历程
2018/09/19 Javascript
node.js ws模块搭建websocket服务端的方法示例
2019/04/25 Javascript
JS实现数据动态渲染的竖向步骤条
2020/06/24 Javascript
Javascript表单序列化原理及实现代码详解
2020/10/30 Javascript
[03:49]显微镜下的DOTA2第十五期—VG登基之路完美团
2014/06/24 DOTA
跟老齐学Python之有容乃大的list(1)
2014/09/14 Python
Django框架中的对象列表视图使用示例
2015/07/21 Python
python 一维二维插值实例
2020/04/22 Python
matplotlib基础绘图命令之imshow的使用
2020/08/13 Python
纯css3实现思维导图样式示例
2018/11/01 HTML / CSS
Schutz鞋官方网站:Schutz Shoes
2017/12/13 全球购物
美国正宗设计师眼镜在线零售商:EYEZZ
2019/03/23 全球购物
迎新晚会邀请函
2014/02/01 职场文书
班长自荐书范文
2014/02/11 职场文书
艺术设计专业个人求职信
2014/04/10 职场文书
安全教育演讲稿
2014/05/09 职场文书
推荐信模板
2014/05/09 职场文书
投资意向书
2014/07/30 职场文书
2015年世界无烟日活动总结
2015/02/10 职场文书
公司催款律师函
2015/05/27 职场文书
vue-cli3.0修改打包后的文件名和文件地址,打包后本地运行报错解决
2022/04/06 Vue.js