利用TensorFlow训练简单的二分类神经网络模型的方法


Posted in Python onMarch 05, 2018

利用TensorFlow实现《神经网络与机器学习》一书中4.7模式分类练习

具体问题是将如下图所示双月牙数据集分类。

利用TensorFlow训练简单的二分类神经网络模型的方法

使用到的工具:

python3.5    tensorflow1.2.1   numpy   matplotlib

1.产生双月环数据集

def produceData(r,w,d,num): 
  r1 = r-w/2 
  r2 = r+w/2 
  #上半圆 
  theta1 = np.random.uniform(0, np.pi ,num) 
  X_Col1 = np.random.uniform( r1*np.cos(theta1),r2*np.cos(theta1),num)[:, np.newaxis] 
  X_Row1 = np.random.uniform(r1*np.sin(theta1),r2*np.sin(theta1),num)[:, np.newaxis] 
  Y_label1 = np.ones(num) #类别标签为1 
  #下半圆 
  theta2 = np.random.uniform(-np.pi, 0 ,num) 
  X_Col2 = (np.random.uniform( r1*np.cos(theta2),r2*np.cos(theta2),num) + r)[:, np.newaxis] 
  X_Row2 = (np.random.uniform(r1 * np.sin(theta2), r2 * np.sin(theta2), num) -d)[:,np.newaxis] 
  Y_label2 = -np.ones(num) #类别标签为-1,注意:由于采取双曲正切函数作为激活函数,类别标签不能为0 
  #合并 
  X_Col = np.vstack((X_Col1, X_Col2)) 
  X_Row = np.vstack((X_Row1, X_Row2)) 
  X = np.hstack((X_Col, X_Row)) 
  Y_label = np.hstack((Y_label1,Y_label2)) 
  Y_label.shape = (num*2 , 1) 
  return X,Y_label

其中r为月环半径,w为月环宽度,d为上下月环距离(与书中一致)

2.利用TensorFlow搭建神经网络模型

2.1 神经网络层添加

def add_layer(layername,inputs, in_size, out_size, activation_function=None): 
  # add one more layer and return the output of this layer 
  with tf.variable_scope(layername,reuse=None): 
    Weights = tf.get_variable("weights",shape=[in_size, out_size], 
                 initializer=tf.truncated_normal_initializer(stddev=0.1)) 
    biases = tf.get_variable("biases", shape=[1, out_size], 
                 initializer=tf.truncated_normal_initializer(stddev=0.1)) 
   
  Wx_plus_b = tf.matmul(inputs, Weights) + biases 
  if activation_function is None: 
    outputs = Wx_plus_b 
  else: 
    outputs = activation_function(Wx_plus_b) 
  return outputs

2.2 利用tensorflow建立神经网络模型

输入层大小:2

隐藏层大小:20

输出层大小:1

激活函数:双曲正切函数

学习率:0.1(与书中略有不同)

(具体的搭建过程可参考莫烦的视频,链接就不附上了自行搜索吧......)

###define placeholder for inputs to network 
xs = tf.placeholder(tf.float32, [None, 2]) 
ys = tf.placeholder(tf.float32, [None, 1]) 
###添加隐藏层 
l1 = add_layer("layer1",xs, 2, 20, activation_function=tf.tanh) 
###添加输出层 
prediction = add_layer("layer2",l1, 20, 1, activation_function=tf.tanh) 
###MSE 均方误差 
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction), reduction_indices=[1])) 
###优化器选取 学习率设置 此处学习率置为0.1 
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss) 
###tensorflow变量初始化,打开会话 
init = tf.global_variables_initializer()#tensorflow更新后初始化所有变量不再用tf.initialize_all_variables() 
sess = tf.Session() 
sess.run(init)

2.3 训练模型

###训练2000次 
for i in range(2000): 
  sess.run(train_step, feed_dict={xs: x_data, ys: y_label})

3.利用训练好的网络模型寻找分类决策边界

3.1 产生二维空间随机点

def produce_random_data(r,w,d,num): 
  X1 = np.random.uniform(-r-w/2,2*r+w/2, num) 
  X2 = np.random.uniform(-r - w / 2-d, r+w/2, num) 
  X = np.vstack((X1, X2)) 
  return X.transpose()

3.2 用训练好的模型采集决策边界附近的点

向网络输入一个二维空间随机点,计算输出值大于-0.5小于0.5即认为该点落在决策边界附近(双曲正切函数)

def collect_boundary_data(v_xs): 
  global prediction 
  X = np.empty([1,2]) 
  X = list() 
  for i in range(len(v_xs)): 
    x_input = v_xs[i] 
    x_input.shape = [1,2] 
    y_pre = sess.run(prediction, feed_dict={xs: x_input}) 
    if abs(y_pre - 0) < 0.5: 
      X.append(v_xs[i]) 
  return np.array(X)

3.3 用numpy工具将采集到的边界附近点拟合成决策边界曲线,用matplotlib.pyplot画图

###产生空间随机数据 
  X_NUM = produce_random_data(10, 6, -4, 5000) 
  ###边界数据采样 
  X_b = collect_boundary_data(X_NUM) 
  ###画出数据 
  fig = plt.figure() 
  ax = fig.add_subplot(1, 1, 1) 
  ###设置坐标轴名称 
  plt.xlabel('x1') 
  plt.ylabel('x2') 
  ax.scatter(x_data[:, 0], x_data[:, 1], marker='x') 
  ###用采样的边界数据拟合边界曲线 7次曲线最佳 
  z1 = np.polyfit(X_b[:, 0], X_b[:, 1], 7) 
  p1 = np.poly1d(z1) 
  x = X_b[:, 0] 
  x.sort() 
  yvals = p1(x) 
  plt.plot(x, yvals, 'r', label='boundray line') 
  plt.legend(loc=4) 
  #plt.ion() 
  plt.show()

4.效果

利用TensorFlow训练简单的二分类神经网络模型的方法

5.附上源码Github链接

https://github.com/Peakulorain/Practices.git 里的PatternClassification.py文件

另注:分类问题还是用softmax去做吧.....我只是用这做书上的练习而已。

(初学者水平有限,有问题请指出,各位大佬轻喷)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python自定义主从分布式架构实例分析
Sep 19 Python
Python 实现链表实例代码
Apr 07 Python
Mac 上切换Python多版本
Jun 17 Python
Python网络爬虫中的同步与异步示例详解
Feb 03 Python
Python3中关于cookie的创建与保存
Oct 21 Python
python3 自动识别usb连接状态,即对usb重连的判断方法
Jul 03 Python
PyQt 图解Qt Designer工具的使用方法
Aug 06 Python
Python (Win)readline和tab补全的安装方法
Aug 27 Python
python打印直角三角形与等腰三角形实例代码
Oct 20 Python
python 给图像添加透明度(alpha通道)
Apr 09 Python
OpenCV灰度化之后图片为绿色的解决
Dec 01 Python
Python torch.flatten()函数案例详解
Aug 30 Python
python使用Pycharm创建一个Django项目
Mar 05 #Python
python爬虫基本知识
Mar 05 #Python
用tensorflow构建线性回归模型的示例代码
Mar 05 #Python
详解python实现线程安全的单例模式
Mar 05 #Python
分析python动态规划的递归、非递归实现
Mar 04 #Python
python3.x上post发送json数据
Mar 04 #Python
python数据封装json格式数据
Mar 04 #Python
You might like
在WordPress中实现发送http请求的相关函数解析
2015/12/29 PHP
PHP中函数gzuncompress无法使用的解决方法
2017/03/02 PHP
php 使用expat方式解析xml文件操作示例
2019/11/26 PHP
jQuery 1.0.4 - New Wave Javascript(js源文件)
2007/01/15 Javascript
JavaScript CSS菜单功能 改进版
2008/12/20 Javascript
javascript+xml实现简单图片轮换(只支持IE)
2012/12/23 Javascript
JS禁用浏览器退格键实现思路及代码
2013/10/29 Javascript
现如今最流行的JavaScript代码规范
2014/03/08 Javascript
JavaScript实现的图像模糊算法代码分享
2014/04/22 Javascript
node.js中的fs.fchownSync方法使用说明
2014/12/16 Javascript
DOM操作一些常用的属性汇总
2015/03/13 Javascript
推荐10 款 SVG 动画的 JavaScript 库
2015/03/24 Javascript
JavaScript构建自己的对象示例
2016/11/29 Javascript
解析Vue2.0双向绑定实现原理
2017/02/23 Javascript
jstree单选功能的实现方法
2017/06/07 Javascript
Vue学习笔记进阶篇之函数化组件解析
2017/07/21 Javascript
通过button将form表单的数据提交到action层的实例
2017/09/08 Javascript
JS库之Highlight.js的用法详解
2017/09/13 Javascript
vue中使用sessionStorage记住密码功能
2018/07/24 Javascript
react基本安装与测试示例
2020/04/27 Javascript
[01:00:52]2018DOTA2亚洲邀请赛 4.4 淘汰赛 EG vs LGD 第一场
2018/04/05 DOTA
Python的Django中将文件上传至七牛云存储的代码分享
2016/06/03 Python
Python字典及字典基本操作方法详解
2018/01/30 Python
selenium+python 去除启动的黑色cmd窗口方法
2018/05/22 Python
记录Python脚本的运行日志的方法
2019/06/05 Python
python中时间、日期、时间戳的转换的实现方法
2019/07/06 Python
Python +Selenium解决图片验证码登录或注册问题(推荐)
2020/02/09 Python
详解Python的三种拷贝方式
2020/02/11 Python
Python3标准库之threading进程中管理并发操作方法
2020/03/30 Python
针对HTML5的Web Worker使用攻略
2015/07/12 HTML / CSS
英国蛋糕装饰用品一站式商店:Craft Company
2019/03/18 全球购物
简述Linux文件系统通过i节点把文件的逻辑结构和物理结构转换的工作过程
2012/04/17 面试题
中级会计职业生涯规划书
2014/03/01 职场文书
奠基仪式策划方案
2014/05/15 职场文书
酒店辞职书怎么写
2015/02/26 职场文书
2019单位介绍信怎么写
2019/06/24 职场文书