TensorFlow实现创建分类器


Posted in Python onFebruary 06, 2018

本文实例为大家分享了TensorFlow实现创建分类器的具体代码,供大家参考,具体内容如下

创建一个iris数据集的分类器。

加载样本数据集,实现一个简单的二值分类器来预测一朵花是否为山鸢尾。iris数据集有三类花,但这里仅预测是否是山鸢尾。导入iris数据集和工具库,相应地对原数据集进行转换。

# Combining Everything Together
#----------------------------------
# This file will perform binary classification on the
# iris dataset. We will only predict if a flower is
# I.setosa or not.
#
# We will create a simple binary classifier by creating a line
# and running everything through a sigmoid to get a binary predictor.
# The two features we will use are pedal length and pedal width.
#
# We will use batch training, but this can be easily
# adapted to stochastic training.

import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
import tensorflow as tf
from tensorflow.python.framework import ops
ops.reset_default_graph()

# 导入iris数据集
# 根据目标数据是否为山鸢尾将其转换成1或者0。
# 由于iris数据集将山鸢尾标记为0,我们将其从0置为1,同时把其他物种标记为0。
# 本次训练只使用两种特征:花瓣长度和花瓣宽度,这两个特征在x-value的第三列和第四列
# iris.target = {0, 1, 2}, where '0' is setosa
# iris.data ~ [sepal.width, sepal.length, pedal.width, pedal.length]
iris = datasets.load_iris()
binary_target = np.array([1. if x==0 else 0. for x in iris.target])
iris_2d = np.array([[x[2], x[3]] for x in iris.data])

# 声明批量训练大小
batch_size = 20

# 初始化计算图
sess = tf.Session()

# 声明数据占位符
x1_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
x2_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 声明模型变量
# Create variables A and b (0 = x1 - A*x2 + b)
A = tf.Variable(tf.random_normal(shape=[1, 1]))
b = tf.Variable(tf.random_normal(shape=[1, 1]))

# 定义线性模型:
# 如果找到的数据点在直线以上,则将数据点代入x2-x1*A-b计算出的结果大于0;
# 同理找到的数据点在直线以下,则将数据点代入x2-x1*A-b计算出的结果小于0。
# x1 - A*x2 + b
my_mult = tf.matmul(x2_data, A)
my_add = tf.add(my_mult, b)
my_output = tf.subtract(x1_data, my_add)

# 增加TensorFlow的sigmoid交叉熵损失函数(cross entropy)
xentropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=my_output, labels=y_target)

# 声明优化器方法
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(xentropy)

# 创建一个变量初始化操作
init = tf.global_variables_initializer()
sess.run(init)

# 运行迭代1000次
for i in range(1000):
  rand_index = np.random.choice(len(iris_2d), size=batch_size)
  # rand_x = np.transpose([iris_2d[rand_index]])
  # 传入三种数据:花瓣长度、花瓣宽度和目标变量
  rand_x = iris_2d[rand_index]
  rand_x1 = np.array([[x[0]] for x in rand_x])
  rand_x2 = np.array([[x[1]] for x in rand_x])
  #rand_y = np.transpose([binary_target[rand_index]])
  rand_y = np.array([[y] for y in binary_target[rand_index]])
  sess.run(train_step, feed_dict={x1_data: rand_x1, x2_data: rand_x2, y_target: rand_y})
  if (i+1)%200==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ', b = ' + str(sess.run(b)))


# 绘图
# 获取斜率/截距
# Pull out slope/intercept
[[slope]] = sess.run(A)
[[intercept]] = sess.run(b)

# 创建拟合线
x = np.linspace(0, 3, num=50)
ablineValues = []
for i in x:
 ablineValues.append(slope*i+intercept)

# 绘制拟合曲线
setosa_x = [a[1] for i,a in enumerate(iris_2d) if binary_target[i]==1]
setosa_y = [a[0] for i,a in enumerate(iris_2d) if binary_target[i]==1]
non_setosa_x = [a[1] for i,a in enumerate(iris_2d) if binary_target[i]==0]
non_setosa_y = [a[0] for i,a in enumerate(iris_2d) if binary_target[i]==0]
plt.plot(setosa_x, setosa_y, 'rx', ms=10, mew=2, label='setosa')
plt.plot(non_setosa_x, non_setosa_y, 'ro', label='Non-setosa')
plt.plot(x, ablineValues, 'b-')
plt.xlim([0.0, 2.7])
plt.ylim([0.0, 7.1])
plt.suptitle('Linear Separator For I.setosa', fontsize=20)
plt.xlabel('Petal Length')
plt.ylabel('Petal Width')
plt.legend(loc='lower right')
plt.show()

输出:

Step #200 A = [[ 8.70572948]], b = [[-3.46638322]]
Step #400 A = [[ 10.21302414]], b = [[-4.720438]]
Step #600 A = [[ 11.11844635]], b = [[-5.53361702]]
Step #800 A = [[ 11.86427212]], b = [[-6.0110755]]
Step #1000 A = [[ 12.49524498]], b = [[-6.29990339]]

TensorFlow实现创建分类器

 以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的函数的一些高阶特性
Apr 27 Python
python使用MySQLdb访问mysql数据库的方法
Aug 03 Python
使用Python的Flask框架表单插件Flask-WTF实现Web登录验证
Jul 12 Python
特征脸(Eigenface)理论基础之PCA主成分分析法
Mar 13 Python
对pytorch网络层结构的数组化详解
Dec 08 Python
运用PyTorch动手搭建一个共享单车预测器
Aug 06 Python
Python爬虫库BeautifulSoup获取对象(标签)名,属性,内容,注释
Jan 25 Python
Python文字截图识别OCR工具实例解析
Mar 05 Python
Python pip install如何修改默认下载路径
Apr 29 Python
windows支持哪个版本的python
Jul 03 Python
Django-imagekit的使用详解
Jul 06 Python
详解python爬取弹幕与数据分析
Nov 14 Python
Python模拟随机游走图形效果示例
Feb 06 #Python
Python 12306抢火车票脚本 Python京东抢手机脚本
Feb 06 #Python
TensorFlow高效读取数据的方法示例
Feb 06 #Python
django使用xlwt导出excel文件实例代码
Feb 06 #Python
Python使用装饰器进行django开发实例代码
Feb 06 #Python
Python yield与实现方法代码分析
Feb 06 #Python
Django中间件工作流程及写法实例代码
Feb 06 #Python
You might like
ThinkPHP使用心得分享-ThinkPHP + Ajax 实现2级联动下拉菜单
2014/05/15 PHP
PHP使用glob函数遍历目录或文件夹的方法
2014/12/16 PHP
YII Framework框架教程之安全方案详解
2016/03/14 PHP
PHP实现bitmap位图排序与求交集的方法
2016/07/28 PHP
Zend Framework上传文件重命名的实现方法
2016/11/25 PHP
php表单习惯用的正则表达式
2017/10/11 PHP
php连接sftp的作用以及实例代码
2019/09/23 PHP
Javascript 页面模板化很多人没有使用过的方法
2012/06/05 Javascript
jQuery 仿百度输入标签插件附效果图
2014/07/04 Javascript
JavaScript中的公有、私有、特权和静态成员用法分析
2014/11/20 Javascript
js实现具有高亮显示效果的多级菜单代码
2015/09/01 Javascript
微信小程序中setInterval的使用方法
2017/09/29 Javascript
JS实现元素上下左右移动效果
2017/10/18 Javascript
在Vue组件中获取全局的点击事件方法
2018/09/06 Javascript
vue实现PC端分辨率适配操作
2020/08/03 Javascript
详解JS深拷贝与浅拷贝
2020/08/04 Javascript
Python程序员面试题 你必须提前准备!
2018/01/16 Python
Python实现的读取/更改/写入xml文件操作示例
2018/08/30 Python
numpy中的ndarray方法和属性详解
2019/05/27 Python
python tkinter实现屏保程序
2019/07/30 Python
pytorch 归一化与反归一化实例
2019/12/31 Python
python中对二维列表中一维列表的调用方法
2020/06/07 Python
Python collections.deque双边队列原理详解
2020/10/05 Python
美国时尚女装在线:Missguided
2016/12/03 全球购物
EVE LOM英国官网:全世界最好的洁面膏
2017/10/30 全球购物
泰国第一的化妆品网站:Konvy
2018/02/25 全球购物
Beach Bunny Swimwear官网:设计师泳装和性感比基尼
2019/03/13 全球购物
造价工程师个人求职信
2013/09/21 职场文书
门诊手术室工作制度
2014/01/30 职场文书
激情洋溢的毕业生就业求职信
2014/03/15 职场文书
经济担保书范文
2014/04/02 职场文书
实习生岗位职责
2014/04/12 职场文书
责任书范本
2014/08/25 职场文书
客户经理岗位职责大全
2015/04/09 职场文书
Python爬取某拍短视频
2021/06/11 Python
Java实现学生管理系统(IO版)
2022/02/24 Java/Android