tensorflow建立一个简单的神经网络的方法


Posted in Python onFebruary 10, 2018

本笔记目的是通过tensorflow实现一个两层的神经网络。目的是实现一个二次函数的拟合。

如何添加一层网络

代码如下:

def add_layer(inputs, in_size, out_size, activation_function=None):
  # add one more layer and return the output of this layer
  Weights = tf.Variable(tf.random_normal([in_size, out_size]))
  biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
  Wx_plus_b = tf.matmul(inputs, Weights) + biases
  if activation_function is None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  return outputs

注意该函数中是xW+b,而不是Wx+b。所以要注意乘法的顺序。x应该定义为[类别数量, 数据数量], W定义为[数据类别,类别数量]。

创建一些数据

# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise

numpy的linspace函数能够产生等差数列。start,stop决定等差数列的起止值。endpoint参数指定包不包括终点值。

numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None)[source] 
Return evenly spaced numbers over a specified interval. 
Returns num evenly spaced samples, calculated over the interval [start, stop].

tensorflow建立一个简单的神经网络的方法

noise函数为添加噪声所用,这样二次函数的点不会与二次函数曲线完全重合。

numpy的newaxis可以新增一个维度而不需要重新创建相应的shape在赋值,非常方便,如上面的例子中就将x_data从一维变成了二维。

添加占位符,用作输入

# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])

添加隐藏层和输出层

# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

计算误差,并用梯度下降使得误差最小

# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

完整代码如下:

from __future__ import print_function
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

def add_layer(inputs, in_size, out_size, activation_function=None):
  # add one more layer and return the output of this layer
  Weights = tf.Variable(tf.random_normal([in_size, out_size]))
  biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
  Wx_plus_b = tf.matmul(inputs, Weights) + biases
  if activation_function is None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  return outputs

# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise

# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),
           reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

# important step
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

# plot the real data
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_data, y_data)
plt.ion()
plt.show()

for i in range(1000):
  # training
  sess.run(train_step, feed_dict={xs: x_data, ys: y_data})
  if i % 50 == 0:
    # to visualize the result and improvement
    try:
      ax.lines.remove(lines[0])
    except Exception:
      pass
    prediction_value = sess.run(prediction, feed_dict={xs: x_data})
    # plot the prediction
    lines = ax.plot(x_data, prediction_value, 'r-', lw=5)
    plt.pause(0.1)

运行结果:

tensorflow建立一个简单的神经网络的方法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python删除nginx缓存文件示例(python文件操作)
Mar 26 Python
跟老齐学Python之Import 模块
Oct 13 Python
各个系统下的Python解释器相关安装方法
Oct 12 Python
python+django加载静态网页模板解析
Dec 12 Python
python 批量添加的button 使用同一点击事件的方法
Jul 17 Python
python中sort和sorted排序的实例方法
Aug 26 Python
详解python中eval函数的作用
Oct 22 Python
python的pyecharts绘制各种图表详细(附代码)
Nov 11 Python
PyQT5 emit 和 connect的用法详解
Dec 13 Python
关于python pycharm中输出的内容不全的解决办法
Jan 10 Python
pytorch 使用加载训练好的模型做inference
Feb 20 Python
python中def是做什么的
Jun 10 Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
python3爬取淘宝信息代码分析
Feb 10 #Python
Python中property属性实例解析
Feb 10 #Python
Java编程迭代地删除文件夹及其下的所有文件实例
Feb 10 #Python
Python中协程用法代码详解
Feb 10 #Python
Python实现简单生成验证码功能【基于random模块】
Feb 10 #Python
You might like
PHP中trait使用方法详细介绍
2017/05/21 PHP
CentOS7系统搭建LAMP及更新PHP版本操作详解
2020/03/26 PHP
javascript 单例/单体模式(Singleton)
2011/04/07 Javascript
关于js日期转化为毫秒数“节省20%的效率和和节省9个字符“问题
2012/03/01 Javascript
javascript对talbe进行动态添加、删除、验证实现代码
2012/03/29 Javascript
javascript模拟地球旋转效果代码实例
2013/12/02 Javascript
JavaScript编程的10个实用小技巧
2014/04/18 Javascript
node.js入门教程
2014/06/01 Javascript
AngularJS入门教程(二):AngularJS模板
2014/12/06 Javascript
深入探寻seajs的模块化与加载方式
2015/04/14 Javascript
js实现拖拽效果(构造函数)
2015/12/14 Javascript
jQuery 3 中的新增功能汇总介绍
2016/06/12 Javascript
Bootstrap基本模板的使用和理解1
2016/12/14 Javascript
php register_shutdown_function函数详解
2017/07/23 Javascript
史上最全JavaScript常用的简写技巧(推荐)
2017/08/17 Javascript
JavaScript中var、let、const区别浅析
2018/06/24 Javascript
使用jQuery动态设置单选框的选中效果
2018/12/06 jQuery
vue打包之后生成一个配置文件修改接口的方法
2018/12/09 Javascript
javascript中的offsetWidth、clientWidth、innerWidth及相关属性方法
2020/05/14 Javascript
vue的hash值原理也是table切换实例代码
2020/12/14 Vue.js
Vue 组件注册全解析
2020/12/17 Vue.js
[36:33]Ti4 循环赛第四日 附加赛NEWBEE vs Mouz
2014/07/13 DOTA
Python os模块介绍
2014/11/30 Python
python使用SMTP发送qq或sina邮件
2017/10/21 Python
python取代netcat过程分析
2018/02/10 Python
PyQt5每天必学之布局管理
2018/04/19 Python
python多进程读图提取特征存npy
2019/05/21 Python
Pycharm Plugins加载失败问题解决方案
2020/11/28 Python
POP文化和音乐灵感的时尚:Hot Topic
2019/06/19 全球购物
护理工作感言
2014/01/16 职场文书
婚庆司仪主持词
2014/03/15 职场文书
共产党员公开承诺书范文
2014/03/28 职场文书
物业公司的岗位任命书
2014/06/06 职场文书
学生上课迟到检讨书
2015/01/01 职场文书
协议书格式模板
2016/03/24 职场文书
Mysql排序的特性详情
2021/11/01 MySQL