利用TensorFlow训练简单的二分类神经网络模型的方法


Posted in Python onMarch 05, 2018

利用TensorFlow实现《神经网络与机器学习》一书中4.7模式分类练习

具体问题是将如下图所示双月牙数据集分类。

利用TensorFlow训练简单的二分类神经网络模型的方法

使用到的工具:

python3.5    tensorflow1.2.1   numpy   matplotlib

1.产生双月环数据集

def produceData(r,w,d,num): 
  r1 = r-w/2 
  r2 = r+w/2 
  #上半圆 
  theta1 = np.random.uniform(0, np.pi ,num) 
  X_Col1 = np.random.uniform( r1*np.cos(theta1),r2*np.cos(theta1),num)[:, np.newaxis] 
  X_Row1 = np.random.uniform(r1*np.sin(theta1),r2*np.sin(theta1),num)[:, np.newaxis] 
  Y_label1 = np.ones(num) #类别标签为1 
  #下半圆 
  theta2 = np.random.uniform(-np.pi, 0 ,num) 
  X_Col2 = (np.random.uniform( r1*np.cos(theta2),r2*np.cos(theta2),num) + r)[:, np.newaxis] 
  X_Row2 = (np.random.uniform(r1 * np.sin(theta2), r2 * np.sin(theta2), num) -d)[:,np.newaxis] 
  Y_label2 = -np.ones(num) #类别标签为-1,注意:由于采取双曲正切函数作为激活函数,类别标签不能为0 
  #合并 
  X_Col = np.vstack((X_Col1, X_Col2)) 
  X_Row = np.vstack((X_Row1, X_Row2)) 
  X = np.hstack((X_Col, X_Row)) 
  Y_label = np.hstack((Y_label1,Y_label2)) 
  Y_label.shape = (num*2 , 1) 
  return X,Y_label

其中r为月环半径,w为月环宽度,d为上下月环距离(与书中一致)

2.利用TensorFlow搭建神经网络模型

2.1 神经网络层添加

def add_layer(layername,inputs, in_size, out_size, activation_function=None): 
  # add one more layer and return the output of this layer 
  with tf.variable_scope(layername,reuse=None): 
    Weights = tf.get_variable("weights",shape=[in_size, out_size], 
                 initializer=tf.truncated_normal_initializer(stddev=0.1)) 
    biases = tf.get_variable("biases", shape=[1, out_size], 
                 initializer=tf.truncated_normal_initializer(stddev=0.1)) 
   
  Wx_plus_b = tf.matmul(inputs, Weights) + biases 
  if activation_function is None: 
    outputs = Wx_plus_b 
  else: 
    outputs = activation_function(Wx_plus_b) 
  return outputs

2.2 利用tensorflow建立神经网络模型

输入层大小:2

隐藏层大小:20

输出层大小:1

激活函数:双曲正切函数

学习率:0.1(与书中略有不同)

(具体的搭建过程可参考莫烦的视频,链接就不附上了自行搜索吧......)

###define placeholder for inputs to network 
xs = tf.placeholder(tf.float32, [None, 2]) 
ys = tf.placeholder(tf.float32, [None, 1]) 
###添加隐藏层 
l1 = add_layer("layer1",xs, 2, 20, activation_function=tf.tanh) 
###添加输出层 
prediction = add_layer("layer2",l1, 20, 1, activation_function=tf.tanh) 
###MSE 均方误差 
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction), reduction_indices=[1])) 
###优化器选取 学习率设置 此处学习率置为0.1 
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss) 
###tensorflow变量初始化,打开会话 
init = tf.global_variables_initializer()#tensorflow更新后初始化所有变量不再用tf.initialize_all_variables() 
sess = tf.Session() 
sess.run(init)

2.3 训练模型

###训练2000次 
for i in range(2000): 
  sess.run(train_step, feed_dict={xs: x_data, ys: y_label})

3.利用训练好的网络模型寻找分类决策边界

3.1 产生二维空间随机点

def produce_random_data(r,w,d,num): 
  X1 = np.random.uniform(-r-w/2,2*r+w/2, num) 
  X2 = np.random.uniform(-r - w / 2-d, r+w/2, num) 
  X = np.vstack((X1, X2)) 
  return X.transpose()

3.2 用训练好的模型采集决策边界附近的点

向网络输入一个二维空间随机点,计算输出值大于-0.5小于0.5即认为该点落在决策边界附近(双曲正切函数)

def collect_boundary_data(v_xs): 
  global prediction 
  X = np.empty([1,2]) 
  X = list() 
  for i in range(len(v_xs)): 
    x_input = v_xs[i] 
    x_input.shape = [1,2] 
    y_pre = sess.run(prediction, feed_dict={xs: x_input}) 
    if abs(y_pre - 0) < 0.5: 
      X.append(v_xs[i]) 
  return np.array(X)

3.3 用numpy工具将采集到的边界附近点拟合成决策边界曲线,用matplotlib.pyplot画图

###产生空间随机数据 
  X_NUM = produce_random_data(10, 6, -4, 5000) 
  ###边界数据采样 
  X_b = collect_boundary_data(X_NUM) 
  ###画出数据 
  fig = plt.figure() 
  ax = fig.add_subplot(1, 1, 1) 
  ###设置坐标轴名称 
  plt.xlabel('x1') 
  plt.ylabel('x2') 
  ax.scatter(x_data[:, 0], x_data[:, 1], marker='x') 
  ###用采样的边界数据拟合边界曲线 7次曲线最佳 
  z1 = np.polyfit(X_b[:, 0], X_b[:, 1], 7) 
  p1 = np.poly1d(z1) 
  x = X_b[:, 0] 
  x.sort() 
  yvals = p1(x) 
  plt.plot(x, yvals, 'r', label='boundray line') 
  plt.legend(loc=4) 
  #plt.ion() 
  plt.show()

4.效果

利用TensorFlow训练简单的二分类神经网络模型的方法

5.附上源码Github链接

https://github.com/Peakulorain/Practices.git 里的PatternClassification.py文件

另注:分类问题还是用softmax去做吧.....我只是用这做书上的练习而已。

(初学者水平有限,有问题请指出,各位大佬轻喷)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python脚本实现Web漏洞扫描工具
Oct 25 Python
OpenCV2.3.1+Python2.7.3+Numpy等的配置解析
Jan 05 Python
Python KMeans聚类问题分析
Feb 23 Python
python距离测量的方法
Mar 06 Python
Python 机器学习库 NumPy入门教程
Apr 19 Python
Python Requests模拟登录实现图书馆座位自动预约
Apr 27 Python
对python中for、if、while的区别与比较方法
Jun 25 Python
python设置值及NaN值处理方法
Jul 03 Python
python3 动态模块导入与全局变量使用实例
Dec 22 Python
Ranorex通过Python将报告发送到邮箱的方法
Jan 12 Python
Python接口测试结果集实现封装比较
May 01 Python
使用已经得到的keras模型识别自己手写的数字方式
Jun 29 Python
python使用Pycharm创建一个Django项目
Mar 05 #Python
python爬虫基本知识
Mar 05 #Python
用tensorflow构建线性回归模型的示例代码
Mar 05 #Python
详解python实现线程安全的单例模式
Mar 05 #Python
分析python动态规划的递归、非递归实现
Mar 04 #Python
python3.x上post发送json数据
Mar 04 #Python
python数据封装json格式数据
Mar 04 #Python
You might like
php打造属于自己的MVC框架
2012/03/07 PHP
非常实用的php弹出错误警告函数扩展性强
2014/01/17 PHP
ThinkPHP模板中判断volist循环的最后一条记录的验证方法
2014/07/01 PHP
CI框架中数据库操作函数$this-&gt;db-&gt;where()相关用法总结
2016/05/17 PHP
Yii使用smsto短信接口的函数demo示例
2016/07/13 PHP
PHP实现四种基础排序算法的运行时间比较(推荐)
2016/08/11 PHP
使用ucenter实现多站点同步登录的讲解
2019/03/21 PHP
Laravel修改验证提示信息为中文的示例
2019/10/23 PHP
关于javascript function对象那些迷惑分析
2011/10/24 Javascript
JavaScript的模块化:封装(闭包),继承(原型) 介绍
2013/07/22 Javascript
动态加载iframe时get请求传递中文参数乱码解决方法
2014/05/07 Javascript
JQuery显示隐藏页面元素的方法总结
2015/04/16 Javascript
jQuery插件bgStretcher.js实现全屏背景特效
2015/06/05 Javascript
详解获取jq ul第一个li定位的四种解决方案
2016/11/23 Javascript
解析JavaScript模仿块级作用域
2016/12/29 Javascript
基于js 本地存储(详解)
2017/08/16 Javascript
利用canvas中toDataURL()将图片转为dataURL(base64)的方法详解
2017/11/20 Javascript
浅析从vue源码看观察者模式
2018/01/29 Javascript
Vue实现PopupWindow组件详解
2018/04/28 Javascript
vue组件表单数据回显验证及提交的实例代码
2018/08/30 Javascript
JQuery判断radio单选框是否选中并获取值的方法
2019/01/17 jQuery
利用js-cookie实现前端设置缓存数据定时失效
2019/06/18 Javascript
微信小程序之导航滑块视图容器功能的实现代码(简单两步)
2020/06/19 Javascript
python使用socket向客户端发送数据的方法
2015/04/29 Python
python绘制无向图度分布曲线示例
2019/11/22 Python
解决django无法访问本地static文件(js,css,img)网页里js,cs都加载不了
2020/04/07 Python
Python基于DB-API操作MySQL数据库过程解析
2020/04/23 Python
Selenium获取登录Cookies并添加Cookies自动登录的方法
2020/12/04 Python
澳大利亚快时尚鞋类市场:Billini
2018/05/20 全球购物
建筑设计所实习生自我鉴定
2013/09/25 职场文书
军训心得体会
2013/12/31 职场文书
教师节感恩老师演讲稿
2014/08/28 职场文书
银行领导班子四风对照检查材料
2014/09/27 职场文书
刑事附带民事诉讼答辩状
2015/05/22 职场文书
浅谈如何写好演讲稿?
2019/06/12 职场文书
详解Go与PHP的语法对比
2021/05/29 PHP