TensorFlow实现创建分类器


Posted in Python onFebruary 06, 2018

本文实例为大家分享了TensorFlow实现创建分类器的具体代码,供大家参考,具体内容如下

创建一个iris数据集的分类器。

加载样本数据集,实现一个简单的二值分类器来预测一朵花是否为山鸢尾。iris数据集有三类花,但这里仅预测是否是山鸢尾。导入iris数据集和工具库,相应地对原数据集进行转换。

# Combining Everything Together
#----------------------------------
# This file will perform binary classification on the
# iris dataset. We will only predict if a flower is
# I.setosa or not.
#
# We will create a simple binary classifier by creating a line
# and running everything through a sigmoid to get a binary predictor.
# The two features we will use are pedal length and pedal width.
#
# We will use batch training, but this can be easily
# adapted to stochastic training.

import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
import tensorflow as tf
from tensorflow.python.framework import ops
ops.reset_default_graph()

# 导入iris数据集
# 根据目标数据是否为山鸢尾将其转换成1或者0。
# 由于iris数据集将山鸢尾标记为0,我们将其从0置为1,同时把其他物种标记为0。
# 本次训练只使用两种特征:花瓣长度和花瓣宽度,这两个特征在x-value的第三列和第四列
# iris.target = {0, 1, 2}, where '0' is setosa
# iris.data ~ [sepal.width, sepal.length, pedal.width, pedal.length]
iris = datasets.load_iris()
binary_target = np.array([1. if x==0 else 0. for x in iris.target])
iris_2d = np.array([[x[2], x[3]] for x in iris.data])

# 声明批量训练大小
batch_size = 20

# 初始化计算图
sess = tf.Session()

# 声明数据占位符
x1_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
x2_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 声明模型变量
# Create variables A and b (0 = x1 - A*x2 + b)
A = tf.Variable(tf.random_normal(shape=[1, 1]))
b = tf.Variable(tf.random_normal(shape=[1, 1]))

# 定义线性模型:
# 如果找到的数据点在直线以上,则将数据点代入x2-x1*A-b计算出的结果大于0;
# 同理找到的数据点在直线以下,则将数据点代入x2-x1*A-b计算出的结果小于0。
# x1 - A*x2 + b
my_mult = tf.matmul(x2_data, A)
my_add = tf.add(my_mult, b)
my_output = tf.subtract(x1_data, my_add)

# 增加TensorFlow的sigmoid交叉熵损失函数(cross entropy)
xentropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=my_output, labels=y_target)

# 声明优化器方法
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(xentropy)

# 创建一个变量初始化操作
init = tf.global_variables_initializer()
sess.run(init)

# 运行迭代1000次
for i in range(1000):
  rand_index = np.random.choice(len(iris_2d), size=batch_size)
  # rand_x = np.transpose([iris_2d[rand_index]])
  # 传入三种数据:花瓣长度、花瓣宽度和目标变量
  rand_x = iris_2d[rand_index]
  rand_x1 = np.array([[x[0]] for x in rand_x])
  rand_x2 = np.array([[x[1]] for x in rand_x])
  #rand_y = np.transpose([binary_target[rand_index]])
  rand_y = np.array([[y] for y in binary_target[rand_index]])
  sess.run(train_step, feed_dict={x1_data: rand_x1, x2_data: rand_x2, y_target: rand_y})
  if (i+1)%200==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ', b = ' + str(sess.run(b)))


# 绘图
# 获取斜率/截距
# Pull out slope/intercept
[[slope]] = sess.run(A)
[[intercept]] = sess.run(b)

# 创建拟合线
x = np.linspace(0, 3, num=50)
ablineValues = []
for i in x:
 ablineValues.append(slope*i+intercept)

# 绘制拟合曲线
setosa_x = [a[1] for i,a in enumerate(iris_2d) if binary_target[i]==1]
setosa_y = [a[0] for i,a in enumerate(iris_2d) if binary_target[i]==1]
non_setosa_x = [a[1] for i,a in enumerate(iris_2d) if binary_target[i]==0]
non_setosa_y = [a[0] for i,a in enumerate(iris_2d) if binary_target[i]==0]
plt.plot(setosa_x, setosa_y, 'rx', ms=10, mew=2, label='setosa')
plt.plot(non_setosa_x, non_setosa_y, 'ro', label='Non-setosa')
plt.plot(x, ablineValues, 'b-')
plt.xlim([0.0, 2.7])
plt.ylim([0.0, 7.1])
plt.suptitle('Linear Separator For I.setosa', fontsize=20)
plt.xlabel('Petal Length')
plt.ylabel('Petal Width')
plt.legend(loc='lower right')
plt.show()

输出:

Step #200 A = [[ 8.70572948]], b = [[-3.46638322]]
Step #400 A = [[ 10.21302414]], b = [[-4.720438]]
Step #600 A = [[ 11.11844635]], b = [[-5.53361702]]
Step #800 A = [[ 11.86427212]], b = [[-6.0110755]]
Step #1000 A = [[ 12.49524498]], b = [[-6.29990339]]

TensorFlow实现创建分类器

 以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之玩转字符串(2)
Sep 14 Python
浅谈function(函数)中的动态参数
Apr 30 Python
轻量级的Web框架Flask 中模块化应用的实现
Sep 11 Python
Python数据结构与算法之链表定义与用法实例详解【单链表、循环链表】
Sep 28 Python
Python 将pdf转成图片的方法
Apr 23 Python
详解Python3除法之真除法、截断除法和下取整对比
May 23 Python
Pandas0.25来了千万别错过这10大好用的新功能
Aug 07 Python
python 数据生成excel导出(xlwt,wlsxwrite)代码实例
Aug 23 Python
PyQt5 如何让界面和逻辑分离的方法
Mar 24 Python
关于Python解包知识点总结
May 05 Python
python3.6使用SMTP协议发送邮件
May 20 Python
使用pth文件添加Python环境变量方式
May 26 Python
Python模拟随机游走图形效果示例
Feb 06 #Python
Python 12306抢火车票脚本 Python京东抢手机脚本
Feb 06 #Python
TensorFlow高效读取数据的方法示例
Feb 06 #Python
django使用xlwt导出excel文件实例代码
Feb 06 #Python
Python使用装饰器进行django开发实例代码
Feb 06 #Python
Python yield与实现方法代码分析
Feb 06 #Python
Django中间件工作流程及写法实例代码
Feb 06 #Python
You might like
将兴奋、喜悦和坎加斯带到戴安娜:亚马逊公主
2020/03/03 欧美动漫
香妃
2021/03/03 冲泡冲煮
如何取得中文字符串中出现次数最多的子串
2013/08/08 PHP
php中利用str_pad函数生成数字递增形式的产品编号
2013/09/30 PHP
php中explode的负数limit用法分析
2015/02/27 PHP
举例详解PHP脚本的测试方法
2015/08/05 PHP
laravel orm 关联条件查询代码
2019/10/21 PHP
Jsonp 跨域的原理以及Jquery的解决方案
2010/05/18 Javascript
简单实用的js调试logger组件实现代码
2010/11/20 Javascript
jQuery计算textarea中文字数(剩余个数)的小程序
2013/11/28 Javascript
JavaScript实现级联菜单的方法
2015/06/29 Javascript
jQuery实现自定义右键菜单的树状菜单效果
2015/09/02 Javascript
JavaScript位移运算符(无符号) >>> 三个大于号 的使用方法详解
2016/03/31 Javascript
图解prototype、proto和constructor的三角关系
2016/07/31 Javascript
自己封装的一个原生JS拖动方法(推荐)
2016/11/22 Javascript
bootstrap输入框组使用方法
2017/02/07 Javascript
jQuery插件FusionCharts绘制的2D条状图效果【附demo源码】
2017/05/13 jQuery
微信小程序 websocket 实现SpringMVC+Spring+Mybatis
2017/08/04 Javascript
npm全局模块卸载及默认安装目录修改方法
2018/05/15 Javascript
vue3.0 CLI - 2.6 - 组件的复用入门教程
2018/09/14 Javascript
详解从Django Rest Framework响应中删除空字段
2019/01/11 Python
判断python对象是否可调用的三种方式及其区别详解
2019/01/31 Python
解决Python3.8用pip安装turtle-0.0.2出现错误问题
2020/02/11 Python
浅谈Django QuerySet对象(模型.objects)的常用方法
2020/03/28 Python
关于PyCharm安装后修改路径名称使其可重新打开的问题
2020/10/20 Python
宝拉珍选美国官网:Paula’s Choice美国
2018/01/07 全球购物
车间副主任岗位职责
2013/12/24 职场文书
诉前财产保全担保书
2014/05/20 职场文书
师德模范事迹材料
2014/06/03 职场文书
骆驼祥子读书笔记
2015/06/26 职场文书
幼儿园科学课教学反思
2016/03/03 职场文书
浅析Django接口版本控制
2021/06/26 Python
Spring Boot 实现敏感词及特殊字符过滤处理
2021/06/29 Java/Android
一文带你探究MySQL中的NULL
2021/11/11 MySQL
MySQL慢查询优化解决问题
2022/03/17 MySQL
《吸血鬼:避世 血猎》官宣4.27发售 系列首款大逃杀
2022/04/03 其他游戏