tensorflow建立一个简单的神经网络的方法


Posted in Python onFebruary 10, 2018

本笔记目的是通过tensorflow实现一个两层的神经网络。目的是实现一个二次函数的拟合。

如何添加一层网络

代码如下:

def add_layer(inputs, in_size, out_size, activation_function=None):
  # add one more layer and return the output of this layer
  Weights = tf.Variable(tf.random_normal([in_size, out_size]))
  biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
  Wx_plus_b = tf.matmul(inputs, Weights) + biases
  if activation_function is None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  return outputs

注意该函数中是xW+b,而不是Wx+b。所以要注意乘法的顺序。x应该定义为[类别数量, 数据数量], W定义为[数据类别,类别数量]。

创建一些数据

# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise

numpy的linspace函数能够产生等差数列。start,stop决定等差数列的起止值。endpoint参数指定包不包括终点值。

numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None)[source] 
Return evenly spaced numbers over a specified interval. 
Returns num evenly spaced samples, calculated over the interval [start, stop].

tensorflow建立一个简单的神经网络的方法

noise函数为添加噪声所用,这样二次函数的点不会与二次函数曲线完全重合。

numpy的newaxis可以新增一个维度而不需要重新创建相应的shape在赋值,非常方便,如上面的例子中就将x_data从一维变成了二维。

添加占位符,用作输入

# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])

添加隐藏层和输出层

# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

计算误差,并用梯度下降使得误差最小

# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

完整代码如下:

from __future__ import print_function
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

def add_layer(inputs, in_size, out_size, activation_function=None):
  # add one more layer and return the output of this layer
  Weights = tf.Variable(tf.random_normal([in_size, out_size]))
  biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
  Wx_plus_b = tf.matmul(inputs, Weights) + biases
  if activation_function is None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  return outputs

# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise

# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),
           reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

# important step
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

# plot the real data
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_data, y_data)
plt.ion()
plt.show()

for i in range(1000):
  # training
  sess.run(train_step, feed_dict={xs: x_data, ys: y_data})
  if i % 50 == 0:
    # to visualize the result and improvement
    try:
      ax.lines.remove(lines[0])
    except Exception:
      pass
    prediction_value = sess.run(prediction, feed_dict={xs: x_data})
    # plot the prediction
    lines = ax.plot(x_data, prediction_value, 'r-', lw=5)
    plt.pause(0.1)

运行结果:

tensorflow建立一个简单的神经网络的方法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python中的join()函数的用法
Apr 07 Python
在Python中使用swapCase()方法转换大小写的教程
May 20 Python
Python中字符串的修改及传参详解
Nov 30 Python
Python机器学习库scikit-learn安装与基本使用教程
Jun 25 Python
python 3.3 下载固定链接文件并保存的方法
Dec 18 Python
Python告诉你木马程序的键盘记录原理
Feb 02 Python
如何使用pyinstaller打包32位的exe程序
May 26 Python
python的pytest框架之命令行参数详解(下)
Jun 27 Python
python Django中models进行模糊查询的示例
Jul 18 Python
调用其他python脚本文件里面的类和方法过程解析
Nov 15 Python
python代码能做成软件吗
Jul 24 Python
selenium如何定位span元素的实现
Jan 13 Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
python3爬取淘宝信息代码分析
Feb 10 #Python
Python中property属性实例解析
Feb 10 #Python
Java编程迭代地删除文件夹及其下的所有文件实例
Feb 10 #Python
Python中协程用法代码详解
Feb 10 #Python
Python实现简单生成验证码功能【基于random模块】
Feb 10 #Python
You might like
PHP+MYSQL会员系统的登陆即权限判断实现代码
2011/09/23 PHP
深入理解ob_flush和flush的区别(ob_flush()与flush()使用方法)
2013/02/06 PHP
php加密算法之实现可逆加密算法和解密分享
2014/01/21 PHP
本地计算机无法启动Apache故障处理
2014/08/08 PHP
phpStorm+XDebug+chrome 配置详解
2019/04/01 PHP
Gambit vs ForZe BO3 第三场 2.13
2021/03/10 DOTA
jquery实现网站超链接和图片提示效果
2013/03/21 Javascript
查找Oracle高消耗语句的方法
2014/03/22 Javascript
js树插件zTree获取所有选中节点数据的方法
2015/01/28 Javascript
JavaScript 七大技巧(二)
2015/12/13 Javascript
全面解析Bootstrap布局组件应用
2016/02/22 Javascript
浅谈Node.js轻量级Web框架Express4.x使用指南
2017/05/03 Javascript
Vue利用路由钩子token过期后跳转到登录页的实例
2017/10/26 Javascript
webpack4.x下babel的安装、配置及使用详解
2019/03/07 Javascript
Python 的 with 语句详解
2014/06/13 Python
利用Python中的输入和输出功能进行读取和写入的教程
2015/04/14 Python
Windows下为Python安装Matplotlib模块
2015/11/06 Python
Python中强大的命令行库click入门教程
2016/12/26 Python
利用Python在一个文件的头部插入数据的实例
2018/05/02 Python
详解基于django实现的webssh简单例子
2018/07/17 Python
Python3爬虫之自动查询天气并实现语音播报
2019/02/21 Python
Python字符串中删除特定字符的方法
2020/01/15 Python
Python连接Oracle之环境配置、实例代码及报错解决方法详解
2020/02/11 Python
Django ORM 查询表中某列字段值的方法
2020/04/30 Python
Python Selenium异常处理的实例分析
2021/02/28 Python
HTML高亮关键字的实现代码
2018/10/22 HTML / CSS
西班牙在线宠物食品和配件商店:bitiba
2019/10/11 全球购物
Static Nested Class 和 Inner Class的不同
2013/11/28 面试题
JSP&Servlet技术面试题
2015/05/21 面试题
人事部主管岗位职责
2013/12/26 职场文书
史上最牛的辞职信
2015/02/28 职场文书
跳高加油稿
2015/07/21 职场文书
2016年大学生社区服务活动总结
2016/04/06 职场文书
pandas提升计算效率的一些方法汇总
2021/05/30 Python
MySQL中优化SQL语句的方法(show status、explain分析服务器状态信息)
2022/04/09 MySQL
Golang流模式之grpc的四种数据流
2022/04/13 Golang