利用TensorFlow训练简单的二分类神经网络模型的方法


Posted in Python onMarch 05, 2018

利用TensorFlow实现《神经网络与机器学习》一书中4.7模式分类练习

具体问题是将如下图所示双月牙数据集分类。

利用TensorFlow训练简单的二分类神经网络模型的方法

使用到的工具:

python3.5    tensorflow1.2.1   numpy   matplotlib

1.产生双月环数据集

def produceData(r,w,d,num): 
  r1 = r-w/2 
  r2 = r+w/2 
  #上半圆 
  theta1 = np.random.uniform(0, np.pi ,num) 
  X_Col1 = np.random.uniform( r1*np.cos(theta1),r2*np.cos(theta1),num)[:, np.newaxis] 
  X_Row1 = np.random.uniform(r1*np.sin(theta1),r2*np.sin(theta1),num)[:, np.newaxis] 
  Y_label1 = np.ones(num) #类别标签为1 
  #下半圆 
  theta2 = np.random.uniform(-np.pi, 0 ,num) 
  X_Col2 = (np.random.uniform( r1*np.cos(theta2),r2*np.cos(theta2),num) + r)[:, np.newaxis] 
  X_Row2 = (np.random.uniform(r1 * np.sin(theta2), r2 * np.sin(theta2), num) -d)[:,np.newaxis] 
  Y_label2 = -np.ones(num) #类别标签为-1,注意:由于采取双曲正切函数作为激活函数,类别标签不能为0 
  #合并 
  X_Col = np.vstack((X_Col1, X_Col2)) 
  X_Row = np.vstack((X_Row1, X_Row2)) 
  X = np.hstack((X_Col, X_Row)) 
  Y_label = np.hstack((Y_label1,Y_label2)) 
  Y_label.shape = (num*2 , 1) 
  return X,Y_label

其中r为月环半径,w为月环宽度,d为上下月环距离(与书中一致)

2.利用TensorFlow搭建神经网络模型

2.1 神经网络层添加

def add_layer(layername,inputs, in_size, out_size, activation_function=None): 
  # add one more layer and return the output of this layer 
  with tf.variable_scope(layername,reuse=None): 
    Weights = tf.get_variable("weights",shape=[in_size, out_size], 
                 initializer=tf.truncated_normal_initializer(stddev=0.1)) 
    biases = tf.get_variable("biases", shape=[1, out_size], 
                 initializer=tf.truncated_normal_initializer(stddev=0.1)) 
   
  Wx_plus_b = tf.matmul(inputs, Weights) + biases 
  if activation_function is None: 
    outputs = Wx_plus_b 
  else: 
    outputs = activation_function(Wx_plus_b) 
  return outputs

2.2 利用tensorflow建立神经网络模型

输入层大小:2

隐藏层大小:20

输出层大小:1

激活函数:双曲正切函数

学习率:0.1(与书中略有不同)

(具体的搭建过程可参考莫烦的视频,链接就不附上了自行搜索吧......)

###define placeholder for inputs to network 
xs = tf.placeholder(tf.float32, [None, 2]) 
ys = tf.placeholder(tf.float32, [None, 1]) 
###添加隐藏层 
l1 = add_layer("layer1",xs, 2, 20, activation_function=tf.tanh) 
###添加输出层 
prediction = add_layer("layer2",l1, 20, 1, activation_function=tf.tanh) 
###MSE 均方误差 
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction), reduction_indices=[1])) 
###优化器选取 学习率设置 此处学习率置为0.1 
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss) 
###tensorflow变量初始化,打开会话 
init = tf.global_variables_initializer()#tensorflow更新后初始化所有变量不再用tf.initialize_all_variables() 
sess = tf.Session() 
sess.run(init)

2.3 训练模型

###训练2000次 
for i in range(2000): 
  sess.run(train_step, feed_dict={xs: x_data, ys: y_label})

3.利用训练好的网络模型寻找分类决策边界

3.1 产生二维空间随机点

def produce_random_data(r,w,d,num): 
  X1 = np.random.uniform(-r-w/2,2*r+w/2, num) 
  X2 = np.random.uniform(-r - w / 2-d, r+w/2, num) 
  X = np.vstack((X1, X2)) 
  return X.transpose()

3.2 用训练好的模型采集决策边界附近的点

向网络输入一个二维空间随机点,计算输出值大于-0.5小于0.5即认为该点落在决策边界附近(双曲正切函数)

def collect_boundary_data(v_xs): 
  global prediction 
  X = np.empty([1,2]) 
  X = list() 
  for i in range(len(v_xs)): 
    x_input = v_xs[i] 
    x_input.shape = [1,2] 
    y_pre = sess.run(prediction, feed_dict={xs: x_input}) 
    if abs(y_pre - 0) < 0.5: 
      X.append(v_xs[i]) 
  return np.array(X)

3.3 用numpy工具将采集到的边界附近点拟合成决策边界曲线,用matplotlib.pyplot画图

###产生空间随机数据 
  X_NUM = produce_random_data(10, 6, -4, 5000) 
  ###边界数据采样 
  X_b = collect_boundary_data(X_NUM) 
  ###画出数据 
  fig = plt.figure() 
  ax = fig.add_subplot(1, 1, 1) 
  ###设置坐标轴名称 
  plt.xlabel('x1') 
  plt.ylabel('x2') 
  ax.scatter(x_data[:, 0], x_data[:, 1], marker='x') 
  ###用采样的边界数据拟合边界曲线 7次曲线最佳 
  z1 = np.polyfit(X_b[:, 0], X_b[:, 1], 7) 
  p1 = np.poly1d(z1) 
  x = X_b[:, 0] 
  x.sort() 
  yvals = p1(x) 
  plt.plot(x, yvals, 'r', label='boundray line') 
  plt.legend(loc=4) 
  #plt.ion() 
  plt.show()

4.效果

利用TensorFlow训练简单的二分类神经网络模型的方法

5.附上源码Github链接

https://github.com/Peakulorain/Practices.git 里的PatternClassification.py文件

另注:分类问题还是用softmax去做吧.....我只是用这做书上的练习而已。

(初学者水平有限,有问题请指出,各位大佬轻喷)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python getopt 参数处理小示例
Jun 09 Python
Python的装饰器用法学习笔记
Jun 24 Python
Python操作使用MySQL数据库的实例代码
May 25 Python
微信跳一跳小游戏python脚本
Jan 05 Python
关于python写入文件自动换行的问题
Jun 23 Python
python处理multipart/form-data的请求方法
Dec 26 Python
python3.7 sys模块的具体使用
Jul 22 Python
基于Python中的yield表达式介绍
Nov 19 Python
Python读取实时数据流示例
Dec 02 Python
Python 实现数组相减示例
Dec 27 Python
在TensorFlow中实现矩阵维度扩展
May 22 Python
Python设计密码强度校验程序
Jul 30 Python
python使用Pycharm创建一个Django项目
Mar 05 #Python
python爬虫基本知识
Mar 05 #Python
用tensorflow构建线性回归模型的示例代码
Mar 05 #Python
详解python实现线程安全的单例模式
Mar 05 #Python
分析python动态规划的递归、非递归实现
Mar 04 #Python
python3.x上post发送json数据
Mar 04 #Python
python数据封装json格式数据
Mar 04 #Python
You might like
3种平台下安装php4经验点滴
2006/10/09 PHP
php目录管理函数小结
2008/09/10 PHP
php5 pdo新改动加载注意事项
2008/09/11 PHP
PHP Ajax中文乱码问题解决方法
2009/02/27 PHP
PHP与Ajax相结合实现登录验证小Demo
2016/03/16 PHP
PHP排序算法之归并排序(Merging Sort)实例详解
2018/04/21 PHP
js获取图片长和宽度的代码
2009/11/24 Javascript
javascript 哈希表(hashtable)的简单实现
2010/01/20 Javascript
改写一个简单的菜单 弹性大小
2010/12/02 Javascript
仿jQuery的siblings效果的js代码
2011/08/09 Javascript
js中有关IE版本检测
2012/01/04 Javascript
NodeJS与Mysql的交互示例代码
2013/08/18 NodeJs
jquery实现页面关键词高亮显示的方法
2015/03/12 Javascript
Node.js中Request模块处理HTTP协议请求的基本使用教程
2016/03/31 Javascript
JS中取二维数组中最大值的方法汇总
2016/04/17 Javascript
JavaScript模板引擎Template.js使用详解
2016/12/15 Javascript
js实现无缝滚动图
2017/02/22 Javascript
关于JavaScript中的this指向问题总结篇
2017/07/23 Javascript
Bootstrap 树控件使用经验分享(图文解说)
2017/11/06 Javascript
JavaScript中EventLoop介绍
2018/01/22 Javascript
jQuery AJAX 方法success()后台传来的4种数据详解
2018/08/08 jQuery
Vue-router 切换组件页面时进入进出动画方法
2018/09/01 Javascript
js+html+css实现手动轮播和自动轮播
2020/12/30 Javascript
Django文件存储 默认存储系统解析
2019/08/02 Python
Python爬取365好书中小说代码实例
2020/02/28 Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
2020/05/26 Python
什么是继承
2013/12/07 面试题
副护士长竞聘演讲稿
2014/04/30 职场文书
法律顾问服务方案
2014/05/15 职场文书
体育节口号
2014/06/19 职场文书
爱护公共设施倡议书
2014/08/29 职场文书
群众路线查摆问题及整改措施
2014/10/10 职场文书
结婚堵门保证书
2015/05/08 职场文书
优秀共产党员事迹材料2016
2016/02/29 职场文书
应用最多的公文《通知》如何写?
2019/04/02 职场文书
vue.js 使用原生js实现轮播图
2022/04/26 Vue.js