tensorflow之自定义神经网络层实例


Posted in Python onFebruary 07, 2020

如下所示:

import tensorflow as tf
tfe = tf.contrib.eager

tf.enable_eager_execution()

大多数情况下,在为机器学习模型编写代码时,您希望在比单个操作和单个变量操作更高的抽象级别上操作。

1.关于图层的一些有用操作

许多机器学习模型可以表达为相对简单的图层的组合和堆叠,TensorFlow提供了一组许多常用图层,以及您从头开始或作为组合创建自己的应用程序特定图层的简单方法。TensorFlow在tf.keras包中包含完整的Keras API,而Keras层在构建自己的模型时非常有用。

#在tf.keras.layers包中,图层是对象。要构造一个图层,只需构造一个对象。大多数层将输出维度/通道的数量作为第一个参数。
layer=tf.keras.layers.Dense(100)
#输入维度的数量通常是不必要的,因为它可以在第一次使用图层时推断出来,但如果您想手动指定它,则可以提供它,这在某些复杂模型中很有用。
layer=tf.keras.layers.Dense(10,input_shape=(None,5))
#调用层
layer(tf.zeros([10,5]))
 

#图层有许多有用的方法。例如,您可以通过调用layer.variables来检查图层中的所有变量。在这种情况下,完全连接的层将具有权重和偏差的变量。
variable=layer.variables
# variable[0]
layer.kernel.numpy()
layer.bias

2.自定义图层

实现自己的层的最佳方法是扩展tf.keras.Layer类并实现:

__init__,您可以在其中执行所有与输入无关的初始化

build方法,您知道输入张量的形状,并可以进行其余的初始化

call方法,在这里进行正向传播计算

请注意,您不必等到调用build来创建变量,您也可以在__init__中创建它们。但是,在build中创建它们的优点是它可以根据图层将要操作的输入的形状启用后期变量创建。另一方面,在__init__中创建变量意味着需要明确指定创建变量所需的形状。

class MyDenseLayer(tf.keras.layers.Layer):
 def __init__(self, num_outputs):
  super(MyDenseLayer, self).__init__()
  self.num_outputs = num_outputs
  
 def build(self, input_shape):
  self.kernel = self.add_variable("kernel", 
                  shape=[input_shape[-1].value, 
                      self.num_outputs])
  
 def call(self, input):
  return tf.matmul(input, self.kernel)
 
layer = MyDenseLayer(10)
print(layer(tf.zeros([10, 5])))
print(layer.variables)

3.搭建网络结构

机器学习模型中许多有趣的图层是通过组合现有层来实现的。例如,resnet中的每个residual块是卷积,批量标准化等的组合。

创建包含其他图层的类似图层的东西时使用的主类是tf.keras.Model。实现一个是通过继承自tf.keras.Model完成的。

class ResnetIdentityBlock(tf.keras.Model):
 def __init__(self, kernel_size, filters):
  super(ResnetIdentityBlock, self).__init__(name='')
  filters1, filters2, filters3 = filters
 
  self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))
  self.bn2a = tf.keras.layers.BatchNormalization()
 
  self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')
  self.bn2b = tf.keras.layers.BatchNormalization()
 
  self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))
  self.bn2c = tf.keras.layers.BatchNormalization()
 
 def call(self, input_tensor, training=False):
  x = self.conv2a(input_tensor)
  x = self.bn2a(x, training=training)
  x = tf.nn.relu(x)
 
  x = self.conv2b(x)
  x = self.bn2b(x, training=training)
  x = tf.nn.relu(x)
 
  x = self.conv2c(x)
  x = self.bn2c(x, training=training)
 
  x += input_tensor
  return tf.nn.relu(x)
 
  
block = ResnetIdentityBlock(1, [1, 2, 3])
print(block(tf.zeros([1, 2, 3, 3])))
print([x.name for x in block.variables])

以上这篇tensorflow之自定义神经网络层实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python cookielib 登录人人网的实现代码
Dec 19 Python
利用Python实现图书超期提醒
Aug 02 Python
Python获取当前路径实现代码
May 08 Python
Python使用requests及BeautifulSoup构建爬虫实例代码
Jan 24 Python
python发送邮件脚本
May 22 Python
Python API自动化框架总结
Nov 12 Python
python对Excel的读取的示例代码
Feb 14 Python
python 已知一个字符,在一个list中找出近似值或相似值实现模糊匹配
Feb 29 Python
Python基于当前时间批量创建文件
May 07 Python
Python基于template实现字符串替换
Nov 27 Python
PyTorch 中的傅里叶卷积实现示例
Dec 11 Python
python palywright库基本使用
Jan 21 Python
在tensorflow中设置使用某一块GPU、多GPU、CPU的操作
Feb 07 #Python
谈一谈数组拼接tf.concat()和np.concatenate()的区别
Feb 07 #Python
python文件和文件夹复制函数
Feb 07 #Python
tf.concat中axis的含义与使用详解
Feb 07 #Python
浅谈tensorflow 中tf.concat()的使用
Feb 07 #Python
Python for循环通过序列索引迭代过程解析
Feb 07 #Python
python中with用法讲解
Feb 07 #Python
You might like
浅谈PHP与C#的值类型指向区别的详解
2013/05/21 PHP
php Imagick获取图片RGB颜色值
2014/07/28 PHP
PHP实现HTTP断点续传的方法
2015/06/17 PHP
自适应图片大小的弹出窗口
2006/07/27 Javascript
两个DIV等高的JS的实现代码
2007/12/23 Javascript
jquery重复提交请求的原因浅析
2014/05/23 Javascript
Node.js与PHP、Python的字符处理性能对比
2014/07/06 Javascript
Node.js中使用事件发射器模式实现事件绑定详解
2014/08/15 Javascript
JavaScript动态数量的文件上传控件
2016/11/18 Javascript
bootstrap警告框使用方法解析
2017/01/13 Javascript
Vue实现数字输入框中分割手机号码的示例
2017/10/10 Javascript
JavaScript使用math.js进行精确计算操作示例
2018/06/19 Javascript
浅谈angular表单提交中ng-submit的默认使用方法
2018/09/30 Javascript
element-ui组件table实现自定义筛选功能的示例代码
2019/03/15 Javascript
vscode配置vue下的es6规范自动格式化详解
2019/03/20 Javascript
Vue实现渲染数据后控制滚动条位置(推荐)
2019/12/09 Javascript
纯js实现无缝滚动功能代码实例
2020/02/21 Javascript
[02:32]DOTA2英雄基础教程 美杜莎
2014/01/07 DOTA
分析并输出Python代码依赖的库的实现代码
2015/08/09 Python
Python入门之三角函数tan()函数实例详解
2017/11/08 Python
pandas.dataframe按行索引表达式选取方法
2018/10/30 Python
python+opencv实现阈值分割
2018/12/26 Python
基于wxPython的GUI实现输入对话框(2)
2019/02/27 Python
python str字符串转uuid实例
2020/03/03 Python
python GUI库图形界面开发之PyQt5简单绘图板实例与代码分析
2020/03/08 Python
python文件读取失败怎么处理
2020/06/23 Python
使用Python画了一棵圣诞树的实例代码
2020/11/27 Python
CSS3实现图片抽屉式效果的示例代码
2019/11/06 HTML / CSS
波兰最大的宠物用品网上商店:FERA.PL
2019/08/11 全球购物
美国艺术和工艺品商店:Hobby Lobby
2020/12/09 全球购物
预防艾滋病宣传标语
2014/06/25 职场文书
烈士陵园观后感
2015/06/08 职场文书
董事会决议范本
2015/07/01 职场文书
2015年政治教研组工作总结
2015/07/22 职场文书
一波干货,会议主持词开场白范文
2019/05/06 职场文书
关于使用Redisson订阅数问题
2022/01/18 Redis