tensorflow建立一个简单的神经网络的方法


Posted in Python onFebruary 10, 2018

本笔记目的是通过tensorflow实现一个两层的神经网络。目的是实现一个二次函数的拟合。

如何添加一层网络

代码如下:

def add_layer(inputs, in_size, out_size, activation_function=None):
  # add one more layer and return the output of this layer
  Weights = tf.Variable(tf.random_normal([in_size, out_size]))
  biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
  Wx_plus_b = tf.matmul(inputs, Weights) + biases
  if activation_function is None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  return outputs

注意该函数中是xW+b,而不是Wx+b。所以要注意乘法的顺序。x应该定义为[类别数量, 数据数量], W定义为[数据类别,类别数量]。

创建一些数据

# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise

numpy的linspace函数能够产生等差数列。start,stop决定等差数列的起止值。endpoint参数指定包不包括终点值。

numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None)[source] 
Return evenly spaced numbers over a specified interval. 
Returns num evenly spaced samples, calculated over the interval [start, stop].

tensorflow建立一个简单的神经网络的方法

noise函数为添加噪声所用,这样二次函数的点不会与二次函数曲线完全重合。

numpy的newaxis可以新增一个维度而不需要重新创建相应的shape在赋值,非常方便,如上面的例子中就将x_data从一维变成了二维。

添加占位符,用作输入

# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])

添加隐藏层和输出层

# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

计算误差,并用梯度下降使得误差最小

# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

完整代码如下:

from __future__ import print_function
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

def add_layer(inputs, in_size, out_size, activation_function=None):
  # add one more layer and return the output of this layer
  Weights = tf.Variable(tf.random_normal([in_size, out_size]))
  biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
  Wx_plus_b = tf.matmul(inputs, Weights) + biases
  if activation_function is None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  return outputs

# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise

# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),
           reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

# important step
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

# plot the real data
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_data, y_data)
plt.ion()
plt.show()

for i in range(1000):
  # training
  sess.run(train_step, feed_dict={xs: x_data, ys: y_data})
  if i % 50 == 0:
    # to visualize the result and improvement
    try:
      ax.lines.remove(lines[0])
    except Exception:
      pass
    prediction_value = sess.run(prediction, feed_dict={xs: x_data})
    # plot the prediction
    lines = ax.plot(x_data, prediction_value, 'r-', lw=5)
    plt.pause(0.1)

运行结果:

tensorflow建立一个简单的神经网络的方法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 基础学习第二弹 类属性和实例属性
Aug 27 Python
Linux下通过python访问MySQL、Oracle、SQL Server数据库的方法
Apr 23 Python
python解决Fedora解压zip时中文乱码的方法
Sep 18 Python
Python中实现最小二乘法思路及实现代码
Jan 04 Python
python+tkinter编写电脑桌面放大镜程序实例代码
Jan 16 Python
基于Django URL传参 FORM表单传数据 get post的用法实例
May 28 Python
Python操作mongodb的9个步骤
Jun 04 Python
Python读取Pickle文件信息并计算与当前时间间隔的方法分析
Jan 30 Python
python实现的多任务版udp聊天器功能案例
Nov 13 Python
python实现五子棋游戏(pygame版)
Jan 19 Python
基于Python第三方插件实现西游记章节标注汉语拼音的方法
May 22 Python
python中列表的含义及用法
May 26 Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
python3爬取淘宝信息代码分析
Feb 10 #Python
Python中property属性实例解析
Feb 10 #Python
Java编程迭代地删除文件夹及其下的所有文件实例
Feb 10 #Python
Python中协程用法代码详解
Feb 10 #Python
Python实现简单生成验证码功能【基于random模块】
Feb 10 #Python
You might like
php强制用户转向www域名的方法
2015/06/19 PHP
PHP静态延迟绑定和普通静态效率的对比
2017/10/20 PHP
Date对象格式化函数代码
2010/07/17 Javascript
jquery中$each()方法的使用指南
2015/04/30 Javascript
使用AngularJS实现可伸缩的页面切换的方法
2015/06/19 Javascript
Bootstrap每天必学之导航条
2015/11/27 Javascript
js操作数组函数实例小结
2015/12/10 Javascript
AngularJS 最常用的功能汇总
2016/02/17 Javascript
Vue-resource实现ajax请求和跨域请求示例
2017/02/23 Javascript
js实现3D图片环展示效果
2017/03/09 Javascript
javascript实现多张图片左右无缝滚动效果
2017/03/22 Javascript
10道典型的JavaScript面试题
2017/03/22 Javascript
微信小程序开发之map地图实现教程
2017/06/08 Javascript
在一般处理程序(ashx)中弹出js提示语
2017/08/16 Javascript
vue项目中在外部js文件中直接调用vue实例的方法比如说this
2019/04/28 Javascript
JS随机密码生成算法
2019/09/23 Javascript
Vue执行方法,方法获取data值,设置data值,方法传值操作
2020/08/05 Javascript
[01:08:29]DOTA2-DPC中国联赛定级赛 RNG vs Aster BO3第一场 1月9日
2021/03/11 DOTA
Python中常见的数据类型小结
2015/08/29 Python
好用的Python编辑器WingIDE的使用经验总结
2016/08/31 Python
Python实现查找匹配项作处理后再替换回去的方法
2017/06/10 Python
python2.7安装图文教程
2018/03/13 Python
python+opencv识别图片中的圆形
2020/03/25 Python
python requests 库请求带有文件参数的接口实例
2019/01/03 Python
基于torch.where和布尔索引的速度比较
2020/01/02 Python
用 Python 制作地球仪的方法
2020/04/24 Python
利用Pycharm + Django搭建一个简单Python Web项目的步骤
2020/10/22 Python
python中round函数保留两位小数的方法
2020/12/04 Python
网络艺术零售业的先驱者:artrepublic
2017/09/26 全球购物
英国领先的男装设计师服装独立零售商:Repertoire Fashion
2020/10/19 全球购物
三年级评语大全
2014/04/23 职场文书
网站创业计划书
2014/04/30 职场文书
2014年客户经理工作总结
2014/11/20 职场文书
2015年个人现实表现材料
2014/12/10 职场文书
Java基础-封装和继承
2021/07/02 Java/Android
CSS+HTML 实现顶部导航栏功能
2021/08/30 HTML / CSS