关于tensorflow的几种参数初始化方法小结


Posted in Python onJanuary 04, 2020

在tensorflow中,经常会遇到参数初始化问题,比如在训练自己的词向量时,需要对原始的embeddigs矩阵进行初始化,更一般的,在全连接神经网络中,每层的权值w也需要进行初始化。

tensorlfow中应该有一下几种初始化方法

1. tf.constant_initializer() 常数初始化
2. tf.ones_initializer() 全1初始化
3. tf.zeros_initializer() 全0初始化
4. tf.random_uniform_initializer() 均匀分布初始化
5. tf.random_normal_initializer() 正态分布初始化
6. tf.truncated_normal_initializer() 截断正态分布初始化
7. tf.uniform_unit_scaling_initializer() 这种方法输入方差是常数
8. tf.variance_scaling_initializer() 自适应初始化
9. tf.orthogonal_initializer() 生成正交矩阵

具体的

1、tf.constant_initializer(),它的简写是tf.Constant()

#coding:utf-8
import numpy as np 
import tensorflow as tf 
train_inputs = [[1,2],[1,4],[3,2]]
with tf.variable_scope("embedding-layer"):
  val = np.array([[1,2,3,4,5,6,7],[1,3,4,5,2,1,9],[0,12,3,4,5,7,8],[2,3,5,5,6,8,9],[3,1,6,1,2,3,5]])
  const_init = tf.constant_initializer(val)
  embeddings = tf.get_variable("embed",shape=[5,7],dtype=tf.float32,initializer=const_init)
  embed = tf.nn.embedding_lookup(embeddings, train_inputs)             #在embedding中查找train_input所对应的表示
  print("embed",embed)
  sum_embed = tf.reduce_mean(embed,1)
initall = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(initall)
  print(sess.run(embed))
  print(sess.run(tf.shape(embed)))
  print(sess.run(sum_embed))

4、random_uniform_initializer = RandomUniform()

可简写为tf.RandomUniform()

生成均匀分布的随机数,参数有四个(minval=0, maxval=None, seed=None, dtype=dtypes.float32),分别用于指定最小值,最大值,随机数种子和类型。

6、tf.truncated_normal_initializer()

可简写tf.TruncatedNormal()

生成截断正态分布的随机数,这个初始化方法在tf中用得比较多。

它有四个参数(mean=0.0, stddev=1.0, seed=None, dtype=dtypes.float32),分别用于指定均值、标准差、随机数种子和随机数的数据类型,一般只需要设置stddev这一个参数就可以了。

8、tf.variance_scaling_initializer()

可简写为tf.VarianceScaling()

参数为(scale=1.0,mode="fan_in",distribution="normal",seed=None,dtype=dtypes.float32)

scale: 缩放尺度(正浮点数)

mode: "fan_in", "fan_out", "fan_avg"中的一个,用于计算标准差stddev的值。

distribution:分布类型,"normal"或“uniform"中的一个。

当 distribution="normal" 的时候,生成truncated normal distribution(截断正态分布) 的随机数,其中stddev = sqrt(scale / n) ,n的计算与mode参数有关。

如果mode = "fan_in", n为输入单元的结点数;

如果mode = "fan_out",n为输出单元的结点数;

如果mode = "fan_avg",n为输入和输出单元结点数的平均值。

当distribution="uniform”的时候 ,生成均匀分布的随机数,假设分布区间为[-limit, limit],则 limit = sqrt(3 * scale / n)

以上这篇关于tensorflow的几种参数初始化方法小结就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用scrapy解析js示例
Jan 23 Python
python获取一组数据里最大值max函数用法实例
May 26 Python
使用Python读写及压缩和解压缩文件的示例
Jul 08 Python
python实现一个简单的并查集的示例代码
Mar 19 Python
浅谈django rest jwt vue 跨域问题
Oct 26 Python
PyQt弹出式对话框的常用方法及标准按钮类型
Feb 27 Python
Python实现一个数组除以一个数的例子
Jul 20 Python
基于Python安装pyecharts所遇的问题及解决方法
Aug 12 Python
Python tkinter常用操作代码实例
Jan 03 Python
基于python 凸包问题的解决
Apr 16 Python
Jupyter Notebook添加代码自动补全功能的实现
Jan 07 Python
Pytorch 实现变量类型转换
May 17 Python
基于TensorFlow常量、序列以及随机值生成实例
Jan 04 #Python
Tensorflow 实现分批量读取数据
Jan 04 #Python
Tensorflow的常用矩阵生成方式
Jan 04 #Python
Tensorflow读取并输出已保存模型的权重数值方式
Jan 04 #Python
tensorflow实现打印ckpt模型保存下的变量名称及变量值
Jan 04 #Python
tensorflow 获取所有variable或tensor的name示例
Jan 04 #Python
tensorflow没有output结点,存储成pb文件的例子
Jan 04 #Python
You might like
php采集速度探究总结(原创)
2008/04/18 PHP
PHP开发中常见的安全问题详解和解决方法(如Sql注入、CSRF、Xss、CC等)
2014/04/21 PHP
php批量删除超链接的实现方法
2015/10/19 PHP
win10环境PHP 7 安装配置【教程】
2016/05/09 PHP
用于节点操作的API,颠覆原生操作HTML DOM节点的API
2010/12/11 Javascript
javascript中判断一个值是否在数组中并没有直接使用
2012/12/17 Javascript
jQuery中$.fn的用法示例介绍
2013/11/05 Javascript
httpclient模拟登陆具体实现(使用js设置cookie)
2013/12/11 Javascript
js清理Word格式示例代码
2014/02/13 Javascript
借助javascript代码判断网页是静态还是伪静态
2014/05/05 Javascript
jQuery实现倒计时按钮功能代码分享
2014/09/03 Javascript
javascript实现简单的全选和反选功能
2016/01/05 Javascript
jQuery选择器之子元素选择器详解
2017/09/18 jQuery
原生JS封装animate运动框架的实例
2017/10/12 Javascript
js实现复制功能(多种方法集合)
2018/01/06 Javascript
Angular数据绑定机制原理
2018/04/17 Javascript
JS实现520 表白简单代码
2018/05/21 Javascript
JSONP解决JS跨域问题的实现
2020/05/25 Javascript
python使用内存zipfile对象在内存中打包文件示例
2014/04/30 Python
Python装饰器的函数式编程详解
2015/02/27 Python
Python数据结构之翻转链表
2017/02/25 Python
Python 操作MySQL详解及实例
2017/04/30 Python
详谈python http长连接客户端
2017/06/12 Python
用python一行代码得到数组中某个元素的个数方法
2019/01/28 Python
python 使用elasticsearch 实现翻页的三种方式
2020/07/31 Python
Python pip install之SSL异常处理操作
2020/09/03 Python
一篇文章教你用python画动态爱心表白
2020/11/22 Python
SmartBuyGlasses丹麦:网上购买名牌太阳镜、眼镜和隐形眼镜
2016/10/01 全球购物
Audible英国:有声读物,30天免费试用
2019/10/16 全球购物
企业为何需要商业计划书
2013/12/26 职场文书
母亲80寿诞答谢词
2014/01/16 职场文书
实习指导教师评语
2014/12/30 职场文书
辞职信格式模板
2015/02/27 职场文书
2019公司借款合同范本2篇!
2019/07/24 职场文书
PO模式在selenium自动化测试框架的优势
2022/03/20 Python
使用Redis实现分布式锁的方法
2022/06/16 Redis