TensorFlow实现创建分类器


Posted in Python onFebruary 06, 2018

本文实例为大家分享了TensorFlow实现创建分类器的具体代码,供大家参考,具体内容如下

创建一个iris数据集的分类器。

加载样本数据集,实现一个简单的二值分类器来预测一朵花是否为山鸢尾。iris数据集有三类花,但这里仅预测是否是山鸢尾。导入iris数据集和工具库,相应地对原数据集进行转换。

# Combining Everything Together
#----------------------------------
# This file will perform binary classification on the
# iris dataset. We will only predict if a flower is
# I.setosa or not.
#
# We will create a simple binary classifier by creating a line
# and running everything through a sigmoid to get a binary predictor.
# The two features we will use are pedal length and pedal width.
#
# We will use batch training, but this can be easily
# adapted to stochastic training.

import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
import tensorflow as tf
from tensorflow.python.framework import ops
ops.reset_default_graph()

# 导入iris数据集
# 根据目标数据是否为山鸢尾将其转换成1或者0。
# 由于iris数据集将山鸢尾标记为0,我们将其从0置为1,同时把其他物种标记为0。
# 本次训练只使用两种特征:花瓣长度和花瓣宽度,这两个特征在x-value的第三列和第四列
# iris.target = {0, 1, 2}, where '0' is setosa
# iris.data ~ [sepal.width, sepal.length, pedal.width, pedal.length]
iris = datasets.load_iris()
binary_target = np.array([1. if x==0 else 0. for x in iris.target])
iris_2d = np.array([[x[2], x[3]] for x in iris.data])

# 声明批量训练大小
batch_size = 20

# 初始化计算图
sess = tf.Session()

# 声明数据占位符
x1_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
x2_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 声明模型变量
# Create variables A and b (0 = x1 - A*x2 + b)
A = tf.Variable(tf.random_normal(shape=[1, 1]))
b = tf.Variable(tf.random_normal(shape=[1, 1]))

# 定义线性模型:
# 如果找到的数据点在直线以上,则将数据点代入x2-x1*A-b计算出的结果大于0;
# 同理找到的数据点在直线以下,则将数据点代入x2-x1*A-b计算出的结果小于0。
# x1 - A*x2 + b
my_mult = tf.matmul(x2_data, A)
my_add = tf.add(my_mult, b)
my_output = tf.subtract(x1_data, my_add)

# 增加TensorFlow的sigmoid交叉熵损失函数(cross entropy)
xentropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=my_output, labels=y_target)

# 声明优化器方法
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(xentropy)

# 创建一个变量初始化操作
init = tf.global_variables_initializer()
sess.run(init)

# 运行迭代1000次
for i in range(1000):
  rand_index = np.random.choice(len(iris_2d), size=batch_size)
  # rand_x = np.transpose([iris_2d[rand_index]])
  # 传入三种数据:花瓣长度、花瓣宽度和目标变量
  rand_x = iris_2d[rand_index]
  rand_x1 = np.array([[x[0]] for x in rand_x])
  rand_x2 = np.array([[x[1]] for x in rand_x])
  #rand_y = np.transpose([binary_target[rand_index]])
  rand_y = np.array([[y] for y in binary_target[rand_index]])
  sess.run(train_step, feed_dict={x1_data: rand_x1, x2_data: rand_x2, y_target: rand_y})
  if (i+1)%200==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ', b = ' + str(sess.run(b)))


# 绘图
# 获取斜率/截距
# Pull out slope/intercept
[[slope]] = sess.run(A)
[[intercept]] = sess.run(b)

# 创建拟合线
x = np.linspace(0, 3, num=50)
ablineValues = []
for i in x:
 ablineValues.append(slope*i+intercept)

# 绘制拟合曲线
setosa_x = [a[1] for i,a in enumerate(iris_2d) if binary_target[i]==1]
setosa_y = [a[0] for i,a in enumerate(iris_2d) if binary_target[i]==1]
non_setosa_x = [a[1] for i,a in enumerate(iris_2d) if binary_target[i]==0]
non_setosa_y = [a[0] for i,a in enumerate(iris_2d) if binary_target[i]==0]
plt.plot(setosa_x, setosa_y, 'rx', ms=10, mew=2, label='setosa')
plt.plot(non_setosa_x, non_setosa_y, 'ro', label='Non-setosa')
plt.plot(x, ablineValues, 'b-')
plt.xlim([0.0, 2.7])
plt.ylim([0.0, 7.1])
plt.suptitle('Linear Separator For I.setosa', fontsize=20)
plt.xlabel('Petal Length')
plt.ylabel('Petal Width')
plt.legend(loc='lower right')
plt.show()

输出:

Step #200 A = [[ 8.70572948]], b = [[-3.46638322]]
Step #400 A = [[ 10.21302414]], b = [[-4.720438]]
Step #600 A = [[ 11.11844635]], b = [[-5.53361702]]
Step #800 A = [[ 11.86427212]], b = [[-6.0110755]]
Step #1000 A = [[ 12.49524498]], b = [[-6.29990339]]

TensorFlow实现创建分类器

 以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用cx_freeze把python打包exe示例
Jan 24 Python
pyqt4教程之实现windows窗口小示例分享
Mar 07 Python
Python使用MD5加密字符串示例
Aug 22 Python
Python中的两个内置模块介绍
Apr 05 Python
python统计文本文件内单词数量的方法
May 30 Python
python实现文本去重且不打乱原本顺序
Jan 26 Python
Python绘制股票移动均线的实例
Aug 24 Python
Ranorex通过Python将报告发送到邮箱的方法
Jan 12 Python
python GUI库图形界面开发之PyQt5美化窗体与控件(异形窗体)实例
Feb 25 Python
python工具快速为音视频自动生成字幕(使用说明)
Jan 27 Python
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 Python
Python Pandas 删除列操作
Mar 16 Python
Python模拟随机游走图形效果示例
Feb 06 #Python
Python 12306抢火车票脚本 Python京东抢手机脚本
Feb 06 #Python
TensorFlow高效读取数据的方法示例
Feb 06 #Python
django使用xlwt导出excel文件实例代码
Feb 06 #Python
Python使用装饰器进行django开发实例代码
Feb 06 #Python
Python yield与实现方法代码分析
Feb 06 #Python
Django中间件工作流程及写法实例代码
Feb 06 #Python
You might like
php中static静态变量的使用方法详解
2010/06/04 PHP
php代码运行时间查看类代码分享
2011/08/06 PHP
在Windows系统下使用PHP生成Word文档的教程
2015/07/03 PHP
让你的博文自动带上缩址的实现代码,方便发到微博客上
2010/12/28 Javascript
js 去掉空格实例 Trim() LTrim() RTrim()
2014/01/07 Javascript
jQuery .tmpl() 用法示例介绍
2014/08/21 Javascript
jQuery创建DOM元素实例解析
2015/01/19 Javascript
深入理解JavaScript系列(39):设计模式之适配器模式详解
2015/03/04 Javascript
javascript的理解及经典案例分析
2016/05/20 Javascript
Javascript之面向对象--方法
2016/12/02 Javascript
canvas 画布在主流浏览器中的尺寸限制详细介绍
2016/12/15 Javascript
js数组方法reduce经典用法代码分享
2018/01/07 Javascript
vue+axios+promise实际开发用法详解
2018/10/15 Javascript
Layui实现数据表格默认全部显示(不要分页)
2019/10/26 Javascript
autojs 蚂蚁森林能量自动拾取即给指定好友浇水的实现方法
2020/05/03 Javascript
[02:27]DOTA2英雄基础教程 莱恩
2014/01/17 DOTA
python实现每次处理一个字符的三种方法
2014/10/09 Python
Python实现的文本编辑器功能示例
2017/06/30 Python
matplotlib 输出保存指定尺寸的图片方法
2018/05/24 Python
python得到一个excel的全部sheet标签值方法
2018/12/10 Python
python提取具有某种特定字符串的行数据方法
2018/12/11 Python
用Python解决x的n次方问题
2019/02/08 Python
Python 函数返回值的示例代码
2019/03/11 Python
Python爬虫 scrapy框架爬取某招聘网存入mongodb解析
2019/07/31 Python
全面总结使用CSS实现水平垂直居中效果的方法
2016/03/10 HTML / CSS
基于 HTML5 WebGL 实现的医疗物流系统
2019/10/08 HTML / CSS
Lancer Skincare官方网站:抗衰老皮肤护理
2020/11/20 全球购物
介绍一下EJB的体系结构
2012/08/01 面试题
毕业生就业自荐信
2013/12/04 职场文书
贷款担保书范文
2014/05/13 职场文书
家庭财产分割协议范文
2014/11/24 职场文书
2015年评职称工作总结范文
2015/04/20 职场文书
史上最牛辞职信
2015/05/13 职场文书
同步小康驻村工作简报
2015/07/20 职场文书
高中语文教材(文学文化常识大全一)
2019/08/13 职场文书
MySQL日期时间函数知识汇总
2022/03/17 MySQL