tensorflow建立一个简单的神经网络的方法


Posted in Python onFebruary 10, 2018

本笔记目的是通过tensorflow实现一个两层的神经网络。目的是实现一个二次函数的拟合。

如何添加一层网络

代码如下:

def add_layer(inputs, in_size, out_size, activation_function=None):
  # add one more layer and return the output of this layer
  Weights = tf.Variable(tf.random_normal([in_size, out_size]))
  biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
  Wx_plus_b = tf.matmul(inputs, Weights) + biases
  if activation_function is None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  return outputs

注意该函数中是xW+b,而不是Wx+b。所以要注意乘法的顺序。x应该定义为[类别数量, 数据数量], W定义为[数据类别,类别数量]。

创建一些数据

# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise

numpy的linspace函数能够产生等差数列。start,stop决定等差数列的起止值。endpoint参数指定包不包括终点值。

numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None)[source] 
Return evenly spaced numbers over a specified interval. 
Returns num evenly spaced samples, calculated over the interval [start, stop].

tensorflow建立一个简单的神经网络的方法

noise函数为添加噪声所用,这样二次函数的点不会与二次函数曲线完全重合。

numpy的newaxis可以新增一个维度而不需要重新创建相应的shape在赋值,非常方便,如上面的例子中就将x_data从一维变成了二维。

添加占位符,用作输入

# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])

添加隐藏层和输出层

# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

计算误差,并用梯度下降使得误差最小

# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

完整代码如下:

from __future__ import print_function
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

def add_layer(inputs, in_size, out_size, activation_function=None):
  # add one more layer and return the output of this layer
  Weights = tf.Variable(tf.random_normal([in_size, out_size]))
  biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
  Wx_plus_b = tf.matmul(inputs, Weights) + biases
  if activation_function is None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  return outputs

# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise

# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),
           reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

# important step
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

# plot the real data
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_data, y_data)
plt.ion()
plt.show()

for i in range(1000):
  # training
  sess.run(train_step, feed_dict={xs: x_data, ys: y_data})
  if i % 50 == 0:
    # to visualize the result and improvement
    try:
      ax.lines.remove(lines[0])
    except Exception:
      pass
    prediction_value = sess.run(prediction, feed_dict={xs: x_data})
    # plot the prediction
    lines = ax.plot(x_data, prediction_value, 'r-', lw=5)
    plt.pause(0.1)

运行结果:

tensorflow建立一个简单的神经网络的方法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
从零学python系列之浅谈pickle模块封装和拆封数据对象的方法
May 23 Python
Python学习笔记整理3之输入输出、python eval函数
Dec 14 Python
python虚拟环境virtualenv的安装与使用
Sep 21 Python
wxPython之解决闪烁的问题
Jan 15 Python
详谈python在windows中的文件路径问题
Apr 28 Python
python得到qq句柄,并显示在前台的方法
Oct 14 Python
详解python3安装pillow后报错没有pillow模块以及没有PIL模块问题解决
Apr 17 Python
在linux下实现 python 监控usb设备信号
Jul 03 Python
详解PyTorch手写数字识别(MNIST数据集)
Aug 16 Python
Tensorflow之梯度裁剪的实现示例
Mar 08 Python
解决Ubuntu18中的pycharm不能调用tensorflow-gpu的问题
Sep 17 Python
如何使用PyCharm引入需要使用的包的方法
Sep 22 Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
python3爬取淘宝信息代码分析
Feb 10 #Python
Python中property属性实例解析
Feb 10 #Python
Java编程迭代地删除文件夹及其下的所有文件实例
Feb 10 #Python
Python中协程用法代码详解
Feb 10 #Python
Python实现简单生成验证码功能【基于random模块】
Feb 10 #Python
You might like
PHP提示Deprecated: mysql_connect(): The mysql extension is deprecated的解决方法
2014/08/28 PHP
php操作csv文件代码实例汇总
2014/09/22 PHP
php实现网站文件批量压缩下载功能
2015/10/28 PHP
如何直接访问php实例对象中的private属性详解
2017/10/12 PHP
JavaScript实现禁止后退的方法
2006/12/27 Javascript
javascript中运用闭包和自执行函数解决大量的全局变量问题
2010/12/30 Javascript
jQuery实现页面滚动时动态加载内容的方法
2015/03/20 Javascript
学习JavaScript设计模式之迭代器模式
2016/01/19 Javascript
Javascript 跨域知识详细介绍
2016/10/30 Javascript
Vue.js 2.0窥探之Virtual DOM到底是什么?
2017/02/10 Javascript
js获取一组日期中最近连续的天数
2017/05/25 Javascript
JavaScript重复元素处理方法分析【统计个数、计算、去重复等】
2017/12/14 Javascript
Node.js使用cookie保持登录的方法
2018/05/11 Javascript
JS编写兼容IE6,7,8浏览器无缝自动轮播
2018/10/12 Javascript
vue实现简单瀑布流布局
2020/05/28 Javascript
openlayers实现地图测距测面
2020/09/25 Javascript
[36:33]完美世界DOTA2联赛循环赛 Matador vs Forest 第一场 11.06
2020/11/06 DOTA
Python实现在线程里运行scrapy的方法
2015/04/07 Python
Python打印scrapy蜘蛛抓取树结构的方法
2015/04/08 Python
python写入xml文件的方法
2015/05/08 Python
python itchat实现微信好友头像拼接图的示例代码
2017/08/14 Python
在python中对变量判断是否为None的三种方法总结
2019/01/23 Python
python3实现mysql导出excel的方法
2019/07/31 Python
wxPython电子表格功能wx.grid实例教程
2019/11/19 Python
调用HTML5的Canvas API绘制图形的快速入门指南
2016/06/17 HTML / CSS
GLAMGLOW格莱魅美国官网:美国知名的面膜品牌
2016/12/31 全球购物
武汉瑞得软件笔试题
2015/10/27 面试题
幼儿教师师德承诺书
2014/05/23 职场文书
爱的奉献演讲稿
2014/09/10 职场文书
2015年助理政工师工作总结
2015/05/26 职场文书
美丽心灵观后感
2015/06/01 职场文书
检举信的写法
2019/04/10 职场文书
python 如何在 Matplotlib 中绘制垂直线
2021/04/02 Python
基于JavaScript实现年月日三级联动
2021/06/22 Javascript
Python requests用法和django后台处理详解
2022/03/19 Python
Python循环之while无限迭代
2022/04/30 Python