tensorflow建立一个简单的神经网络的方法


Posted in Python onFebruary 10, 2018

本笔记目的是通过tensorflow实现一个两层的神经网络。目的是实现一个二次函数的拟合。

如何添加一层网络

代码如下:

def add_layer(inputs, in_size, out_size, activation_function=None):
  # add one more layer and return the output of this layer
  Weights = tf.Variable(tf.random_normal([in_size, out_size]))
  biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
  Wx_plus_b = tf.matmul(inputs, Weights) + biases
  if activation_function is None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  return outputs

注意该函数中是xW+b,而不是Wx+b。所以要注意乘法的顺序。x应该定义为[类别数量, 数据数量], W定义为[数据类别,类别数量]。

创建一些数据

# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise

numpy的linspace函数能够产生等差数列。start,stop决定等差数列的起止值。endpoint参数指定包不包括终点值。

numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None)[source] 
Return evenly spaced numbers over a specified interval. 
Returns num evenly spaced samples, calculated over the interval [start, stop].

tensorflow建立一个简单的神经网络的方法

noise函数为添加噪声所用,这样二次函数的点不会与二次函数曲线完全重合。

numpy的newaxis可以新增一个维度而不需要重新创建相应的shape在赋值,非常方便,如上面的例子中就将x_data从一维变成了二维。

添加占位符,用作输入

# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])

添加隐藏层和输出层

# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

计算误差,并用梯度下降使得误差最小

# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

完整代码如下:

from __future__ import print_function
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

def add_layer(inputs, in_size, out_size, activation_function=None):
  # add one more layer and return the output of this layer
  Weights = tf.Variable(tf.random_normal([in_size, out_size]))
  biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
  Wx_plus_b = tf.matmul(inputs, Weights) + biases
  if activation_function is None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  return outputs

# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise

# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),
           reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

# important step
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

# plot the real data
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_data, y_data)
plt.ion()
plt.show()

for i in range(1000):
  # training
  sess.run(train_step, feed_dict={xs: x_data, ys: y_data})
  if i % 50 == 0:
    # to visualize the result and improvement
    try:
      ax.lines.remove(lines[0])
    except Exception:
      pass
    prediction_value = sess.run(prediction, feed_dict={xs: x_data})
    # plot the prediction
    lines = ax.plot(x_data, prediction_value, 'r-', lw=5)
    plt.pause(0.1)

运行结果:

tensorflow建立一个简单的神经网络的方法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
举例讲解Django中数据模型访问外键值的方法
Jul 21 Python
Python程序中设置HTTP代理
Nov 06 Python
python基础教程之五种数据类型详解
Jan 12 Python
总结python实现父类调用两种方法的不同
Jan 15 Python
Numpy 将二维图像矩阵转换为一维向量的方法
Jun 05 Python
python之线程通过信号pyqtSignal刷新ui的方法
Jan 11 Python
Python使用sqlalchemy模块连接数据库操作示例
Mar 13 Python
Python后台开发Django会话控制的实现
Apr 15 Python
简单了解python中的与或非运算
Sep 18 Python
Pycharm+Python+PyQt5使用详解
Sep 25 Python
基于python实现FTP文件上传与下载操作(ftp&sftp协议)
Apr 01 Python
Win10下用Anaconda安装TensorFlow(图文教程)
Jun 18 Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
python3爬取淘宝信息代码分析
Feb 10 #Python
Python中property属性实例解析
Feb 10 #Python
Java编程迭代地删除文件夹及其下的所有文件实例
Feb 10 #Python
Python中协程用法代码详解
Feb 10 #Python
Python实现简单生成验证码功能【基于random模块】
Feb 10 #Python
You might like
PHP获取网卡地址的代码
2008/04/09 PHP
php adodb连接mssql解决乱码问题
2009/06/12 PHP
PHP url 加密解密函数代码
2011/08/26 PHP
ThinkPHP使用smarty模板引擎的方法
2014/07/01 PHP
PHP通过内置函数memory_get_usage()获取内存使用情况
2014/11/20 PHP
[原创]php正则删除img标签的方法示例
2017/05/27 PHP
PHP ob缓存以及ob函数原理实例解析
2020/11/13 PHP
JS实现仿微博可关闭弹出层效果
2015/09/21 Javascript
js验证真实姓名与身份证号是否匹配
2015/10/13 Javascript
js实现网页的两个input标签内的数值加减(示例代码)
2017/08/15 Javascript
浅谈Node Inspector 代理实现
2017/10/19 Javascript
JS/jQuery实现DIV延时几秒后消失或显示的方法
2018/02/12 jQuery
vue组件jsx语法的具体使用
2018/05/21 Javascript
vue.js使用watch监听路由变化的方法
2018/07/08 Javascript
基于jQuery拖拽事件的封装
2020/11/29 jQuery
Python黑帽编程 3.4 跨越VLAN详解
2016/09/28 Python
详解Python在七牛云平台的应用(一)
2017/12/05 Python
Python实现简单网页图片抓取完整代码实例
2017/12/15 Python
Python冲顶大会 快来答题!
2018/01/17 Python
Python/ArcPy遍历指定目录中的MDB文件方法
2018/10/27 Python
pytorch载入预训练模型后,实现训练指定层
2020/01/06 Python
如何在django中实现分页功能
2020/04/22 Python
class类在python中获取金融数据的实例方法
2020/12/10 Python
HTML5 Canvas图像模糊完美解决办法
2018/02/06 HTML / CSS
华为慧通笔试题
2016/04/22 面试题
Java中会存在内存泄漏吗,请简单描述
2016/12/22 面试题
我们在web应用开发过程中经常遇到输出某种编码的字符,如iso8859-1等,如何输出一个某种编码的字符串?
2014/03/30 面试题
工程开工庆典邀请函
2014/02/01 职场文书
大专毕业自我鉴定
2014/02/04 职场文书
三关爱志愿服务活动方案
2014/08/17 职场文书
2014年财政局工作总结
2014/12/09 职场文书
2015年安全生产工作总结范文
2015/04/02 职场文书
2015年医药代表工作总结
2015/04/25 职场文书
2016中学教师读书心得体会
2016/01/13 职场文书
原生JS中应该禁止出现的写法
2021/05/05 Javascript
MySQL外键约束(Foreign Key)案例详解
2022/06/28 MySQL