tensorflow建立一个简单的神经网络的方法


Posted in Python onFebruary 10, 2018

本笔记目的是通过tensorflow实现一个两层的神经网络。目的是实现一个二次函数的拟合。

如何添加一层网络

代码如下:

def add_layer(inputs, in_size, out_size, activation_function=None):
  # add one more layer and return the output of this layer
  Weights = tf.Variable(tf.random_normal([in_size, out_size]))
  biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
  Wx_plus_b = tf.matmul(inputs, Weights) + biases
  if activation_function is None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  return outputs

注意该函数中是xW+b,而不是Wx+b。所以要注意乘法的顺序。x应该定义为[类别数量, 数据数量], W定义为[数据类别,类别数量]。

创建一些数据

# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise

numpy的linspace函数能够产生等差数列。start,stop决定等差数列的起止值。endpoint参数指定包不包括终点值。

numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None)[source] 
Return evenly spaced numbers over a specified interval. 
Returns num evenly spaced samples, calculated over the interval [start, stop].

tensorflow建立一个简单的神经网络的方法

noise函数为添加噪声所用,这样二次函数的点不会与二次函数曲线完全重合。

numpy的newaxis可以新增一个维度而不需要重新创建相应的shape在赋值,非常方便,如上面的例子中就将x_data从一维变成了二维。

添加占位符,用作输入

# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])

添加隐藏层和输出层

# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

计算误差,并用梯度下降使得误差最小

# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

完整代码如下:

from __future__ import print_function
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

def add_layer(inputs, in_size, out_size, activation_function=None):
  # add one more layer and return the output of this layer
  Weights = tf.Variable(tf.random_normal([in_size, out_size]))
  biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
  Wx_plus_b = tf.matmul(inputs, Weights) + biases
  if activation_function is None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  return outputs

# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise

# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),
           reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

# important step
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

# plot the real data
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_data, y_data)
plt.ion()
plt.show()

for i in range(1000):
  # training
  sess.run(train_step, feed_dict={xs: x_data, ys: y_data})
  if i % 50 == 0:
    # to visualize the result and improvement
    try:
      ax.lines.remove(lines[0])
    except Exception:
      pass
    prediction_value = sess.run(prediction, feed_dict={xs: x_data})
    # plot the prediction
    lines = ax.plot(x_data, prediction_value, 'r-', lw=5)
    plt.pause(0.1)

运行结果:

tensorflow建立一个简单的神经网络的方法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现简单socket通信的方法
Apr 19 Python
使用Python从有道词典网页获取单词翻译
Jul 03 Python
关于Python中空格字符串处理的技巧总结
Aug 10 Python
python实现冒泡排序算法的两种方法
Mar 10 Python
python 读取DICOM头文件的实例
May 07 Python
Python3多进程 multiprocessing 模块实例详解
Jun 11 Python
centos6.8安装python3.7无法import _ssl的解决方法
Sep 17 Python
解决Pycharm出现的部分快捷键无效问题
Oct 22 Python
python 获取url中的参数列表实例
Dec 18 Python
解决pycharm remote deployment 配置的问题
Jun 27 Python
解决TensorFlow模型恢复报错的问题
Feb 06 Python
Python函数默认参数常见问题及解决方案
Mar 26 Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
python3爬取淘宝信息代码分析
Feb 10 #Python
Python中property属性实例解析
Feb 10 #Python
Java编程迭代地删除文件夹及其下的所有文件实例
Feb 10 #Python
Python中协程用法代码详解
Feb 10 #Python
Python实现简单生成验证码功能【基于random模块】
Feb 10 #Python
You might like
神族 PROTOSS 概述
2020/03/14 星际争霸
ftp类(example.php)
2006/10/09 PHP
解析thinkphp import 文件内容变量失效的问题
2013/06/20 PHP
PHP获取MySql新增记录ID值的3种方法
2014/06/24 PHP
php提交过来的数据生成为txt文件
2016/04/28 PHP
PHP函数按引用传递参数及函数可选参数用法示例
2018/06/04 PHP
jQuery中add实现同时选择两个id对象
2010/10/22 Javascript
js弹出框轻量级插件jquery.boxy使用介绍
2013/01/15 Javascript
在ASP.NET中使用JavaScript脚本的方法
2013/11/12 Javascript
2014 年最热门的21款JavaScript框架推荐
2014/12/25 Javascript
jQuery each函数源码分析
2016/05/25 Javascript
JS匹配日期和时间的正则表达式示例
2017/05/12 Javascript
Js中async/await的执行顺序详解
2017/09/22 Javascript
node puppeteer(headless chrome)实现网站登录
2018/05/09 Javascript
如何用webpack4带你实现一个vue的打包的项目
2018/06/20 Javascript
axios异步提交表单数据的几种方法
2019/08/11 Javascript
vue遍历对象中的数组取值示例
2019/11/07 Javascript
详解django.contirb.auth-认证
2018/07/16 Python
python 字典中取值的两种方法小结
2018/08/02 Python
python实现烟花小程序
2019/01/30 Python
5款Python程序员高频使用开发工具推荐
2019/04/10 Python
python接口自动化测试之接口数据依赖的实现方法
2019/04/26 Python
python 中值滤波,椒盐去噪,图片增强实例
2019/12/18 Python
加拿大健康、婴儿和美容产品在线购物:Well.ca
2016/11/30 全球购物
Nordgreen手表德国官方网站:丹麦极简主义手表
2019/10/31 全球购物
拉飞逸官网:Lafayette 148 New York
2020/07/15 全球购物
销售业务实习自我鉴定
2013/09/23 职场文书
出国留学介绍信
2014/01/13 职场文书
小学生秋游活动方案
2014/02/23 职场文书
判缓刑人员个人思想汇报
2014/10/10 职场文书
2015新学期校长寄语(3篇)
2015/03/25 职场文书
Go语言带缓冲的通道实现
2021/04/26 Golang
Python实现老照片修复之上色小技巧
2021/10/16 Python
vue如何实现关闭对话框后刷新列表
2022/04/08 Vue.js
python区块链实现简版工作量证明
2022/05/25 Python
服务器SVN搭建图文安装过程
2022/06/21 Servers