利用TensorFlow训练简单的二分类神经网络模型的方法


Posted in Python onMarch 05, 2018

利用TensorFlow实现《神经网络与机器学习》一书中4.7模式分类练习

具体问题是将如下图所示双月牙数据集分类。

利用TensorFlow训练简单的二分类神经网络模型的方法

使用到的工具:

python3.5    tensorflow1.2.1   numpy   matplotlib

1.产生双月环数据集

def produceData(r,w,d,num): 
  r1 = r-w/2 
  r2 = r+w/2 
  #上半圆 
  theta1 = np.random.uniform(0, np.pi ,num) 
  X_Col1 = np.random.uniform( r1*np.cos(theta1),r2*np.cos(theta1),num)[:, np.newaxis] 
  X_Row1 = np.random.uniform(r1*np.sin(theta1),r2*np.sin(theta1),num)[:, np.newaxis] 
  Y_label1 = np.ones(num) #类别标签为1 
  #下半圆 
  theta2 = np.random.uniform(-np.pi, 0 ,num) 
  X_Col2 = (np.random.uniform( r1*np.cos(theta2),r2*np.cos(theta2),num) + r)[:, np.newaxis] 
  X_Row2 = (np.random.uniform(r1 * np.sin(theta2), r2 * np.sin(theta2), num) -d)[:,np.newaxis] 
  Y_label2 = -np.ones(num) #类别标签为-1,注意:由于采取双曲正切函数作为激活函数,类别标签不能为0 
  #合并 
  X_Col = np.vstack((X_Col1, X_Col2)) 
  X_Row = np.vstack((X_Row1, X_Row2)) 
  X = np.hstack((X_Col, X_Row)) 
  Y_label = np.hstack((Y_label1,Y_label2)) 
  Y_label.shape = (num*2 , 1) 
  return X,Y_label

其中r为月环半径,w为月环宽度,d为上下月环距离(与书中一致)

2.利用TensorFlow搭建神经网络模型

2.1 神经网络层添加

def add_layer(layername,inputs, in_size, out_size, activation_function=None): 
  # add one more layer and return the output of this layer 
  with tf.variable_scope(layername,reuse=None): 
    Weights = tf.get_variable("weights",shape=[in_size, out_size], 
                 initializer=tf.truncated_normal_initializer(stddev=0.1)) 
    biases = tf.get_variable("biases", shape=[1, out_size], 
                 initializer=tf.truncated_normal_initializer(stddev=0.1)) 
   
  Wx_plus_b = tf.matmul(inputs, Weights) + biases 
  if activation_function is None: 
    outputs = Wx_plus_b 
  else: 
    outputs = activation_function(Wx_plus_b) 
  return outputs

2.2 利用tensorflow建立神经网络模型

输入层大小:2

隐藏层大小:20

输出层大小:1

激活函数:双曲正切函数

学习率:0.1(与书中略有不同)

(具体的搭建过程可参考莫烦的视频,链接就不附上了自行搜索吧......)

###define placeholder for inputs to network 
xs = tf.placeholder(tf.float32, [None, 2]) 
ys = tf.placeholder(tf.float32, [None, 1]) 
###添加隐藏层 
l1 = add_layer("layer1",xs, 2, 20, activation_function=tf.tanh) 
###添加输出层 
prediction = add_layer("layer2",l1, 20, 1, activation_function=tf.tanh) 
###MSE 均方误差 
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction), reduction_indices=[1])) 
###优化器选取 学习率设置 此处学习率置为0.1 
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss) 
###tensorflow变量初始化,打开会话 
init = tf.global_variables_initializer()#tensorflow更新后初始化所有变量不再用tf.initialize_all_variables() 
sess = tf.Session() 
sess.run(init)

2.3 训练模型

###训练2000次 
for i in range(2000): 
  sess.run(train_step, feed_dict={xs: x_data, ys: y_label})

3.利用训练好的网络模型寻找分类决策边界

3.1 产生二维空间随机点

def produce_random_data(r,w,d,num): 
  X1 = np.random.uniform(-r-w/2,2*r+w/2, num) 
  X2 = np.random.uniform(-r - w / 2-d, r+w/2, num) 
  X = np.vstack((X1, X2)) 
  return X.transpose()

3.2 用训练好的模型采集决策边界附近的点

向网络输入一个二维空间随机点,计算输出值大于-0.5小于0.5即认为该点落在决策边界附近(双曲正切函数)

def collect_boundary_data(v_xs): 
  global prediction 
  X = np.empty([1,2]) 
  X = list() 
  for i in range(len(v_xs)): 
    x_input = v_xs[i] 
    x_input.shape = [1,2] 
    y_pre = sess.run(prediction, feed_dict={xs: x_input}) 
    if abs(y_pre - 0) < 0.5: 
      X.append(v_xs[i]) 
  return np.array(X)

3.3 用numpy工具将采集到的边界附近点拟合成决策边界曲线,用matplotlib.pyplot画图

###产生空间随机数据 
  X_NUM = produce_random_data(10, 6, -4, 5000) 
  ###边界数据采样 
  X_b = collect_boundary_data(X_NUM) 
  ###画出数据 
  fig = plt.figure() 
  ax = fig.add_subplot(1, 1, 1) 
  ###设置坐标轴名称 
  plt.xlabel('x1') 
  plt.ylabel('x2') 
  ax.scatter(x_data[:, 0], x_data[:, 1], marker='x') 
  ###用采样的边界数据拟合边界曲线 7次曲线最佳 
  z1 = np.polyfit(X_b[:, 0], X_b[:, 1], 7) 
  p1 = np.poly1d(z1) 
  x = X_b[:, 0] 
  x.sort() 
  yvals = p1(x) 
  plt.plot(x, yvals, 'r', label='boundray line') 
  plt.legend(loc=4) 
  #plt.ion() 
  plt.show()

4.效果

利用TensorFlow训练简单的二分类神经网络模型的方法

5.附上源码Github链接

https://github.com/Peakulorain/Practices.git 里的PatternClassification.py文件

另注:分类问题还是用softmax去做吧.....我只是用这做书上的练习而已。

(初学者水平有限,有问题请指出,各位大佬轻喷)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python连接MySQL、MongoDB、Redis、memcache等数据库的方法
Nov 15 Python
Python实现设置windows桌面壁纸代码分享
Mar 28 Python
win8下python3.4安装和环境配置图文教程
Jul 31 Python
django使用LDAP验证的方法示例
Dec 10 Python
对python中Librosa的mfcc步骤详解
Jan 09 Python
python3 打印输出字典中特定的某个key的方法示例
Jul 06 Python
解决python中导入win32com.client出错的问题
Jul 26 Python
python3 实现口罩抽签的功能
Mar 11 Python
django 实现手动存储文件到model的FileField
Mar 30 Python
关于python 的legend图例,参数使用说明
Apr 17 Python
Python基础之Socket通信原理
Apr 22 Python
Python 中的Sympy详细使用
Aug 07 Python
python使用Pycharm创建一个Django项目
Mar 05 #Python
python爬虫基本知识
Mar 05 #Python
用tensorflow构建线性回归模型的示例代码
Mar 05 #Python
详解python实现线程安全的单例模式
Mar 05 #Python
分析python动态规划的递归、非递归实现
Mar 04 #Python
python3.x上post发送json数据
Mar 04 #Python
python数据封装json格式数据
Mar 04 #Python
You might like
PHP+MYSQL 出现乱码的解决方法
2008/08/08 PHP
浅谈本地WAMP环境的搭建
2015/05/13 PHP
FCK调用方法..
2006/12/21 Javascript
javascript错误的认识不用关心内存管理
2012/12/15 Javascript
轻松创建nodejs服务器(5):事件处理程序
2014/12/18 NodeJs
JS实现从表格中动态删除指定行的方法
2015/03/31 Javascript
使用BootStrap建立响应式网页——通栏轮播图(carousel)
2016/12/21 Javascript
原生js实现网页顶部自动下拉/收缩广告效果
2017/01/20 Javascript
一篇看懂vuejs的状态管理神器 vuex状态管理模式
2017/04/20 Javascript
详解jQuery中的easyui
2018/09/02 jQuery
jquery3和layui冲突导致使用layui.layer.full弹出全屏iframe窗口时高度152px问题
2019/05/12 jQuery
微信小程序登录时如何获取input框中的内容
2019/12/04 Javascript
python正则表达式re模块详细介绍
2014/05/29 Python
Python实现将xml导入至excel
2015/11/20 Python
python爬虫框架scrapy实战之爬取京东商城进阶篇
2017/04/24 Python
Python 安装setuptools和pip工具操作方法(必看)
2017/05/22 Python
Python pymongo模块用法示例
2018/03/31 Python
解读python logging模块的使用方法
2018/04/17 Python
Python玩转加密的技巧【推荐】
2019/05/13 Python
Python实现决策树并且使用Graphviz可视化的例子
2019/08/09 Python
Python代码生成视频的缩略图的实例讲解
2019/12/22 Python
Python定义一个函数的方法
2020/06/15 Python
python中plt.imshow与cv2.imshow显示颜色问题
2020/07/16 Python
利用HTML5 Canvas制作一个简单的打飞机游戏
2015/05/11 HTML / CSS
英国第一蛋白粉品牌:Myprotein
2016/09/14 全球购物
Contém1g官网:巴西彩妆品牌
2020/01/17 全球购物
GWT的应用有哪两种部署模式
2012/12/21 面试题
机械制造与自动化应届生求职信
2013/11/16 职场文书
银行办理业务介绍信
2014/01/18 职场文书
大学生新学期计划书
2014/04/28 职场文书
儿童生日会策划方案
2014/05/15 职场文书
关于运动会的口号
2014/06/07 职场文书
经典毕业生求职信
2014/07/12 职场文书
发展党员工作情况汇报
2014/10/28 职场文书
回复函格式及范文
2015/07/14 职场文书
sql注入报错之注入原理实例解析
2022/06/10 MySQL