TensorFlow实现创建分类器


Posted in Python onFebruary 06, 2018

本文实例为大家分享了TensorFlow实现创建分类器的具体代码,供大家参考,具体内容如下

创建一个iris数据集的分类器。

加载样本数据集,实现一个简单的二值分类器来预测一朵花是否为山鸢尾。iris数据集有三类花,但这里仅预测是否是山鸢尾。导入iris数据集和工具库,相应地对原数据集进行转换。

# Combining Everything Together
#----------------------------------
# This file will perform binary classification on the
# iris dataset. We will only predict if a flower is
# I.setosa or not.
#
# We will create a simple binary classifier by creating a line
# and running everything through a sigmoid to get a binary predictor.
# The two features we will use are pedal length and pedal width.
#
# We will use batch training, but this can be easily
# adapted to stochastic training.

import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
import tensorflow as tf
from tensorflow.python.framework import ops
ops.reset_default_graph()

# 导入iris数据集
# 根据目标数据是否为山鸢尾将其转换成1或者0。
# 由于iris数据集将山鸢尾标记为0,我们将其从0置为1,同时把其他物种标记为0。
# 本次训练只使用两种特征:花瓣长度和花瓣宽度,这两个特征在x-value的第三列和第四列
# iris.target = {0, 1, 2}, where '0' is setosa
# iris.data ~ [sepal.width, sepal.length, pedal.width, pedal.length]
iris = datasets.load_iris()
binary_target = np.array([1. if x==0 else 0. for x in iris.target])
iris_2d = np.array([[x[2], x[3]] for x in iris.data])

# 声明批量训练大小
batch_size = 20

# 初始化计算图
sess = tf.Session()

# 声明数据占位符
x1_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
x2_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 声明模型变量
# Create variables A and b (0 = x1 - A*x2 + b)
A = tf.Variable(tf.random_normal(shape=[1, 1]))
b = tf.Variable(tf.random_normal(shape=[1, 1]))

# 定义线性模型:
# 如果找到的数据点在直线以上,则将数据点代入x2-x1*A-b计算出的结果大于0;
# 同理找到的数据点在直线以下,则将数据点代入x2-x1*A-b计算出的结果小于0。
# x1 - A*x2 + b
my_mult = tf.matmul(x2_data, A)
my_add = tf.add(my_mult, b)
my_output = tf.subtract(x1_data, my_add)

# 增加TensorFlow的sigmoid交叉熵损失函数(cross entropy)
xentropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=my_output, labels=y_target)

# 声明优化器方法
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(xentropy)

# 创建一个变量初始化操作
init = tf.global_variables_initializer()
sess.run(init)

# 运行迭代1000次
for i in range(1000):
  rand_index = np.random.choice(len(iris_2d), size=batch_size)
  # rand_x = np.transpose([iris_2d[rand_index]])
  # 传入三种数据:花瓣长度、花瓣宽度和目标变量
  rand_x = iris_2d[rand_index]
  rand_x1 = np.array([[x[0]] for x in rand_x])
  rand_x2 = np.array([[x[1]] for x in rand_x])
  #rand_y = np.transpose([binary_target[rand_index]])
  rand_y = np.array([[y] for y in binary_target[rand_index]])
  sess.run(train_step, feed_dict={x1_data: rand_x1, x2_data: rand_x2, y_target: rand_y})
  if (i+1)%200==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ', b = ' + str(sess.run(b)))


# 绘图
# 获取斜率/截距
# Pull out slope/intercept
[[slope]] = sess.run(A)
[[intercept]] = sess.run(b)

# 创建拟合线
x = np.linspace(0, 3, num=50)
ablineValues = []
for i in x:
 ablineValues.append(slope*i+intercept)

# 绘制拟合曲线
setosa_x = [a[1] for i,a in enumerate(iris_2d) if binary_target[i]==1]
setosa_y = [a[0] for i,a in enumerate(iris_2d) if binary_target[i]==1]
non_setosa_x = [a[1] for i,a in enumerate(iris_2d) if binary_target[i]==0]
non_setosa_y = [a[0] for i,a in enumerate(iris_2d) if binary_target[i]==0]
plt.plot(setosa_x, setosa_y, 'rx', ms=10, mew=2, label='setosa')
plt.plot(non_setosa_x, non_setosa_y, 'ro', label='Non-setosa')
plt.plot(x, ablineValues, 'b-')
plt.xlim([0.0, 2.7])
plt.ylim([0.0, 7.1])
plt.suptitle('Linear Separator For I.setosa', fontsize=20)
plt.xlabel('Petal Length')
plt.ylabel('Petal Width')
plt.legend(loc='lower right')
plt.show()

输出:

Step #200 A = [[ 8.70572948]], b = [[-3.46638322]]
Step #400 A = [[ 10.21302414]], b = [[-4.720438]]
Step #600 A = [[ 11.11844635]], b = [[-5.53361702]]
Step #800 A = [[ 11.86427212]], b = [[-6.0110755]]
Step #1000 A = [[ 12.49524498]], b = [[-6.29990339]]

TensorFlow实现创建分类器

 以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的批量远程管理和部署工具Fabric用法实例
Jan 23 Python
Python中functools模块的常用函数解析
Jun 30 Python
在pycharm中设置显示行数的方法
Jan 16 Python
python中@property和property函数常见使用方法示例
Oct 21 Python
tensorflow2.0与tensorflow1.0的性能区别介绍
Feb 07 Python
django model 条件过滤 queryset.filter(**condtions)用法详解
May 20 Python
Python爬虫HTPP请求方法有哪些
Jun 03 Python
python读写数据读写csv文件(pandas用法)
Dec 14 Python
python压包的概念及实例详解
Feb 17 Python
Python爬虫之爬取最新更新的小说网站
May 06 Python
利用Python判断整数是否是回文数的3种方法总结
Jul 07 Python
Python函数式编程中itertools模块详解
Sep 15 Python
Python模拟随机游走图形效果示例
Feb 06 #Python
Python 12306抢火车票脚本 Python京东抢手机脚本
Feb 06 #Python
TensorFlow高效读取数据的方法示例
Feb 06 #Python
django使用xlwt导出excel文件实例代码
Feb 06 #Python
Python使用装饰器进行django开发实例代码
Feb 06 #Python
Python yield与实现方法代码分析
Feb 06 #Python
Django中间件工作流程及写法实例代码
Feb 06 #Python
You might like
落伍首发 php+mysql 采用ajax技术的 省 市 地 3级联动无刷新菜单 源码
2006/12/16 PHP
PHP 设置MySQL连接字符集的方法
2011/01/02 PHP
PHP把空格、换行符、中文逗号等替换成英文逗号的正则表达式
2014/05/04 PHP
php正则表达式获取内容所有链接
2015/07/24 PHP
php实现json编码的方法
2015/07/30 PHP
程序员的表白神器“520”大声喊出来
2016/05/20 PHP
JS实多级联动下拉菜单类,简单实现省市区联动菜单!
2007/05/03 Javascript
javascript 变量作用域 代码分析
2009/06/26 Javascript
JavaScript学习笔记之创建对象
2016/03/25 Javascript
浅谈JavaScript中变量和函数声明的提升
2016/08/09 Javascript
浅谈jquery选择器 :first与:first-child的区别
2016/11/20 Javascript
JS实现物体带缓冲的间歇运动效果示例
2016/12/22 Javascript
利用PM2部署node.js项目的方法教程
2017/05/10 Javascript
JS实现按钮控制计时开始和停止功能
2017/07/27 Javascript
判断滚动条滑到底部触发事件(实例讲解)
2017/11/15 Javascript
vue中如何实现后台管理系统的权限控制的方法示例
2018/09/19 Javascript
vue-cli 3 全局过滤器的实例代码详解
2019/06/03 Javascript
Vue混入mixins滚动触底的方法
2019/11/22 Javascript
toString.call()通用的判断数据类型方法示例
2020/08/28 Javascript
vue 公共列表选择组件,引用Vant-UI的样式方式
2020/11/02 Javascript
python 正则表达式 概述及常用字符
2009/05/04 Python
详解Python的Django框架中inclusion_tag的使用
2015/07/21 Python
深入理解Python3中的http.client模块
2017/03/29 Python
python 寻找离散序列极值点的方法
2019/07/10 Python
python 设置xlabel,ylabel 坐标轴字体大小,字体类型
2019/07/23 Python
加拿大床上用品、家居装饰、厨房和浴室产品购物网站:Linen Chest
2018/06/05 全球购物
恐龙的灭绝教学反思
2014/02/12 职场文书
2014年党风廉政建设工作总结
2014/11/19 职场文书
健康状况证明书
2014/11/26 职场文书
2015年妇女工作总结
2015/05/14 职场文书
辞职申请书范本
2019/05/20 职场文书
Django展示可视化图表的多种方式
2021/04/08 Python
Java日常练习题,每天进步一点点(38)
2021/07/26 Java/Android
MySQL中order by的使用详情
2021/11/17 MySQL
python脚本框架webpy模板控制结构
2021/11/20 Python
使用Cargo工具高效创建Rust项目
2022/08/14 Javascript