利用TensorFlow训练简单的二分类神经网络模型的方法


Posted in Python onMarch 05, 2018

利用TensorFlow实现《神经网络与机器学习》一书中4.7模式分类练习

具体问题是将如下图所示双月牙数据集分类。

利用TensorFlow训练简单的二分类神经网络模型的方法

使用到的工具:

python3.5    tensorflow1.2.1   numpy   matplotlib

1.产生双月环数据集

def produceData(r,w,d,num): 
  r1 = r-w/2 
  r2 = r+w/2 
  #上半圆 
  theta1 = np.random.uniform(0, np.pi ,num) 
  X_Col1 = np.random.uniform( r1*np.cos(theta1),r2*np.cos(theta1),num)[:, np.newaxis] 
  X_Row1 = np.random.uniform(r1*np.sin(theta1),r2*np.sin(theta1),num)[:, np.newaxis] 
  Y_label1 = np.ones(num) #类别标签为1 
  #下半圆 
  theta2 = np.random.uniform(-np.pi, 0 ,num) 
  X_Col2 = (np.random.uniform( r1*np.cos(theta2),r2*np.cos(theta2),num) + r)[:, np.newaxis] 
  X_Row2 = (np.random.uniform(r1 * np.sin(theta2), r2 * np.sin(theta2), num) -d)[:,np.newaxis] 
  Y_label2 = -np.ones(num) #类别标签为-1,注意:由于采取双曲正切函数作为激活函数,类别标签不能为0 
  #合并 
  X_Col = np.vstack((X_Col1, X_Col2)) 
  X_Row = np.vstack((X_Row1, X_Row2)) 
  X = np.hstack((X_Col, X_Row)) 
  Y_label = np.hstack((Y_label1,Y_label2)) 
  Y_label.shape = (num*2 , 1) 
  return X,Y_label

其中r为月环半径,w为月环宽度,d为上下月环距离(与书中一致)

2.利用TensorFlow搭建神经网络模型

2.1 神经网络层添加

def add_layer(layername,inputs, in_size, out_size, activation_function=None): 
  # add one more layer and return the output of this layer 
  with tf.variable_scope(layername,reuse=None): 
    Weights = tf.get_variable("weights",shape=[in_size, out_size], 
                 initializer=tf.truncated_normal_initializer(stddev=0.1)) 
    biases = tf.get_variable("biases", shape=[1, out_size], 
                 initializer=tf.truncated_normal_initializer(stddev=0.1)) 
   
  Wx_plus_b = tf.matmul(inputs, Weights) + biases 
  if activation_function is None: 
    outputs = Wx_plus_b 
  else: 
    outputs = activation_function(Wx_plus_b) 
  return outputs

2.2 利用tensorflow建立神经网络模型

输入层大小:2

隐藏层大小:20

输出层大小:1

激活函数:双曲正切函数

学习率:0.1(与书中略有不同)

(具体的搭建过程可参考莫烦的视频,链接就不附上了自行搜索吧......)

###define placeholder for inputs to network 
xs = tf.placeholder(tf.float32, [None, 2]) 
ys = tf.placeholder(tf.float32, [None, 1]) 
###添加隐藏层 
l1 = add_layer("layer1",xs, 2, 20, activation_function=tf.tanh) 
###添加输出层 
prediction = add_layer("layer2",l1, 20, 1, activation_function=tf.tanh) 
###MSE 均方误差 
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction), reduction_indices=[1])) 
###优化器选取 学习率设置 此处学习率置为0.1 
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss) 
###tensorflow变量初始化,打开会话 
init = tf.global_variables_initializer()#tensorflow更新后初始化所有变量不再用tf.initialize_all_variables() 
sess = tf.Session() 
sess.run(init)

2.3 训练模型

###训练2000次 
for i in range(2000): 
  sess.run(train_step, feed_dict={xs: x_data, ys: y_label})

3.利用训练好的网络模型寻找分类决策边界

3.1 产生二维空间随机点

def produce_random_data(r,w,d,num): 
  X1 = np.random.uniform(-r-w/2,2*r+w/2, num) 
  X2 = np.random.uniform(-r - w / 2-d, r+w/2, num) 
  X = np.vstack((X1, X2)) 
  return X.transpose()

3.2 用训练好的模型采集决策边界附近的点

向网络输入一个二维空间随机点,计算输出值大于-0.5小于0.5即认为该点落在决策边界附近(双曲正切函数)

def collect_boundary_data(v_xs): 
  global prediction 
  X = np.empty([1,2]) 
  X = list() 
  for i in range(len(v_xs)): 
    x_input = v_xs[i] 
    x_input.shape = [1,2] 
    y_pre = sess.run(prediction, feed_dict={xs: x_input}) 
    if abs(y_pre - 0) < 0.5: 
      X.append(v_xs[i]) 
  return np.array(X)

3.3 用numpy工具将采集到的边界附近点拟合成决策边界曲线,用matplotlib.pyplot画图

###产生空间随机数据 
  X_NUM = produce_random_data(10, 6, -4, 5000) 
  ###边界数据采样 
  X_b = collect_boundary_data(X_NUM) 
  ###画出数据 
  fig = plt.figure() 
  ax = fig.add_subplot(1, 1, 1) 
  ###设置坐标轴名称 
  plt.xlabel('x1') 
  plt.ylabel('x2') 
  ax.scatter(x_data[:, 0], x_data[:, 1], marker='x') 
  ###用采样的边界数据拟合边界曲线 7次曲线最佳 
  z1 = np.polyfit(X_b[:, 0], X_b[:, 1], 7) 
  p1 = np.poly1d(z1) 
  x = X_b[:, 0] 
  x.sort() 
  yvals = p1(x) 
  plt.plot(x, yvals, 'r', label='boundray line') 
  plt.legend(loc=4) 
  #plt.ion() 
  plt.show()

4.效果

利用TensorFlow训练简单的二分类神经网络模型的方法

5.附上源码Github链接

https://github.com/Peakulorain/Practices.git 里的PatternClassification.py文件

另注:分类问题还是用softmax去做吧.....我只是用这做书上的练习而已。

(初学者水平有限,有问题请指出,各位大佬轻喷)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python的Supervisor进行进程监控以及自动启动
May 29 Python
python实现给字典添加条目的方法
Sep 25 Python
python求列表交集的方法汇总
Nov 10 Python
Python迭代器和生成器介绍
Mar 06 Python
python Django批量导入数据
Mar 25 Python
python实现实时监控文件的方法
Aug 26 Python
Python cookbook(数据结构与算法)根据字段将记录分组操作示例
Mar 19 Python
python DataFrame获取行数、列数、索引及第几行第几列的值方法
Apr 08 Python
python将字符串转变成dict格式的实现
Nov 18 Python
python实现图像外边界跟踪操作
Jul 13 Python
Django框架安装及项目创建过程解析
Sep 14 Python
Python万能模板案例之matplotlib绘制直方图的基本配置
Apr 13 Python
python使用Pycharm创建一个Django项目
Mar 05 #Python
python爬虫基本知识
Mar 05 #Python
用tensorflow构建线性回归模型的示例代码
Mar 05 #Python
详解python实现线程安全的单例模式
Mar 05 #Python
分析python动态规划的递归、非递归实现
Mar 04 #Python
python3.x上post发送json数据
Mar 04 #Python
python数据封装json格式数据
Mar 04 #Python
You might like
PHP分页详细讲解(有实例)
2013/10/30 PHP
PHP图片等比例缩放生成缩略图函数分享
2014/06/10 PHP
完善CodeIgniter在IDE中代码提示功能的方法
2014/07/19 PHP
10个简化PHP开发的工具
2014/12/25 PHP
PHP中实现Bloom Filter算法
2015/03/30 PHP
PHP读取Excel内的图片(phpspreadsheet和PHPExcel扩展库)
2019/11/19 PHP
JQUERY1.6 使用方法四 检测浏览器
2011/11/23 Javascript
基于jQuery的计算文本框字数的代码
2012/06/06 Javascript
ajax的hide隐藏问题解决方法
2012/12/11 Javascript
javascript中的undefined和not defined区别示例介绍
2014/02/26 Javascript
深入理解JavaScript系列(44):设计模式之桥接模式详解
2015/03/04 Javascript
由浅入深讲解Javascript继承机制与simple-inheritance源码分析
2015/12/13 Javascript
jQuery插件Easyui设置datagrid的pageNumber导致两次请求问题的解决方法
2016/08/06 Javascript
javascript 正则表达式去空行方法
2017/01/24 Javascript
微信小程序表单验证功能完整实例
2017/12/01 Javascript
分享Angular http interceptors 拦截器使用(推荐)
2019/11/10 Javascript
node.js中对Event Loop事件循环的理解与应用实例分析
2020/02/14 Javascript
[01:17]辉夜杯战队访谈宣传片—EHOME
2015/12/25 DOTA
Python实现端口复用实例代码
2014/07/03 Python
详解Python中用于计算指数的exp()方法
2015/05/14 Python
python 执行shell命令并将结果保存的实例
2018/05/11 Python
基于Pandas读取csv文件Error的总结
2018/06/15 Python
Python面向对象之继承和多态用法分析
2019/06/08 Python
Flask框架搭建虚拟环境的步骤分析
2019/12/21 Python
python使用QQ邮箱实现自动发送邮件
2020/06/22 Python
使用Python中tkinter库简单gui界面制作及打包成exe的操作方法(二)
2020/10/12 Python
CSS实现限制字数功能当对象内文本溢出时显示省略标记
2014/08/20 HTML / CSS
html5 canvas里绘制椭圆并保持线条粗细均匀的技巧
2013/03/25 HTML / CSS
HTML5制作表格样式
2016/11/15 HTML / CSS
国际性能运动服装品牌:Dare 2b
2018/07/27 全球购物
VLAN和VPN有什么区别?分别实现在OSI的第几层?
2014/12/23 面试题
管理失职检讨书
2014/02/12 职场文书
童年读书笔记
2015/06/26 职场文书
浅谈Java父子类加载顺序
2021/08/04 Java/Android
java如何实现socket连接方法封装
2021/09/25 Java/Android
记一次Mysql不走日期字段索引的原因小结
2021/10/24 MySQL