numpy实现RNN原理实现


Posted in Python onMarch 02, 2021

首先说明代码只是帮助理解,并未写出梯度下降部分,默认参数已经被固定,不影响理解。代码主要实现RNN原理,只使用numpy库,不可用于GPU加速。

import numpy as np


class Rnn():

  def __init__(self, input_size, hidden_size, num_layers, bidirectional=False):
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.bidirectional = bidirectional

  def feed(self, x):
    '''

    :param x: [seq, batch_size, embedding]
    :return: out, hidden
    '''

    # x.shape [sep, batch, feature]
    # hidden.shape [hidden_size, batch]
    # Whh0.shape [hidden_size, hidden_size] Wih0.shape [hidden_size, feature]
    # Whh1.shape [hidden_size, hidden_size] Wih1.size [hidden_size, hidden_size]

    out = []
    x, hidden = np.array(x), [np.zeros((self.hidden_size, x.shape[1])) for i in range(self.num_layers)]
    Wih = [np.random.random((self.hidden_size, self.hidden_size)) for i in range(1, self.num_layers)]
    Wih.insert(0, np.random.random((self.hidden_size, x.shape[2])))
    Whh = [np.random.random((self.hidden_size, self.hidden_size)) for i in range(self.num_layers)]

    time = x.shape[0]
    for i in range(time):
      hidden[0] = np.tanh((np.dot(Wih[0], np.transpose(x[i, ...], (1, 0))) +
               np.dot(Whh[0], hidden[0])
               ))

      for i in range(1, self.num_layers):
        hidden[i] = np.tanh((np.dot(Wih[i], hidden[i-1]) +
                   np.dot(Whh[i], hidden[i])
                   ))

      out.append(hidden[self.num_layers-1])

    return np.array(out), np.array(hidden)


def sigmoid(x):
  return 1.0/(1.0 + 1.0/np.exp(x))


if __name__ == '__main__':
  rnn = Rnn(1, 5, 4)
  input = np.random.random((6, 2, 1))
  out, h = rnn.feed(input)
  print(f'seq is {input.shape[0]}, batch_size is {input.shape[1]} ', 'out.shape ', out.shape, ' h.shape ', h.shape)
  # print(sigmoid(np.random.random((2, 3))))
  #
  # element-wise multiplication
  # print(np.array([1, 2])*np.array([2, 1]))

到此这篇关于numpy实现RNN原理实现的文章就介绍到这了,更多相关numpy实现RNN内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python安装Scrapy图文教程
Aug 14 Python
python批量替换多文件字符串问题详解
Apr 22 Python
python自动重试第三方包retrying模块的方法
Apr 24 Python
Django rest framework实现分页的示例
May 24 Python
django+echart绘制曲线图的方法示例
Nov 26 Python
Python 获取中文字拼音首个字母的方法
Nov 28 Python
Python3实现的判断回文链表算法示例
Mar 08 Python
Django使用unittest模块进行单元测试过程解析
Aug 02 Python
如何在mac环境中用python处理protobuf
Dec 25 Python
为什么说python更适合树莓派编程
Jul 20 Python
解决python3中os.popen()出错的问题
Nov 19 Python
python如何在word中存储本地图片
Apr 07 Python
解决tensorflow模型压缩的问题_踩坑无数,总算搞定
Mar 02 #Python
python Protobuf定义消息类型知识点讲解
Mar 02 #Python
Django项目在pycharm新建的步骤方法
Mar 02 #Python
基于注解实现 SpringBoot 接口防刷的方法
Mar 02 #Python
python Autopep8实现按PEP8风格自动排版Python代码
Mar 02 #Python
pycharm配置安装autopep8自动规范代码的实现
Mar 02 #Python
Python实现我的世界小游戏源代码
Mar 02 #Python
You might like
php学习之 数组声明
2011/06/09 PHP
php中模拟POST传递数据的两种方法分享
2011/09/16 PHP
php eval函数用法总结
2012/10/31 PHP
解析PHP中VC6 X86和VC9 X86的区别及 Non Thread Safe的意思
2013/06/28 PHP
解密ThinkPHP3.1.2版本之模块和操作映射
2014/06/19 PHP
PHP进阶学习之依赖注入与Ioc容器详解
2019/06/19 PHP
js 全兼容可高亮二级缓冲折叠菜单
2010/06/04 Javascript
基于jquery的图片的切换(以数字的形式)
2011/02/14 Javascript
javascript模版引擎-tmpl的bug修复与性能优化分析
2011/10/23 Javascript
Moment.js 不容错过的超棒Javascript日期处理类库
2012/04/15 Javascript
jquery全选/全不选/反选另一种实现方法(配合原生js)
2013/04/07 Javascript
MyEclipse取消验证Js的两种方法
2013/11/14 Javascript
nodejs中使用多线程编程的方法实例
2015/03/24 NodeJs
JavaScript使用shift方法移除素组第一个元素实例分析
2015/04/06 Javascript
详细分析JavaScript函数定义
2015/07/16 Javascript
jquery实现LED广告牌旋转系统图片切换效果代码分享
2015/08/26 Javascript
九种原生js动画效果
2015/11/11 Javascript
微信小程序 富文本转文本实例详解
2016/10/24 Javascript
webpack实现热加载自动刷新的方法
2017/07/30 Javascript
Echarts基本用法_动力节点Java学院整理
2017/08/11 Javascript
集合Bootstrap自定义confirm提示效果
2017/09/19 Javascript
vue计算属性和监听器实例解析
2018/05/10 Javascript
使用p5.js临摹动态图片
2019/11/04 Javascript
vue父子模板传值问题解决方法案例分析
2020/02/26 Javascript
JavaScript数组类型Array相关的属性与方法详解
2020/09/08 Javascript
vue.js实现点击图标放大离开时缩小的代码
2021/01/27 Vue.js
python中cPickle用法例子分享
2014/01/03 Python
运用TensorFlow进行简单实现线性回归、梯度下降示例
2018/03/05 Python
python数据处理 根据颜色对图片进行分类的方法
2018/12/08 Python
Python 实现OpenCV格式和PIL.Image格式互转
2020/01/09 Python
详解pyqt5的UI中嵌入matplotlib图形并实时刷新(挖坑和填坑)
2020/08/07 Python
美国益智玩具购物网站:Fat Brain Toys
2017/11/03 全球购物
ParcelABC西班牙:包裹运送和快递服务
2019/12/24 全球购物
社区班子对照检查材料
2014/08/27 职场文书
2015教师年度思想工作总结
2015/04/30 职场文书
关于ObjectUtils.isEmpty() 和 null 的区别
2022/02/28 Java/Android