tensorflow之自定义神经网络层实例


Posted in Python onFebruary 07, 2020

如下所示:

import tensorflow as tf
tfe = tf.contrib.eager

tf.enable_eager_execution()

大多数情况下,在为机器学习模型编写代码时,您希望在比单个操作和单个变量操作更高的抽象级别上操作。

1.关于图层的一些有用操作

许多机器学习模型可以表达为相对简单的图层的组合和堆叠,TensorFlow提供了一组许多常用图层,以及您从头开始或作为组合创建自己的应用程序特定图层的简单方法。TensorFlow在tf.keras包中包含完整的Keras API,而Keras层在构建自己的模型时非常有用。

#在tf.keras.layers包中,图层是对象。要构造一个图层,只需构造一个对象。大多数层将输出维度/通道的数量作为第一个参数。
layer=tf.keras.layers.Dense(100)
#输入维度的数量通常是不必要的,因为它可以在第一次使用图层时推断出来,但如果您想手动指定它,则可以提供它,这在某些复杂模型中很有用。
layer=tf.keras.layers.Dense(10,input_shape=(None,5))
#调用层
layer(tf.zeros([10,5]))
 

#图层有许多有用的方法。例如,您可以通过调用layer.variables来检查图层中的所有变量。在这种情况下,完全连接的层将具有权重和偏差的变量。
variable=layer.variables
# variable[0]
layer.kernel.numpy()
layer.bias

2.自定义图层

实现自己的层的最佳方法是扩展tf.keras.Layer类并实现:

__init__,您可以在其中执行所有与输入无关的初始化

build方法,您知道输入张量的形状,并可以进行其余的初始化

call方法,在这里进行正向传播计算

请注意,您不必等到调用build来创建变量,您也可以在__init__中创建它们。但是,在build中创建它们的优点是它可以根据图层将要操作的输入的形状启用后期变量创建。另一方面,在__init__中创建变量意味着需要明确指定创建变量所需的形状。

class MyDenseLayer(tf.keras.layers.Layer):
 def __init__(self, num_outputs):
  super(MyDenseLayer, self).__init__()
  self.num_outputs = num_outputs
  
 def build(self, input_shape):
  self.kernel = self.add_variable("kernel", 
                  shape=[input_shape[-1].value, 
                      self.num_outputs])
  
 def call(self, input):
  return tf.matmul(input, self.kernel)
 
layer = MyDenseLayer(10)
print(layer(tf.zeros([10, 5])))
print(layer.variables)

3.搭建网络结构

机器学习模型中许多有趣的图层是通过组合现有层来实现的。例如,resnet中的每个residual块是卷积,批量标准化等的组合。

创建包含其他图层的类似图层的东西时使用的主类是tf.keras.Model。实现一个是通过继承自tf.keras.Model完成的。

class ResnetIdentityBlock(tf.keras.Model):
 def __init__(self, kernel_size, filters):
  super(ResnetIdentityBlock, self).__init__(name='')
  filters1, filters2, filters3 = filters
 
  self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))
  self.bn2a = tf.keras.layers.BatchNormalization()
 
  self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')
  self.bn2b = tf.keras.layers.BatchNormalization()
 
  self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))
  self.bn2c = tf.keras.layers.BatchNormalization()
 
 def call(self, input_tensor, training=False):
  x = self.conv2a(input_tensor)
  x = self.bn2a(x, training=training)
  x = tf.nn.relu(x)
 
  x = self.conv2b(x)
  x = self.bn2b(x, training=training)
  x = tf.nn.relu(x)
 
  x = self.conv2c(x)
  x = self.bn2c(x, training=training)
 
  x += input_tensor
  return tf.nn.relu(x)
 
  
block = ResnetIdentityBlock(1, [1, 2, 3])
print(block(tf.zeros([1, 2, 3, 3])))
print([x.name for x in block.variables])

以上这篇tensorflow之自定义神经网络层实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
可用于监控 mysql Master Slave 状态的python代码
Feb 10 Python
web.py获取上传文件名的正确方法
Aug 26 Python
Python 模板引擎的注入问题分析
Jan 01 Python
python去除空格和换行符的实现方法(推荐)
Jan 04 Python
Python异常处理操作实例详解
May 10 Python
Django中使用Celery的教程详解
Aug 24 Python
python 对类的成员函数开启线程的方法
Jan 22 Python
Django Rest framework三种分页方式详解
Jul 26 Python
elasticsearch python 查询的两种方法
Aug 04 Python
python之pymysql模块简单应用示例代码
Dec 16 Python
python 获取字典特定值对应的键的实现
Sep 29 Python
Python创建自己的加密货币的示例
Mar 01 Python
在tensorflow中设置使用某一块GPU、多GPU、CPU的操作
Feb 07 #Python
谈一谈数组拼接tf.concat()和np.concatenate()的区别
Feb 07 #Python
python文件和文件夹复制函数
Feb 07 #Python
tf.concat中axis的含义与使用详解
Feb 07 #Python
浅谈tensorflow 中tf.concat()的使用
Feb 07 #Python
Python for循环通过序列索引迭代过程解析
Feb 07 #Python
python中with用法讲解
Feb 07 #Python
You might like
php 魔术函数使用说明
2010/02/21 PHP
浅析PHP substr,mb_substr以及mb_strcut的区别和用法
2013/06/21 PHP
php获取文件类型和文件信息的方法
2015/07/10 PHP
thinkPHP分页功能实例详解
2017/05/05 PHP
ThinkPHP5 验证器的具体使用
2018/05/31 PHP
javascript函数中的arguments参数
2010/08/01 Javascript
16个最流行的JavaScript框架[推荐]
2011/05/29 Javascript
javascript是怎么继承的介绍
2012/01/05 Javascript
js函数的延迟加载实现代码
2012/10/11 Javascript
jQuery中Dom的基本操作小结
2014/01/23 Javascript
Jquery实现地铁线路指示灯提示牌效果的方法
2015/03/02 Javascript
JS实现新浪微博效果带遮罩层的弹出框代码
2015/10/12 Javascript
使用JavaScript脚本判断页面是否在微信中被打开
2016/03/06 Javascript
Web打印解决方案之证件套打的实现思路
2016/08/29 Javascript
javascript 判断是否是微信浏览器的方法
2016/10/09 Javascript
jQuery实现动态生成表格并为行绑定单击变色动作的方法
2017/04/17 jQuery
Ionic3 UI组件之Gallery Modal详解
2017/06/07 Javascript
layui结合form,table的全选、反选v1.0示例讲解
2018/08/15 Javascript
JS异步执行结果获取的3种解决方式
2019/02/19 Javascript
微信小程序学习笔记之函数定义、页面渲染图文详解
2019/03/28 Javascript
layui按条件隐藏表格列的实例
2019/09/19 Javascript
详解Vue 数据更新了但页面没有更新的 7 种情况汇总及延伸总结
2020/05/28 Javascript
Python随手笔记之标准类型内建函数
2015/12/02 Python
python自动化脚本安装指定版本python环境详解
2017/09/14 Python
Python 获得命令行参数的方法(推荐)
2018/01/24 Python
Python3字符串encode与decode的讲解
2019/04/02 Python
python图形开发GUI库wxpython使用方法详解
2020/02/14 Python
利用pyecharts读取csv并进行数据统计可视化的实现
2020/04/17 Python
naturalizer加拿大官网:美国娜然女鞋
2017/04/04 全球购物
英国儿童设计师服装和玩具购物网站:Zac & Lulu
2020/10/19 全球购物
荷兰美妆护肤品海淘网站:Beautinow(中文)
2020/11/22 全球购物
给全校老师的建议书
2014/03/13 职场文书
篮球比赛口号
2014/06/10 职场文书
公证委托书标准格式
2014/09/11 职场文书
2014年法院个人工作总结
2014/12/17 职场文书
羊脂球读书笔记
2015/06/30 职场文书