tensorflow建立一个简单的神经网络的方法


Posted in Python onFebruary 10, 2018

本笔记目的是通过tensorflow实现一个两层的神经网络。目的是实现一个二次函数的拟合。

如何添加一层网络

代码如下:

def add_layer(inputs, in_size, out_size, activation_function=None):
  # add one more layer and return the output of this layer
  Weights = tf.Variable(tf.random_normal([in_size, out_size]))
  biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
  Wx_plus_b = tf.matmul(inputs, Weights) + biases
  if activation_function is None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  return outputs

注意该函数中是xW+b,而不是Wx+b。所以要注意乘法的顺序。x应该定义为[类别数量, 数据数量], W定义为[数据类别,类别数量]。

创建一些数据

# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise

numpy的linspace函数能够产生等差数列。start,stop决定等差数列的起止值。endpoint参数指定包不包括终点值。

numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None)[source] 
Return evenly spaced numbers over a specified interval. 
Returns num evenly spaced samples, calculated over the interval [start, stop].

tensorflow建立一个简单的神经网络的方法

noise函数为添加噪声所用,这样二次函数的点不会与二次函数曲线完全重合。

numpy的newaxis可以新增一个维度而不需要重新创建相应的shape在赋值,非常方便,如上面的例子中就将x_data从一维变成了二维。

添加占位符,用作输入

# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])

添加隐藏层和输出层

# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

计算误差,并用梯度下降使得误差最小

# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

完整代码如下:

from __future__ import print_function
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

def add_layer(inputs, in_size, out_size, activation_function=None):
  # add one more layer and return the output of this layer
  Weights = tf.Variable(tf.random_normal([in_size, out_size]))
  biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
  Wx_plus_b = tf.matmul(inputs, Weights) + biases
  if activation_function is None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  return outputs

# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise

# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),
           reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

# important step
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

# plot the real data
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_data, y_data)
plt.ion()
plt.show()

for i in range(1000):
  # training
  sess.run(train_step, feed_dict={xs: x_data, ys: y_data})
  if i % 50 == 0:
    # to visualize the result and improvement
    try:
      ax.lines.remove(lines[0])
    except Exception:
      pass
    prediction_value = sess.run(prediction, feed_dict={xs: x_data})
    # plot the prediction
    lines = ax.plot(x_data, prediction_value, 'r-', lw=5)
    plt.pause(0.1)

运行结果:

tensorflow建立一个简单的神经网络的方法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现同时兼容老版和新版Socket协议的一个简单WebSocket服务器
Jun 04 Python
常见的在Python中实现单例模式的三种方法
Apr 08 Python
Python中super的用法实例
May 28 Python
Python黑帽编程 3.4 跨越VLAN详解
Sep 28 Python
Python 在字符串中加入变量的实例讲解
May 02 Python
Python中一些不为人知的基础技巧总结
May 19 Python
Python字符串内置函数功能与用法总结
Apr 16 Python
python实现对象列表根据某个属性排序的方法详解
Jun 11 Python
解决pyinstaller打包发布后的exe文件打开控制台闪退的问题
Jun 21 Python
Python ADF 单位根检验 如何查看结果的实现
Jun 03 Python
Python 实现简单的客户端认证
Jul 29 Python
浅谈matplotlib默认字体设置探索
Feb 03 Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
python3爬取淘宝信息代码分析
Feb 10 #Python
Python中property属性实例解析
Feb 10 #Python
Java编程迭代地删除文件夹及其下的所有文件实例
Feb 10 #Python
Python中协程用法代码详解
Feb 10 #Python
Python实现简单生成验证码功能【基于random模块】
Feb 10 #Python
You might like
几种显示数据的方法的比较
2006/10/09 PHP
一键删除顽固的空文件夹 软件下载
2007/01/26 PHP
队列在编程中的实际应用(php)
2010/09/04 PHP
php代码收集表单内容并写入文件的代码
2012/01/29 PHP
PHP读取RSS(Feed)简单实例
2014/06/12 PHP
Laravel Memcached缓存驱动的配置与应用方法分析
2016/10/08 PHP
jquery 回车事件实现代码
2011/08/23 Javascript
jQuery点击后一组图片左右滑动的实现代码
2012/08/16 Javascript
JS弹出窗口代码大全(详细整理)
2012/12/21 Javascript
js 获取时间间隔实现代码
2014/05/12 Javascript
浅谈 jQuery 事件源码定位问题
2014/06/18 Javascript
js使用for循环及if语句判断多个一样的name
2014/09/09 Javascript
一款基jquery超炫的动画导航菜单可响应单击事件
2014/11/02 Javascript
浅谈利用JavaScript进行的DDoS攻击原理与防御
2015/06/04 Javascript
基于insertBefore制作简单的循环插空效果
2015/09/21 Javascript
jQuery实现带玻璃流光质感的手风琴特效
2015/11/20 Javascript
JS仿京东移动端手指拨动切换轮播图效果
2020/04/10 Javascript
angular2 组件之间通过service互相传递的实例
2018/09/30 Javascript
elementUI vue this.$confirm 和el-dialog 弹出框 移动 示例demo
2019/07/03 Javascript
selenium 反爬虫之跳过淘宝滑块验证功能的实现代码
2020/08/27 Javascript
pandas 对group进行聚合的例子
2019/12/27 Python
python实现用户名密码校验
2020/03/18 Python
Python常用扩展插件使用教程解析
2020/11/02 Python
PyCharm Community安装与配置的详细教程
2020/11/24 Python
python asyncio 协程库的使用
2021/01/21 Python
巴西婴儿用品商店:Bebe Store
2017/11/23 全球购物
汉森批发:Hansen Wholesale
2018/05/24 全球购物
什么是ESB?请介绍一下ESB?
2015/05/27 面试题
教师研修随笔感言
2014/01/23 职场文书
教师工作自我鉴定范文
2014/09/14 职场文书
2015年税务稽查工作总结
2015/05/26 职场文书
2016党员三严三实心得体会
2016/01/15 职场文书
oracle连接ODBC sqlserver数据源的详细步骤
2021/07/25 Oracle
2021好看的国漫排行榜前十名 《完美世界》上榜,《元龙》排名第一
2022/03/18 国漫
开发者首先否认《遗弃》被取消的传言
2022/04/11 其他游戏
el-table-column 内容不自动换行的解决方法
2022/08/14 Vue.js