TensorFlow实现iris数据集线性回归


Posted in Python onSeptember 07, 2018

本文将遍历批量数据点并让TensorFlow更新斜率和y截距。这次将使用Scikit Learn的内建iris数据集。特别地,我们将用数据点(x值代表花瓣宽度,y值代表花瓣长度)找到最优直线。选择这两种特征是因为它们具有线性关系,在后续结果中将会看到。本文将使用L2正则损失函数。

# 用TensorFlow实现线性回归算法
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve linear regression.
# y = Ax + b
#
# We will use the iris data, specifically:
# y = Sepal Length
# x = Petal Width

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

# 批量大小
batch_size = 25

# Initialize 占位符
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 模型变量
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# 增加线性模型,y=Ax+b
model_output = tf.add(tf.matmul(x_data, A), b)

# 声明L2损失函数,其为批量损失的平均值。
loss = tf.reduce_mean(tf.square(y_target - model_output))

# 声明优化器 学习率设为0.05
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(loss)

# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 批量训练遍历迭代
# 迭代100次,每25次迭代输出变量值和损失值
loss_vec = []
for i in range(100):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = np.transpose([x_vals[rand_index]])
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)
  if (i+1)%25==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Loss = ' + str(temp_loss))

# 抽取系数
[slope] = sess.run(A)
[y_intercept] = sess.run(b)

# 创建最佳拟合直线
best_fit = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)

# 绘制两幅图
# 拟合的直线
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='Best fit line', linewidth=3)
plt.legend(loc='upper left')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
# 迭代100次的L2正则损失函数
plt.plot(loss_vec, 'k-')
plt.title('L2 Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('L2 Loss')
plt.show()

结果:

Step #25 A = [[ 1.93474162]] b = [[ 3.11190438]]
Loss = 1.21364
Step #50 A = [[ 1.48641717]] b = [[ 3.81004381]]
Loss = 0.945256
Step #75 A = [[ 1.26089203]] b = [[ 4.221035]]
Loss = 0.254756
Step #100 A = [[ 1.1693294]] b = [[ 4.47258472]]
Loss = 0.281654

TensorFlow实现iris数据集线性回归

TensorFlow实现iris数据集线性回归

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
简单介绍Python中的RSS处理
Apr 13 Python
深入剖析Python的爬虫框架Scrapy的结构与运作流程
Jan 20 Python
Django的信号机制详解
May 05 Python
Python轻量级ORM框架Peewee访问sqlite数据库的方法详解
Jul 20 Python
python增加矩阵维度的实例讲解
Apr 04 Python
python smtplib模块实现发送邮件带附件sendmail
May 22 Python
基于Django与ajax之间的json传输方法
May 29 Python
Python3实现的爬虫爬取数据并存入mysql数据库操作示例
Jun 06 Python
对pandas写入读取h5文件的方法详解
Dec 28 Python
Python协程操作之gevent(yield阻塞,greenlet),协程实现多任务(有规律的交替协作执行)用法详解
Oct 14 Python
Python调用C语言程序方法解析
Jul 07 Python
python和C++共享内存传输图像的示例
Oct 27 Python
TensorFlow实现模型评估
Sep 07 #Python
使用tensorflow实现线性svm
Sep 07 #Python
Python多进程池 multiprocessing Pool用法示例
Sep 07 #Python
详解python while 函数及while和for的区别
Sep 07 #Python
使用TensorFlow实现SVM
Sep 06 #Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
python实现机器学习之多元线性回归
Sep 06 #Python
You might like
随机头像PHP版
2006/10/09 PHP
php 禁止页面缓存输出
2009/01/07 PHP
PHP中使用cURL实现Get和Post请求的方法
2013/03/13 PHP
php实现发送微信模板消息的方法
2015/03/07 PHP
Jquery 一次处理多个ajax请求的代码
2011/09/02 Javascript
Jquery模仿Baidu、Google搜索时自动补充搜索结果提示
2013/12/26 Javascript
JSONP跨域的原理解析及其实现介绍
2014/03/22 Javascript
14 个折磨人的 JavaScript 面试题
2016/08/08 Javascript
微信小程序 获取相册照片实例详解
2016/11/16 Javascript
Node.js的特点详解
2017/02/03 Javascript
React-Native 组件之 Modal的使用详解
2017/08/08 Javascript
jQuery实现点击旋转,再点击恢复初始状态动画效果示例
2018/12/11 jQuery
Vue 动态组件与 v-once 指令的实现
2019/02/12 Javascript
vue3+typescript实现图片懒加载插件
2020/10/26 Javascript
JavaScript 生成唯一ID的几种方式
2021/02/19 Javascript
[12:51]71泪洒现场!是DOTA2让经典重现
2014/03/24 DOTA
[50:59]2018DOTA2亚洲邀请赛 4.7 总决赛 LGD vs Mineski第四场
2018/04/10 DOTA
一个超级简单的python web程序
2014/09/11 Python
python统计日志ip访问数的方法
2015/07/06 Python
Python定时任务sched模块用法示例
2018/07/16 Python
利用python修改json文件的value方法
2018/12/31 Python
django model通过字典更新数据实例
2020/04/01 Python
使用opencv识别图像红色区域,并输出红色区域中心点坐标
2020/06/02 Python
美国知名生活购物网站:Goop
2017/11/03 全球购物
个人求职简历的自我评价范文
2013/10/09 职场文书
一份婚庆公司创业计划书
2014/01/11 职场文书
公司经理聘任书
2014/03/29 职场文书
文明礼仪伴我行演讲稿
2014/05/12 职场文书
幼儿园标语大全
2014/06/19 职场文书
房屋买卖授权委托书
2014/09/27 职场文书
公司2014年度工作总结
2014/12/10 职场文书
诚信考试承诺书范文
2015/04/29 职场文书
2015年银行信贷员工作总结
2015/05/19 职场文书
团结主题班会
2015/08/13 职场文书
国产动画《万圣街》日语配音版制作决定!
2022/03/20 国漫
MSSQL基本语法操作
2022/04/11 SQL Server