Python实现的递归神经网络简单示例


Posted in Python onAugust 11, 2017

本文实例讲述了Python实现的递归神经网络。分享给大家供大家参考,具体如下:

# Recurrent Neural Networks
import copy, numpy as np
np.random.seed(0)
# compute sigmoid nonlinearity
def sigmoid(x):
  output = 1/(1+np.exp(-x))
  return output
# convert output of sigmoid function to its derivative
def sigmoid_output_to_derivative(output):
  return output*(1-output)
# training dataset generation
int2binary = {}
binary_dim = 8
largest_number = pow(2,binary_dim)
binary = np.unpackbits(
  np.array([range(largest_number)],dtype=np.uint8).T,axis=1)
for i in range(largest_number):
  int2binary[i] = binary[i]
# input variables
alpha = 0.1
input_dim = 2
hidden_dim = 16
output_dim = 1
# initialize neural network weights
synapse_0 = 2*np.random.random((input_dim,hidden_dim)) - 1
synapse_1 = 2*np.random.random((hidden_dim,output_dim)) - 1
synapse_h = 2*np.random.random((hidden_dim,hidden_dim)) - 1
synapse_0_update = np.zeros_like(synapse_0)
synapse_1_update = np.zeros_like(synapse_1)
synapse_h_update = np.zeros_like(synapse_h)
# training logic
for j in range(10000):
  # generate a simple addition problem (a + b = c)
  a_int = np.random.randint(largest_number/2) # int version
  a = int2binary[a_int] # binary encoding
  b_int = np.random.randint(largest_number/2) # int version
  b = int2binary[b_int] # binary encoding
  # true answer
  c_int = a_int + b_int
  c = int2binary[c_int]
  # where we'll store our best guess (binary encoded)
  d = np.zeros_like(c)
  overallError = 0
  layer_2_deltas = list()
  layer_1_values = list()
  layer_1_values.append(np.zeros(hidden_dim))
  # moving along the positions in the binary encoding
  for position in range(binary_dim):
    # generate input and output
    X = np.array([[a[binary_dim - position - 1],b[binary_dim - position - 1]]])
    y = np.array([[c[binary_dim - position - 1]]]).T
    # hidden layer (input ~+ prev_hidden)
    layer_1 = sigmoid(np.dot(X,synapse_0) + np.dot(layer_1_values[-1],synapse_h))
    # output layer (new binary representation)
    layer_2 = sigmoid(np.dot(layer_1,synapse_1))
    # did we miss?... if so, by how much?
    layer_2_error = y - layer_2
    layer_2_deltas.append((layer_2_error)*sigmoid_output_to_derivative(layer_2))
    overallError += np.abs(layer_2_error[0])
    # decode estimate so we can print(it out)
    d[binary_dim - position - 1] = np.round(layer_2[0][0])
    # store hidden layer so we can use it in the next timestep
    layer_1_values.append(copy.deepcopy(layer_1))
  future_layer_1_delta = np.zeros(hidden_dim)
  for position in range(binary_dim):
    X = np.array([[a[position],b[position]]])
    layer_1 = layer_1_values[-position-1]
    prev_layer_1 = layer_1_values[-position-2]
    # error at output layer
    layer_2_delta = layer_2_deltas[-position-1]
    # error at hidden layer
    layer_1_delta = (future_layer_1_delta.dot(synapse_h.T) + layer_2_delta.dot(synapse_1.T)) * sigmoid_output_to_derivative(layer_1)
    # let's update all our weights so we can try again
    synapse_1_update += np.atleast_2d(layer_1).T.dot(layer_2_delta)
    synapse_h_update += np.atleast_2d(prev_layer_1).T.dot(layer_1_delta)
    synapse_0_update += X.T.dot(layer_1_delta)
    future_layer_1_delta = layer_1_delta
  synapse_0 += synapse_0_update * alpha
  synapse_1 += synapse_1_update * alpha
  synapse_h += synapse_h_update * alpha
  synapse_0_update *= 0
  synapse_1_update *= 0
  synapse_h_update *= 0
  # print(out progress)
  if j % 1000 == 0:
    print("Error:" + str(overallError))
    print("Pred:" + str(d))
    print("True:" + str(c))
    out = 0
    for index,x in enumerate(reversed(d)):
      out += x*pow(2,index)
    print(str(a_int) + " + " + str(b_int) + " = " + str(out))
    print("------------")

运行输出:

Error:[ 3.45638663]
Pred:[0 0 0 0 0 0 0 1]
True:[0 1 0 0 0 1 0 1]
9 + 60 = 1
------------
Error:[ 3.63389116]
Pred:[1 1 1 1 1 1 1 1]
True:[0 0 1 1 1 1 1 1]
28 + 35 = 255
------------
Error:[ 3.91366595]
Pred:[0 1 0 0 1 0 0 0]
True:[1 0 1 0 0 0 0 0]
116 + 44 = 72
------------
Error:[ 3.72191702]
Pred:[1 1 0 1 1 1 1 1]
True:[0 1 0 0 1 1 0 1]
4 + 73 = 223
------------
Error:[ 3.5852713]
Pred:[0 0 0 0 1 0 0 0]
True:[0 1 0 1 0 0 1 0]
71 + 11 = 8
------------
Error:[ 2.53352328]
Pred:[1 0 1 0 0 0 1 0]
True:[1 1 0 0 0 0 1 0]
81 + 113 = 162
------------
Error:[ 0.57691441]
Pred:[0 1 0 1 0 0 0 1]
True:[0 1 0 1 0 0 0 1]
81 + 0 = 81
------------
Error:[ 1.42589952]
Pred:[1 0 0 0 0 0 0 1]
True:[1 0 0 0 0 0 0 1]
4 + 125 = 129
------------
Error:[ 0.47477457]
Pred:[0 0 1 1 1 0 0 0]
True:[0 0 1 1 1 0 0 0]
39 + 17 = 56
------------
Error:[ 0.21595037]
Pred:[0 0 0 0 1 1 1 0]
True:[0 0 0 0 1 1 1 0]
11 + 3 = 14
------------

英文原文:https://iamtrask.github.io/2015/11/15/anyone-can-code-lstm/

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
python统计一个文本中重复行数的方法
Nov 19 Python
使用IPython来操作Docker容器的入门指引
Apr 08 Python
详解Python中的条件判断语句
May 14 Python
python解决方案:WindowsError: [Error 2]
Aug 28 Python
详谈在flask中使用jsonify和json.dumps的区别
Mar 26 Python
matplotlib subplots 调整子图间矩的实例
May 25 Python
python实现猜单词小游戏
May 22 Python
举例讲解Python常用模块
Mar 08 Python
通过实例解析Python调用json模块
Dec 11 Python
django在开发中取消外键约束的实现
May 20 Python
python 使用tkinter与messagebox写界面和弹窗
Mar 20 Python
python区块链持久化和命令行接口实现简版
May 25 Python
Python调用系统底层API播放wav文件的方法
Aug 11 #Python
Django 导出 Excel 代码的实例详解
Aug 11 #Python
python技能之数据导出excel的实例代码
Aug 11 #Python
利用标准库fractions模块让Python支持分数类型的方法详解
Aug 11 #Python
Python对字符串实现去重操作的方法示例
Aug 11 #Python
python中模块查找的原理与方法详解
Aug 11 #Python
python利用lxml读写xml格式的文件
Aug 10 #Python
You might like
PHP中两个float(浮点数)比较实例分析
2015/09/27 PHP
PHP中获取文件创建日期、修改日期、访问时间的方法
2016/11/05 PHP
PHP使用PDO创建MySQL数据库、表及插入多条数据操作示例
2019/05/30 PHP
关于js中alert弹出窗口文本换行问题简单详细说明
2012/12/11 Javascript
img onload事件绑定各浏览器均可执行
2012/12/19 Javascript
仿百度输入框智能提示的js代码
2013/08/22 Javascript
javascript相关事件的几个概念
2015/05/21 Javascript
windows下安装nodejs及框架express
2015/08/07 NodeJs
图片旋转、鼠标滚轮缩放、镜像、切换图片js代码
2020/12/13 Javascript
javascript求日期差的方法
2016/03/02 Javascript
jQuery on()方法绑定动态元素的点击事件实例代码浅析
2016/06/16 Javascript
AngularJS基础 ng-options 指令详解
2016/08/02 Javascript
jQuery实现的图片轮播效果完整示例
2016/09/12 Javascript
Vue中使用vux的配置详解
2017/05/05 Javascript
D3.js进阶系列之CSV表格文件的读取详解
2017/06/06 Javascript
Vue $emit $refs子父组件间方法的调用实例
2018/09/12 Javascript
jQuery.parseJSON()函数详解
2019/02/28 jQuery
vue从零实现一个消息通知组件的方法详解
2020/03/16 Javascript
Python读写Excel文件方法介绍
2014/11/22 Python
如何准确判断请求是搜索引擎爬虫(蜘蛛)发出的请求
2015/10/13 Python
详解JavaScript编程中的window与window.screen对象
2015/10/26 Python
在Pycharm中执行scrapy命令的方法
2019/01/16 Python
Python中print函数简单使用总结
2019/08/05 Python
CSS3 真的会替代 SCSS 吗
2021/03/09 HTML / CSS
法国时尚童装网站:Melijoe
2016/08/10 全球购物
Linux文件系统类型
2012/09/16 面试题
同志主要表现材料
2014/08/21 职场文书
学生自我鉴定格式及范文
2014/09/16 职场文书
2014年学校禁毒工作总结
2014/12/23 职场文书
离婚案件被告代理词
2015/05/23 职场文书
2015年学校精神文明工作总结
2015/05/27 职场文书
单位更名证明
2015/06/18 职场文书
校运会广播稿
2015/08/19 职场文书
python process模块的使用简介
2021/05/14 Python
Ajax请求超时与网络异常处理图文详解
2021/05/23 Javascript
详解Flutter网络请求Dio库的使用及封装
2022/04/14 Java/Android