numpy实现RNN原理实现


Posted in Python onMarch 02, 2021

首先说明代码只是帮助理解,并未写出梯度下降部分,默认参数已经被固定,不影响理解。代码主要实现RNN原理,只使用numpy库,不可用于GPU加速。

import numpy as np


class Rnn():

  def __init__(self, input_size, hidden_size, num_layers, bidirectional=False):
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.bidirectional = bidirectional

  def feed(self, x):
    '''

    :param x: [seq, batch_size, embedding]
    :return: out, hidden
    '''

    # x.shape [sep, batch, feature]
    # hidden.shape [hidden_size, batch]
    # Whh0.shape [hidden_size, hidden_size] Wih0.shape [hidden_size, feature]
    # Whh1.shape [hidden_size, hidden_size] Wih1.size [hidden_size, hidden_size]

    out = []
    x, hidden = np.array(x), [np.zeros((self.hidden_size, x.shape[1])) for i in range(self.num_layers)]
    Wih = [np.random.random((self.hidden_size, self.hidden_size)) for i in range(1, self.num_layers)]
    Wih.insert(0, np.random.random((self.hidden_size, x.shape[2])))
    Whh = [np.random.random((self.hidden_size, self.hidden_size)) for i in range(self.num_layers)]

    time = x.shape[0]
    for i in range(time):
      hidden[0] = np.tanh((np.dot(Wih[0], np.transpose(x[i, ...], (1, 0))) +
               np.dot(Whh[0], hidden[0])
               ))

      for i in range(1, self.num_layers):
        hidden[i] = np.tanh((np.dot(Wih[i], hidden[i-1]) +
                   np.dot(Whh[i], hidden[i])
                   ))

      out.append(hidden[self.num_layers-1])

    return np.array(out), np.array(hidden)


def sigmoid(x):
  return 1.0/(1.0 + 1.0/np.exp(x))


if __name__ == '__main__':
  rnn = Rnn(1, 5, 4)
  input = np.random.random((6, 2, 1))
  out, h = rnn.feed(input)
  print(f'seq is {input.shape[0]}, batch_size is {input.shape[1]} ', 'out.shape ', out.shape, ' h.shape ', h.shape)
  # print(sigmoid(np.random.random((2, 3))))
  #
  # element-wise multiplication
  # print(np.array([1, 2])*np.array([2, 1]))

到此这篇关于numpy实现RNN原理实现的文章就介绍到这了,更多相关numpy实现RNN内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
跟老齐学Python之让人欢喜让人忧的迭代
Oct 02 Python
linux 下实现python多版本安装实践
Nov 18 Python
在Python的Django框架中实现Hacker News的一些功能
Apr 17 Python
python 将print输出的内容保存到txt文件中
Jul 17 Python
对Python+opencv将图片生成视频的实例详解
Jan 08 Python
Python中的集合介绍
Jan 28 Python
spark dataframe 将一列展开,把该列所有值都变成新列的方法
Jan 29 Python
利用python和百度地图API实现数据地图标注的方法
May 13 Python
django页面跳转问题及注意事项
Jul 18 Python
python matplotlib 画dataframe的时间序列图实例
Nov 20 Python
python读取hdfs并返回dataframe教程
Jun 05 Python
超级实用的8个Python列表技巧
Aug 24 Python
解决tensorflow模型压缩的问题_踩坑无数,总算搞定
Mar 02 #Python
python Protobuf定义消息类型知识点讲解
Mar 02 #Python
Django项目在pycharm新建的步骤方法
Mar 02 #Python
基于注解实现 SpringBoot 接口防刷的方法
Mar 02 #Python
python Autopep8实现按PEP8风格自动排版Python代码
Mar 02 #Python
pycharm配置安装autopep8自动规范代码的实现
Mar 02 #Python
Python实现我的世界小游戏源代码
Mar 02 #Python
You might like
PHP json格式和js json格式 js跨域调用实现代码
2012/09/08 PHP
php采用file_get_contents代替使用curl实例
2014/11/07 PHP
php array_map使用自定义的函数处理数组中的每个值
2016/10/26 PHP
javascript 清空form表单中某种元素的值
2009/12/26 Javascript
js定时器的使用(实例讲解)
2014/01/06 Javascript
javascript 获取网页标题代码实例
2014/01/22 Javascript
7个JS基础知识总结
2014/03/05 Javascript
介绍一个简单的JavaScript类框架
2015/06/24 Javascript
node.js中格式化数字增加千位符的几种方法
2015/07/03 Javascript
详解AngularJS中自定义过滤器
2015/12/28 Javascript
浅析BootStrap模态框的使用(经典)
2016/04/29 Javascript
Vue表单实例代码
2016/09/05 Javascript
Seajs是什么及sea.js 由来,特点以及优势
2016/10/13 Javascript
React组件的三种写法总结
2017/01/12 Javascript
javascript实现多张图片左右无缝滚动效果
2017/03/22 Javascript
JS数组搜索之折半搜索实现方法分析
2017/03/27 Javascript
如何选择jQuery版本 1.x? 2.x? 3.x?
2017/04/01 jQuery
Angular 4依赖注入学习教程之组件服务注入(二)
2017/06/04 Javascript
纯JS实现简单的日历
2017/06/26 Javascript
微信小程序实现保存图片到相册功能
2018/11/30 Javascript
JS二级菜单不同实现方法分析【4种方法】
2018/12/21 Javascript
JS使用队列对数组排列,基数排序算法示例
2019/03/02 Javascript
如何优雅地在Node应用中进行错误异常处理
2019/11/25 Javascript
vue 解决uglifyjs-webpack-plugin打包出现报错的问题
2020/08/04 Javascript
[16:19]教你分分钟做大人——风暴之灵
2015/03/11 DOTA
Python实现PS滤镜特效Marble Filter玻璃条纹扭曲效果示例
2018/01/29 Python
Python多线程多进程实例对比解析
2020/03/12 Python
python学习笔记之多进程
2020/08/06 Python
python中@property的作用和getter setter的解释
2020/12/22 Python
python软件测试Jmeter性能测试JDBC Request(结合数据库)的使用详解
2021/01/26 Python
css3 实现滚动条美化效果的实例代码
2021/01/06 HTML / CSS
高中综合实践活动总结
2014/07/07 职场文书
班子成员四风问题自我剖析材料
2014/09/29 职场文书
婚礼答谢词范文
2015/09/29 职场文书
学习商务礼仪心得体会
2016/01/22 职场文书
python语言中pandas字符串分割str.split()函数
2022/08/05 Python