TensorFlow实现iris数据集线性回归


Posted in Python onSeptember 07, 2018

本文将遍历批量数据点并让TensorFlow更新斜率和y截距。这次将使用Scikit Learn的内建iris数据集。特别地,我们将用数据点(x值代表花瓣宽度,y值代表花瓣长度)找到最优直线。选择这两种特征是因为它们具有线性关系,在后续结果中将会看到。本文将使用L2正则损失函数。

# 用TensorFlow实现线性回归算法
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve linear regression.
# y = Ax + b
#
# We will use the iris data, specifically:
# y = Sepal Length
# x = Petal Width

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

# 批量大小
batch_size = 25

# Initialize 占位符
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 模型变量
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# 增加线性模型,y=Ax+b
model_output = tf.add(tf.matmul(x_data, A), b)

# 声明L2损失函数,其为批量损失的平均值。
loss = tf.reduce_mean(tf.square(y_target - model_output))

# 声明优化器 学习率设为0.05
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(loss)

# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 批量训练遍历迭代
# 迭代100次,每25次迭代输出变量值和损失值
loss_vec = []
for i in range(100):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = np.transpose([x_vals[rand_index]])
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)
  if (i+1)%25==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Loss = ' + str(temp_loss))

# 抽取系数
[slope] = sess.run(A)
[y_intercept] = sess.run(b)

# 创建最佳拟合直线
best_fit = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)

# 绘制两幅图
# 拟合的直线
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='Best fit line', linewidth=3)
plt.legend(loc='upper left')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
# 迭代100次的L2正则损失函数
plt.plot(loss_vec, 'k-')
plt.title('L2 Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('L2 Loss')
plt.show()

结果:

Step #25 A = [[ 1.93474162]] b = [[ 3.11190438]]
Loss = 1.21364
Step #50 A = [[ 1.48641717]] b = [[ 3.81004381]]
Loss = 0.945256
Step #75 A = [[ 1.26089203]] b = [[ 4.221035]]
Loss = 0.254756
Step #100 A = [[ 1.1693294]] b = [[ 4.47258472]]
Loss = 0.281654

TensorFlow实现iris数据集线性回归

TensorFlow实现iris数据集线性回归

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python动态加载模块的3种方法
Nov 22 Python
JSONLINT:python的json数据验证库实例解析
Nov 28 Python
python如何定义带参数的装饰器
Mar 20 Python
浅谈python的深浅拷贝以及fromkeys的用法
Mar 08 Python
python实现两个dict合并与计算操作示例
Jul 01 Python
windows下python虚拟环境virtualenv安装和使用详解
Jul 16 Python
Python抓新型冠状病毒肺炎疫情数据并绘制全国疫情分布的代码实例
Feb 05 Python
python读取当前目录下的CSV文件数据
Mar 11 Python
keras-siamese用自己的数据集实现详解
Jun 10 Python
PyTorch的torch.cat用法
Jun 28 Python
利用Opencv实现图片的油画特效实例
Feb 28 Python
Python函数中apply、map、applymap的区别
Nov 27 Python
TensorFlow实现模型评估
Sep 07 #Python
使用tensorflow实现线性svm
Sep 07 #Python
Python多进程池 multiprocessing Pool用法示例
Sep 07 #Python
详解python while 函数及while和for的区别
Sep 07 #Python
使用TensorFlow实现SVM
Sep 06 #Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
python实现机器学习之多元线性回归
Sep 06 #Python
You might like
php模板原理讲解
2013/11/13 PHP
ThinkPHP之A方法实例讲解
2014/06/20 PHP
PHP实现生成透明背景的PNG缩略图函数分享
2014/07/08 PHP
PHP进阶学习之Geo的地图定位算法详解
2019/06/19 PHP
关于__defineGetter__ 和__defineSetter__的说明
2007/05/12 Javascript
可兼容IE的获取及设置cookie的jquery.cookie函数方法
2013/09/02 Javascript
js实现文章文字大小字号功能完整实例
2014/11/01 Javascript
JavaScript输出当前时间Unix时间戳的方法
2015/04/06 Javascript
jquery实现适用于门户站的导航下拉菜单效果代码
2015/08/24 Javascript
JavaScript 性能优化小结
2015/10/12 Javascript
使用vue.js2.0 + ElementUI开发后台管理系统详细教程(一)
2017/01/21 Javascript
JavaScript无操作后屏保功能的实现方法
2017/07/04 Javascript
jQuery条件分页 代替离线查询(附代码)
2017/08/17 jQuery
Vue 2.5.2下axios + express 本地请求404的解决方法
2018/02/21 Javascript
Vue 页面跳转不用router-link的实现代码
2018/04/12 Javascript
bootstrap 路径导航 分页 进度条的实例代码
2018/08/06 Javascript
js实现文件上传功能 后台使用MultipartFile
2018/09/08 Javascript
原生JS实现前端本地文件上传
2018/09/08 Javascript
移动端JS实现拖拽两种方法解析
2020/10/12 Javascript
通过源码分析Python中的切片赋值
2017/05/08 Python
详解python算法之冒泡排序
2019/03/05 Python
python2.7使用plotly绘制本地散点图和折线图
2019/04/02 Python
python爬虫神器Pyppeteer入门及使用
2019/07/13 Python
python 使用pygame工具包实现贪吃蛇游戏(多彩版)
2019/10/30 Python
Python基础之字典常见操作经典实例详解
2020/02/26 Python
python 使用递归回溯完美解决八皇后的问题
2020/02/26 Python
Python matplotlib图例放在外侧保存时显示不完整问题解决
2020/07/28 Python
python mock测试的示例
2020/10/19 Python
详解Java中一维、二维数组在内存中的结构
2021/02/11 Python
python爬虫scrapy基于CrawlSpider类的全站数据爬取示例解析
2021/02/20 Python
非常详细的C#面试题集
2016/07/13 面试题
个人函授自我鉴定
2014/03/25 职场文书
先进个人总结范文
2015/02/15 职场文书
Nginx tp3.2.3 404问题解决方案
2021/03/31 Servers
如何搭建 MySQL 高可用高性能集群
2021/06/21 MySQL
Python中np.random.randint()参数详解及用法实例
2022/09/23 Python