利用TensorFlow训练简单的二分类神经网络模型的方法


Posted in Python onMarch 05, 2018

利用TensorFlow实现《神经网络与机器学习》一书中4.7模式分类练习

具体问题是将如下图所示双月牙数据集分类。

利用TensorFlow训练简单的二分类神经网络模型的方法

使用到的工具:

python3.5    tensorflow1.2.1   numpy   matplotlib

1.产生双月环数据集

def produceData(r,w,d,num): 
  r1 = r-w/2 
  r2 = r+w/2 
  #上半圆 
  theta1 = np.random.uniform(0, np.pi ,num) 
  X_Col1 = np.random.uniform( r1*np.cos(theta1),r2*np.cos(theta1),num)[:, np.newaxis] 
  X_Row1 = np.random.uniform(r1*np.sin(theta1),r2*np.sin(theta1),num)[:, np.newaxis] 
  Y_label1 = np.ones(num) #类别标签为1 
  #下半圆 
  theta2 = np.random.uniform(-np.pi, 0 ,num) 
  X_Col2 = (np.random.uniform( r1*np.cos(theta2),r2*np.cos(theta2),num) + r)[:, np.newaxis] 
  X_Row2 = (np.random.uniform(r1 * np.sin(theta2), r2 * np.sin(theta2), num) -d)[:,np.newaxis] 
  Y_label2 = -np.ones(num) #类别标签为-1,注意:由于采取双曲正切函数作为激活函数,类别标签不能为0 
  #合并 
  X_Col = np.vstack((X_Col1, X_Col2)) 
  X_Row = np.vstack((X_Row1, X_Row2)) 
  X = np.hstack((X_Col, X_Row)) 
  Y_label = np.hstack((Y_label1,Y_label2)) 
  Y_label.shape = (num*2 , 1) 
  return X,Y_label

其中r为月环半径,w为月环宽度,d为上下月环距离(与书中一致)

2.利用TensorFlow搭建神经网络模型

2.1 神经网络层添加

def add_layer(layername,inputs, in_size, out_size, activation_function=None): 
  # add one more layer and return the output of this layer 
  with tf.variable_scope(layername,reuse=None): 
    Weights = tf.get_variable("weights",shape=[in_size, out_size], 
                 initializer=tf.truncated_normal_initializer(stddev=0.1)) 
    biases = tf.get_variable("biases", shape=[1, out_size], 
                 initializer=tf.truncated_normal_initializer(stddev=0.1)) 
   
  Wx_plus_b = tf.matmul(inputs, Weights) + biases 
  if activation_function is None: 
    outputs = Wx_plus_b 
  else: 
    outputs = activation_function(Wx_plus_b) 
  return outputs

2.2 利用tensorflow建立神经网络模型

输入层大小:2

隐藏层大小:20

输出层大小:1

激活函数:双曲正切函数

学习率:0.1(与书中略有不同)

(具体的搭建过程可参考莫烦的视频,链接就不附上了自行搜索吧......)

###define placeholder for inputs to network 
xs = tf.placeholder(tf.float32, [None, 2]) 
ys = tf.placeholder(tf.float32, [None, 1]) 
###添加隐藏层 
l1 = add_layer("layer1",xs, 2, 20, activation_function=tf.tanh) 
###添加输出层 
prediction = add_layer("layer2",l1, 20, 1, activation_function=tf.tanh) 
###MSE 均方误差 
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction), reduction_indices=[1])) 
###优化器选取 学习率设置 此处学习率置为0.1 
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss) 
###tensorflow变量初始化,打开会话 
init = tf.global_variables_initializer()#tensorflow更新后初始化所有变量不再用tf.initialize_all_variables() 
sess = tf.Session() 
sess.run(init)

2.3 训练模型

###训练2000次 
for i in range(2000): 
  sess.run(train_step, feed_dict={xs: x_data, ys: y_label})

3.利用训练好的网络模型寻找分类决策边界

3.1 产生二维空间随机点

def produce_random_data(r,w,d,num): 
  X1 = np.random.uniform(-r-w/2,2*r+w/2, num) 
  X2 = np.random.uniform(-r - w / 2-d, r+w/2, num) 
  X = np.vstack((X1, X2)) 
  return X.transpose()

3.2 用训练好的模型采集决策边界附近的点

向网络输入一个二维空间随机点,计算输出值大于-0.5小于0.5即认为该点落在决策边界附近(双曲正切函数)

def collect_boundary_data(v_xs): 
  global prediction 
  X = np.empty([1,2]) 
  X = list() 
  for i in range(len(v_xs)): 
    x_input = v_xs[i] 
    x_input.shape = [1,2] 
    y_pre = sess.run(prediction, feed_dict={xs: x_input}) 
    if abs(y_pre - 0) < 0.5: 
      X.append(v_xs[i]) 
  return np.array(X)

3.3 用numpy工具将采集到的边界附近点拟合成决策边界曲线,用matplotlib.pyplot画图

###产生空间随机数据 
  X_NUM = produce_random_data(10, 6, -4, 5000) 
  ###边界数据采样 
  X_b = collect_boundary_data(X_NUM) 
  ###画出数据 
  fig = plt.figure() 
  ax = fig.add_subplot(1, 1, 1) 
  ###设置坐标轴名称 
  plt.xlabel('x1') 
  plt.ylabel('x2') 
  ax.scatter(x_data[:, 0], x_data[:, 1], marker='x') 
  ###用采样的边界数据拟合边界曲线 7次曲线最佳 
  z1 = np.polyfit(X_b[:, 0], X_b[:, 1], 7) 
  p1 = np.poly1d(z1) 
  x = X_b[:, 0] 
  x.sort() 
  yvals = p1(x) 
  plt.plot(x, yvals, 'r', label='boundray line') 
  plt.legend(loc=4) 
  #plt.ion() 
  plt.show()

4.效果

利用TensorFlow训练简单的二分类神经网络模型的方法

5.附上源码Github链接

https://github.com/Peakulorain/Practices.git 里的PatternClassification.py文件

另注:分类问题还是用softmax去做吧.....我只是用这做书上的练习而已。

(初学者水平有限,有问题请指出,各位大佬轻喷)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
对python requests发送json格式数据的实例详解
Dec 19 Python
在python中将字符串转为json对象并取值的方法
Dec 31 Python
python多任务及返回值的处理方法
Jan 22 Python
Django基础知识 URL路由系统详解
Jul 18 Python
python通过opencv实现图片裁剪原理解析
Jan 19 Python
基于Python获取照片的GPS位置信息
Jan 20 Python
Anaconda3+tensorflow2.0.0+PyCharm安装与环境搭建(图文)
Feb 18 Python
DataFrame 数据合并实现(merge,join,concat)
Jun 14 Python
python opencv pytesseract 验证码识别的实现
Aug 28 Python
Python3使用 GitLab API 进行批量合并分支
Oct 15 Python
快速创建python 虚拟环境
Nov 28 Python
python 如何用urllib与服务端交互(发送和接收数据)
Mar 04 Python
python使用Pycharm创建一个Django项目
Mar 05 #Python
python爬虫基本知识
Mar 05 #Python
用tensorflow构建线性回归模型的示例代码
Mar 05 #Python
详解python实现线程安全的单例模式
Mar 05 #Python
分析python动态规划的递归、非递归实现
Mar 04 #Python
python3.x上post发送json数据
Mar 04 #Python
python数据封装json格式数据
Mar 04 #Python
You might like
PHP 将图片按创建时间进行分类存储的实现代码
2010/01/05 PHP
PHP判断图片格式的七种方法小结
2013/06/03 PHP
php判断类是否存在函数class_exists用法分析
2014/11/14 PHP
PHP使用PDO连接ACCESS数据库
2015/03/05 PHP
php安装swoole扩展的方法
2015/03/19 PHP
PHP+Mysql+jQuery实现发布微博程序 php篇
2015/10/15 PHP
利用Homestead快速运行一个Laravel项目的方法详解
2017/11/14 PHP
Laravel中validation验证 返回中文提示 全局设置的方法
2019/09/29 PHP
javascript字典探测用户名工具
2006/10/05 Javascript
EasyUI的treegrid组件动态加载数据问题的解决办法
2011/12/11 Javascript
js中单引号与双引号冲突问题解决方法
2013/10/04 Javascript
jqgrid 编辑添加功能详细解析
2013/11/08 Javascript
jquery+css实现的红色线条横向二级菜单效果
2015/08/22 Javascript
jquery图片轮播特效代码分享
2020/04/20 Javascript
AngularJS入门教程之与服务器(Ajax)交互操作示例【附完整demo源码下载】
2016/11/02 Javascript
原生js编写2048小游戏
2017/03/17 Javascript
Angular4绑定html内容出现警告的处理方法
2017/11/03 Javascript
JavaScript体验异步更好的解决办法
2018/01/08 Javascript
vue todo-list组件发布到npm上的方法
2018/04/04 Javascript
vue插件draggable实现拖拽移动图片顺序
2018/12/01 Javascript
Node.js 路由的实现方法
2019/06/05 Javascript
在vue中使用axios实现post方式获取二进制流下载文件(实例代码)
2019/12/16 Javascript
python访问纯真IP数据库的代码
2011/05/19 Python
python的常见命令注入威胁
2013/02/18 Python
Python使用os模块和fileinput模块来操作文件目录
2016/01/19 Python
Linux下python3.6.1环境配置教程
2018/09/26 Python
python 重命名轴索引的方法
2018/11/10 Python
Python寻找两个有序数组的中位数实例详解
2018/12/05 Python
python web框架Flask实现图形验证码及验证码的动态刷新实例
2019/10/14 Python
wxpython布局的实现方法
2019/11/01 Python
带你彻底搞懂python操作mysql数据库(cursor游标讲解)
2020/01/06 Python
Python多进程编程常用方法解析
2020/03/26 Python
加拿大女鞋品牌:ALDO
2016/11/13 全球购物
英国卫浴商店:Ergonomic Design
2019/09/22 全球购物
批评与自我批评范文
2014/10/15 职场文书
解决jupyter notebook启动后没有token的坑
2021/04/24 Python