tensorflow建立一个简单的神经网络的方法


Posted in Python onFebruary 10, 2018

本笔记目的是通过tensorflow实现一个两层的神经网络。目的是实现一个二次函数的拟合。

如何添加一层网络

代码如下:

def add_layer(inputs, in_size, out_size, activation_function=None):
  # add one more layer and return the output of this layer
  Weights = tf.Variable(tf.random_normal([in_size, out_size]))
  biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
  Wx_plus_b = tf.matmul(inputs, Weights) + biases
  if activation_function is None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  return outputs

注意该函数中是xW+b,而不是Wx+b。所以要注意乘法的顺序。x应该定义为[类别数量, 数据数量], W定义为[数据类别,类别数量]。

创建一些数据

# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise

numpy的linspace函数能够产生等差数列。start,stop决定等差数列的起止值。endpoint参数指定包不包括终点值。

numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None)[source] 
Return evenly spaced numbers over a specified interval. 
Returns num evenly spaced samples, calculated over the interval [start, stop].

tensorflow建立一个简单的神经网络的方法

noise函数为添加噪声所用,这样二次函数的点不会与二次函数曲线完全重合。

numpy的newaxis可以新增一个维度而不需要重新创建相应的shape在赋值,非常方便,如上面的例子中就将x_data从一维变成了二维。

添加占位符,用作输入

# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])

添加隐藏层和输出层

# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

计算误差,并用梯度下降使得误差最小

# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

完整代码如下:

from __future__ import print_function
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

def add_layer(inputs, in_size, out_size, activation_function=None):
  # add one more layer and return the output of this layer
  Weights = tf.Variable(tf.random_normal([in_size, out_size]))
  biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
  Wx_plus_b = tf.matmul(inputs, Weights) + biases
  if activation_function is None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  return outputs

# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise

# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),
           reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

# important step
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

# plot the real data
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_data, y_data)
plt.ion()
plt.show()

for i in range(1000):
  # training
  sess.run(train_step, feed_dict={xs: x_data, ys: y_data})
  if i % 50 == 0:
    # to visualize the result and improvement
    try:
      ax.lines.remove(lines[0])
    except Exception:
      pass
    prediction_value = sess.run(prediction, feed_dict={xs: x_data})
    # plot the prediction
    lines = ax.plot(x_data, prediction_value, 'r-', lw=5)
    plt.pause(0.1)

运行结果:

tensorflow建立一个简单的神经网络的方法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python冒泡排序算法的实现代码
Nov 21 Python
python开发中module模块用法实例分析
Nov 12 Python
python微信跳一跳系列之棋子定位颜色识别
Feb 26 Python
python如何统计序列中元素
Jul 31 Python
基于MTCNN/TensorFlow实现人脸检测
May 24 Python
对python中GUI,Label和Button的实例详解
Jun 27 Python
Python 运行.py文件和交互式运行代码的区别详解
Jul 02 Python
python调用webservice接口的实现
Jul 12 Python
基于python3 pyQt5 QtDesignner实现窗口化猜数字游戏功能
Jul 15 Python
python 多进程共享全局变量之Manager()详解
Aug 15 Python
基于python操作ES实例详解
Nov 16 Python
关于python3.7安装matplotlib始终无法成功的问题的解决
Jul 28 Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
python3爬取淘宝信息代码分析
Feb 10 #Python
Python中property属性实例解析
Feb 10 #Python
Java编程迭代地删除文件夹及其下的所有文件实例
Feb 10 #Python
Python中协程用法代码详解
Feb 10 #Python
Python实现简单生成验证码功能【基于random模块】
Feb 10 #Python
You might like
php常用ODBC函数集(详细)
2013/06/24 PHP
php字符编码转换之gb2312转为utf8
2013/10/28 PHP
PHP程序漏洞产生的原因分析与防范方法说明
2014/03/06 PHP
浅析PHP 中move_uploaded_file 上传中文文件名失败
2019/04/17 PHP
PHP优化之批量操作MySQL实例分析
2020/04/23 PHP
通过action传过来的值在option获取进行验证的方法
2013/11/14 Javascript
JS点击链接后慢慢展开隐藏着图片的方法
2015/02/17 Javascript
JavaScript图片轮播代码分享
2015/07/31 Javascript
JavaScript+CSS无限极分类效果完整实现方法
2015/12/22 Javascript
JS使用正则截取两个字符串之间的字符串实现方法详解
2017/01/06 Javascript
jQuery使用unlock.js插件实现滑动解锁
2017/04/04 jQuery
原生JS获取元素的位置与尺寸实现方法
2017/10/18 Javascript
Angular2.0/4.0 使用Echarts图表的示例代码
2017/12/07 Javascript
Node Puppeteer图像识别实现百度指数爬虫的示例
2018/02/22 Javascript
vue使用自定义icon图标的方法
2018/05/14 Javascript
微信小程序保存多张图片的实现方法
2019/03/05 Javascript
VUE前端从后台请求过来的数据进行转换数据结构操作
2020/11/11 Javascript
绘制微信小程序验证码功能的实例代码
2021/01/05 Javascript
Python中pygame安装方法图文详解
2015/11/11 Python
Django验证码的生成与使用示例
2017/05/20 Python
Python基于win32ui模块创建弹出式菜单示例
2018/05/09 Python
对pycharm代码整体左移和右移缩进快捷键的介绍
2018/07/16 Python
Python计算一个点到所有点的欧式距离实现方法
2019/07/04 Python
anaconda中更改python版本的方法步骤
2019/07/14 Python
Python操作Excel工作簿的示例代码(\*.xlsx)
2020/03/23 Python
详解CSS3阴影 box-shadow的使用和技巧总结
2016/12/03 HTML / CSS
Html5原生拖拽相关事件简介以及基础实现
2020/11/19 HTML / CSS
Raleigh兰令自行车美国官网:英国凤头牌自行车
2018/01/08 全球购物
Opodo英国旅游网站:预订廉价航班、酒店和汽车租赁
2018/07/14 全球购物
简述你对Statement,PreparedStatement,CallableStatement的理解
2013/03/25 面试题
青年文明号事迹材料
2014/01/18 职场文书
教师申诉制度
2014/01/29 职场文书
竞选学委演讲稿
2014/09/13 职场文书
校运会班级霸气口号
2015/12/24 职场文书
晶体管来复再生式二管收音机
2021/04/22 无线电
判断Python中的Nonetype类型
2021/05/25 Python