tensorflow建立一个简单的神经网络的方法


Posted in Python onFebruary 10, 2018

本笔记目的是通过tensorflow实现一个两层的神经网络。目的是实现一个二次函数的拟合。

如何添加一层网络

代码如下:

def add_layer(inputs, in_size, out_size, activation_function=None):
  # add one more layer and return the output of this layer
  Weights = tf.Variable(tf.random_normal([in_size, out_size]))
  biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
  Wx_plus_b = tf.matmul(inputs, Weights) + biases
  if activation_function is None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  return outputs

注意该函数中是xW+b,而不是Wx+b。所以要注意乘法的顺序。x应该定义为[类别数量, 数据数量], W定义为[数据类别,类别数量]。

创建一些数据

# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise

numpy的linspace函数能够产生等差数列。start,stop决定等差数列的起止值。endpoint参数指定包不包括终点值。

numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None)[source] 
Return evenly spaced numbers over a specified interval. 
Returns num evenly spaced samples, calculated over the interval [start, stop].

tensorflow建立一个简单的神经网络的方法

noise函数为添加噪声所用,这样二次函数的点不会与二次函数曲线完全重合。

numpy的newaxis可以新增一个维度而不需要重新创建相应的shape在赋值,非常方便,如上面的例子中就将x_data从一维变成了二维。

添加占位符,用作输入

# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])

添加隐藏层和输出层

# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

计算误差,并用梯度下降使得误差最小

# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

完整代码如下:

from __future__ import print_function
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

def add_layer(inputs, in_size, out_size, activation_function=None):
  # add one more layer and return the output of this layer
  Weights = tf.Variable(tf.random_normal([in_size, out_size]))
  biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
  Wx_plus_b = tf.matmul(inputs, Weights) + biases
  if activation_function is None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  return outputs

# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise

# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),
           reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

# important step
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

# plot the real data
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_data, y_data)
plt.ion()
plt.show()

for i in range(1000):
  # training
  sess.run(train_step, feed_dict={xs: x_data, ys: y_data})
  if i % 50 == 0:
    # to visualize the result and improvement
    try:
      ax.lines.remove(lines[0])
    except Exception:
      pass
    prediction_value = sess.run(prediction, feed_dict={xs: x_data})
    # plot the prediction
    lines = ax.plot(x_data, prediction_value, 'r-', lw=5)
    plt.pause(0.1)

运行结果:

tensorflow建立一个简单的神经网络的方法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用wxpython实现的一个简单图片浏览器实例
Jul 10 Python
Python读取环境变量的方法和自定义类分享
Nov 22 Python
简化Python的Django框架代码的一些示例
Apr 20 Python
window下eclipse安装python插件教程
Apr 24 Python
python 生成器协程运算实例
Sep 04 Python
Linux下远程连接Jupyter+pyspark部署教程
Jun 21 Python
解决pycharm remote deployment 配置的问题
Jun 27 Python
python gdal安装与简单使用
Aug 01 Python
python爬虫 Pyppeteer使用方法解析
Sep 28 Python
python 5个顶级异步框架推荐
Sep 09 Python
OpenCV-Python实现轮廓拟合
Jun 08 Python
Pytorch可视化的几种实现方法
Jun 10 Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
python3爬取淘宝信息代码分析
Feb 10 #Python
Python中property属性实例解析
Feb 10 #Python
Java编程迭代地删除文件夹及其下的所有文件实例
Feb 10 #Python
Python中协程用法代码详解
Feb 10 #Python
Python实现简单生成验证码功能【基于random模块】
Feb 10 #Python
You might like
PHP set_error_handler()函数使用详解(示例)
2013/11/12 PHP
Yii2 输出xml格式数据的方法
2016/05/03 PHP
修改yii2.0用户登录使用的user表为其它的表实现方法(推荐)
2017/08/01 PHP
翻译整理的jQuery使用查询手册
2007/03/07 Javascript
IE8 引入跨站数据获取功能说明
2008/07/22 Javascript
锋利的jQuery 第三章章节总结的例子
2010/03/23 Javascript
从零开始学习jQuery (四) jQuery中操作元素的属性与样式
2011/02/23 Javascript
js 中的switch表达式使用示例
2020/06/03 Javascript
JavaScript中的prototype.bind()方法介绍
2014/04/04 Javascript
基于javascript如何传递特殊字符
2015/11/30 Javascript
js实现图片轮播效果
2015/12/19 Javascript
jQuery EasyUI 入门必看
2016/06/03 Javascript
运用js教你轻松制作html音乐播放器
2020/04/17 Javascript
Bootstrap源码解读表单(2)
2016/12/22 Javascript
Vue2.0表单校验组件vee-validate的使用详解
2017/05/02 Javascript
JS原生带小白点轮播图实例讲解
2017/07/22 Javascript
自定义类似于jQuery UI Selectable 的Vue指令v-selectable
2017/08/23 jQuery
Bootstrap4 gulp 配置详解
2019/01/06 Javascript
webpack 代码分离优化快速指北
2019/05/18 Javascript
聊聊Vue 中 title 的动态修改问题
2019/06/11 Javascript
Element Cascader 级联选择器的使用示例
2020/07/27 Javascript
用Python编写一个国际象棋AI程序
2014/11/28 Python
python删除指定类型(或非指定)的文件实例详解
2015/07/06 Python
在arcgis使用python脚本进行字段计算时是如何解决中文问题的
2015/10/18 Python
Python如何爬取实时变化的WebSocket数据的方法
2019/03/09 Python
Python shutil模块用法实例分析
2019/10/02 Python
python 双循环遍历list 变量判断代码
2020/05/04 Python
将keras的h5模型转换为tensorflow的pb模型操作
2020/05/25 Python
基于TensorFlow的CNN实现Mnist手写数字识别
2020/06/17 Python
美国运动鞋和运动服零售商:Footaction
2017/04/07 全球购物
世界上最受欢迎的花店:1-800-Flowers.com
2020/06/01 全球购物
工程质量月活动方案
2014/02/19 职场文书
湖南省党的群众路线教育实践活动总结会议新闻稿
2014/10/21 职场文书
2014小学教师年度考核工作总结
2014/12/03 职场文书
部门经理迟到检讨书
2015/02/16 职场文书
党支部工作总结2015
2015/04/01 职场文书