TensorFlow实现创建分类器


Posted in Python onFebruary 06, 2018

本文实例为大家分享了TensorFlow实现创建分类器的具体代码,供大家参考,具体内容如下

创建一个iris数据集的分类器。

加载样本数据集,实现一个简单的二值分类器来预测一朵花是否为山鸢尾。iris数据集有三类花,但这里仅预测是否是山鸢尾。导入iris数据集和工具库,相应地对原数据集进行转换。

# Combining Everything Together
#----------------------------------
# This file will perform binary classification on the
# iris dataset. We will only predict if a flower is
# I.setosa or not.
#
# We will create a simple binary classifier by creating a line
# and running everything through a sigmoid to get a binary predictor.
# The two features we will use are pedal length and pedal width.
#
# We will use batch training, but this can be easily
# adapted to stochastic training.

import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
import tensorflow as tf
from tensorflow.python.framework import ops
ops.reset_default_graph()

# 导入iris数据集
# 根据目标数据是否为山鸢尾将其转换成1或者0。
# 由于iris数据集将山鸢尾标记为0,我们将其从0置为1,同时把其他物种标记为0。
# 本次训练只使用两种特征:花瓣长度和花瓣宽度,这两个特征在x-value的第三列和第四列
# iris.target = {0, 1, 2}, where '0' is setosa
# iris.data ~ [sepal.width, sepal.length, pedal.width, pedal.length]
iris = datasets.load_iris()
binary_target = np.array([1. if x==0 else 0. for x in iris.target])
iris_2d = np.array([[x[2], x[3]] for x in iris.data])

# 声明批量训练大小
batch_size = 20

# 初始化计算图
sess = tf.Session()

# 声明数据占位符
x1_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
x2_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 声明模型变量
# Create variables A and b (0 = x1 - A*x2 + b)
A = tf.Variable(tf.random_normal(shape=[1, 1]))
b = tf.Variable(tf.random_normal(shape=[1, 1]))

# 定义线性模型:
# 如果找到的数据点在直线以上,则将数据点代入x2-x1*A-b计算出的结果大于0;
# 同理找到的数据点在直线以下,则将数据点代入x2-x1*A-b计算出的结果小于0。
# x1 - A*x2 + b
my_mult = tf.matmul(x2_data, A)
my_add = tf.add(my_mult, b)
my_output = tf.subtract(x1_data, my_add)

# 增加TensorFlow的sigmoid交叉熵损失函数(cross entropy)
xentropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=my_output, labels=y_target)

# 声明优化器方法
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(xentropy)

# 创建一个变量初始化操作
init = tf.global_variables_initializer()
sess.run(init)

# 运行迭代1000次
for i in range(1000):
  rand_index = np.random.choice(len(iris_2d), size=batch_size)
  # rand_x = np.transpose([iris_2d[rand_index]])
  # 传入三种数据:花瓣长度、花瓣宽度和目标变量
  rand_x = iris_2d[rand_index]
  rand_x1 = np.array([[x[0]] for x in rand_x])
  rand_x2 = np.array([[x[1]] for x in rand_x])
  #rand_y = np.transpose([binary_target[rand_index]])
  rand_y = np.array([[y] for y in binary_target[rand_index]])
  sess.run(train_step, feed_dict={x1_data: rand_x1, x2_data: rand_x2, y_target: rand_y})
  if (i+1)%200==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ', b = ' + str(sess.run(b)))


# 绘图
# 获取斜率/截距
# Pull out slope/intercept
[[slope]] = sess.run(A)
[[intercept]] = sess.run(b)

# 创建拟合线
x = np.linspace(0, 3, num=50)
ablineValues = []
for i in x:
 ablineValues.append(slope*i+intercept)

# 绘制拟合曲线
setosa_x = [a[1] for i,a in enumerate(iris_2d) if binary_target[i]==1]
setosa_y = [a[0] for i,a in enumerate(iris_2d) if binary_target[i]==1]
non_setosa_x = [a[1] for i,a in enumerate(iris_2d) if binary_target[i]==0]
non_setosa_y = [a[0] for i,a in enumerate(iris_2d) if binary_target[i]==0]
plt.plot(setosa_x, setosa_y, 'rx', ms=10, mew=2, label='setosa')
plt.plot(non_setosa_x, non_setosa_y, 'ro', label='Non-setosa')
plt.plot(x, ablineValues, 'b-')
plt.xlim([0.0, 2.7])
plt.ylim([0.0, 7.1])
plt.suptitle('Linear Separator For I.setosa', fontsize=20)
plt.xlabel('Petal Length')
plt.ylabel('Petal Width')
plt.legend(loc='lower right')
plt.show()

输出:

Step #200 A = [[ 8.70572948]], b = [[-3.46638322]]
Step #400 A = [[ 10.21302414]], b = [[-4.720438]]
Step #600 A = [[ 11.11844635]], b = [[-5.53361702]]
Step #800 A = [[ 11.86427212]], b = [[-6.0110755]]
Step #1000 A = [[ 12.49524498]], b = [[-6.29990339]]

TensorFlow实现创建分类器

 以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python定时采集摄像头图像上传ftp服务器功能实现
Dec 23 Python
python实现simhash算法实例
Apr 25 Python
Python中的日期时间处理详解
Nov 17 Python
Python小游戏之300行代码实现俄罗斯方块
Jan 04 Python
python TF-IDF算法实现文本关键词提取
May 29 Python
python覆盖写入,追加写入的实例
Jun 26 Python
Python shelve模块实现解析
Aug 28 Python
pymysql模块的使用(增删改查)详解
Sep 09 Python
pymysql 开启调试模式的实现
Sep 24 Python
在pycharm中实现删除bookmark
Feb 14 Python
自定义Django默认的sitemap站点地图样式
Mar 04 Python
python中tab键是什么意思
Jun 18 Python
Python模拟随机游走图形效果示例
Feb 06 #Python
Python 12306抢火车票脚本 Python京东抢手机脚本
Feb 06 #Python
TensorFlow高效读取数据的方法示例
Feb 06 #Python
django使用xlwt导出excel文件实例代码
Feb 06 #Python
Python使用装饰器进行django开发实例代码
Feb 06 #Python
Python yield与实现方法代码分析
Feb 06 #Python
Django中间件工作流程及写法实例代码
Feb 06 #Python
You might like
微博短链接算法php版本实现代码
2012/09/15 PHP
调整PHP的性能
2013/10/30 PHP
PHP编程文件处理类SplFileObject和SplFileInfo用法实例分析
2017/07/22 PHP
Thinkphp 框架扩展之行为扩展原理与实现方法分析
2020/04/23 PHP
Prototype 学习 工具函数学习($方法)
2009/07/12 Javascript
jQuery each()方法的使用方法
2010/03/18 Javascript
jquerymobile checkbox及时刷新才能获取其准确值
2012/04/14 Javascript
js操纵dom生成下拉列表框的方法
2014/02/24 Javascript
js截取中英文字符串、标点符号无乱码示例解读
2014/04/17 Javascript
JavaScript中获取鼠标位置相关属性总结
2014/10/11 Javascript
原生JavaScript编写俄罗斯方块
2015/03/30 Javascript
jQuery实现横向带缓冲的水平运动效果(附demo源码下载)
2016/01/29 Javascript
基于Node.js的JavaScript项目构建工具gulp的使用教程
2016/05/20 Javascript
jQuery表单元素选择器代码实例
2017/02/06 Javascript
Angular使用 ng-img-max 调整浏览器中的图片的示例代码
2017/08/17 Javascript
jQuery实现鼠标滑过商品小图片上显示对应大图片功能【测试可用】
2018/04/27 jQuery
Vue2实时监听表单变化的示例讲解
2018/08/30 Javascript
小程序:授权、登录、session_key、unionId的详解
2019/05/15 Javascript
深入学习JavaScript 高阶函数
2019/06/11 Javascript
vue开发简单上传图片功能
2020/06/30 Javascript
在Python 不同级目录之间模块的调用方法
2019/01/19 Python
由面试题加深对Django的认识理解
2019/07/19 Python
解决Python对齐文本字符串问题
2019/08/28 Python
PYTHON如何读取和写入EXCEL里面的数据
2019/10/28 Python
django中间键重定向实例方法
2019/11/10 Python
python实现图片二值化及灰度处理方式
2019/12/07 Python
Pycharm pyuic5实现将ui文件转为py文件,让UI界面成功显示
2020/04/08 Python
CSS3制作3D立方体loading特效
2020/11/09 HTML / CSS
美国轻奢时尚购物网站:REVOLVE(支持中文)
2020/07/18 全球购物
幼儿园感恩节活动方案
2014/10/06 职场文书
入股协议书范本
2014/11/01 职场文书
网站文案策划岗位职责
2015/04/14 职场文书
遗嘱格式范本
2015/08/07 职场文书
Python如何快速找到多个字典中的公共键(key)
2022/04/29 Python
PostgreSQL常用字符串分割函数整理汇总
2022/07/07 PostgreSQL
Python中np.random.randint()参数详解及用法实例
2022/09/23 Python