tensorflow之自定义神经网络层实例


Posted in Python onFebruary 07, 2020

如下所示:

import tensorflow as tf
tfe = tf.contrib.eager

tf.enable_eager_execution()

大多数情况下,在为机器学习模型编写代码时,您希望在比单个操作和单个变量操作更高的抽象级别上操作。

1.关于图层的一些有用操作

许多机器学习模型可以表达为相对简单的图层的组合和堆叠,TensorFlow提供了一组许多常用图层,以及您从头开始或作为组合创建自己的应用程序特定图层的简单方法。TensorFlow在tf.keras包中包含完整的Keras API,而Keras层在构建自己的模型时非常有用。

#在tf.keras.layers包中,图层是对象。要构造一个图层,只需构造一个对象。大多数层将输出维度/通道的数量作为第一个参数。
layer=tf.keras.layers.Dense(100)
#输入维度的数量通常是不必要的,因为它可以在第一次使用图层时推断出来,但如果您想手动指定它,则可以提供它,这在某些复杂模型中很有用。
layer=tf.keras.layers.Dense(10,input_shape=(None,5))
#调用层
layer(tf.zeros([10,5]))
 

#图层有许多有用的方法。例如,您可以通过调用layer.variables来检查图层中的所有变量。在这种情况下,完全连接的层将具有权重和偏差的变量。
variable=layer.variables
# variable[0]
layer.kernel.numpy()
layer.bias

2.自定义图层

实现自己的层的最佳方法是扩展tf.keras.Layer类并实现:

__init__,您可以在其中执行所有与输入无关的初始化

build方法,您知道输入张量的形状,并可以进行其余的初始化

call方法,在这里进行正向传播计算

请注意,您不必等到调用build来创建变量,您也可以在__init__中创建它们。但是,在build中创建它们的优点是它可以根据图层将要操作的输入的形状启用后期变量创建。另一方面,在__init__中创建变量意味着需要明确指定创建变量所需的形状。

class MyDenseLayer(tf.keras.layers.Layer):
 def __init__(self, num_outputs):
  super(MyDenseLayer, self).__init__()
  self.num_outputs = num_outputs
  
 def build(self, input_shape):
  self.kernel = self.add_variable("kernel", 
                  shape=[input_shape[-1].value, 
                      self.num_outputs])
  
 def call(self, input):
  return tf.matmul(input, self.kernel)
 
layer = MyDenseLayer(10)
print(layer(tf.zeros([10, 5])))
print(layer.variables)

3.搭建网络结构

机器学习模型中许多有趣的图层是通过组合现有层来实现的。例如,resnet中的每个residual块是卷积,批量标准化等的组合。

创建包含其他图层的类似图层的东西时使用的主类是tf.keras.Model。实现一个是通过继承自tf.keras.Model完成的。

class ResnetIdentityBlock(tf.keras.Model):
 def __init__(self, kernel_size, filters):
  super(ResnetIdentityBlock, self).__init__(name='')
  filters1, filters2, filters3 = filters
 
  self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))
  self.bn2a = tf.keras.layers.BatchNormalization()
 
  self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')
  self.bn2b = tf.keras.layers.BatchNormalization()
 
  self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))
  self.bn2c = tf.keras.layers.BatchNormalization()
 
 def call(self, input_tensor, training=False):
  x = self.conv2a(input_tensor)
  x = self.bn2a(x, training=training)
  x = tf.nn.relu(x)
 
  x = self.conv2b(x)
  x = self.bn2b(x, training=training)
  x = tf.nn.relu(x)
 
  x = self.conv2c(x)
  x = self.bn2c(x, training=training)
 
  x += input_tensor
  return tf.nn.relu(x)
 
  
block = ResnetIdentityBlock(1, [1, 2, 3])
print(block(tf.zeros([1, 2, 3, 3])))
print([x.name for x in block.variables])

以上这篇tensorflow之自定义神经网络层实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python搭建Django应用程序步骤及版本冲突问题解决
Nov 19 Python
python概率计算器实例分析
Mar 25 Python
windows系统下Python环境的搭建(Aptana Studio)
Mar 06 Python
Python字符串内置函数功能与用法总结
Apr 16 Python
python中时间、日期、时间戳的转换的实现方法
Jul 06 Python
Django的models中on_delete参数详解
Jul 16 Python
对python3中的RE(正则表达式)-详细总结
Jul 23 Python
Python 余弦相似度与皮尔逊相关系数 计算实例
Dec 23 Python
Python字符串格式化常用手段及注意事项
Jun 17 Python
新手常见Python错误及异常解决处理方案
Jun 18 Python
Python字典取键、值对的方法步骤
Sep 30 Python
Python图片检索之以图搜图
May 31 Python
在tensorflow中设置使用某一块GPU、多GPU、CPU的操作
Feb 07 #Python
谈一谈数组拼接tf.concat()和np.concatenate()的区别
Feb 07 #Python
python文件和文件夹复制函数
Feb 07 #Python
tf.concat中axis的含义与使用详解
Feb 07 #Python
浅谈tensorflow 中tf.concat()的使用
Feb 07 #Python
Python for循环通过序列索引迭代过程解析
Feb 07 #Python
python中with用法讲解
Feb 07 #Python
You might like
为IP查询添加GOOGLE地图功能的代码
2010/08/08 PHP
PHP中通过加号合并数组的一个简单方法分享
2011/01/27 PHP
PHP使用array_multisort对多个数组或多维数组进行排序
2014/12/16 PHP
ThinkPHP2.x防范XSS跨站攻击的方法
2015/09/25 PHP
PHP实现的redis主从数据库状态检测功能示例
2017/07/20 PHP
PHP __call()方法实现委托示例
2019/05/20 PHP
javascript AutoScroller 函数类
2009/05/29 Javascript
基于javascript实现九九乘法表
2016/03/27 Javascript
简单实现JS计算器功能
2016/12/21 Javascript
JavaScript调试的多个必备小Tips
2017/01/15 Javascript
JS在浏览器中解析Base64编码图像
2017/02/09 Javascript
bootstrap 通过加减按钮实现输入框组功能
2017/11/15 Javascript
ajax与jsonp的区别及用法
2018/10/16 Javascript
vue-router的使用方法及含参数的配置方法
2018/11/13 Javascript
详解JavaScript实现动态的轮播图效果
2019/04/29 Javascript
JS面向对象编程实现的Tab选项卡案例详解
2020/03/03 Javascript
解决vue中使用less/sass及使用中遇到无效的问题
2020/10/24 Javascript
Javascript实现关闭广告效果
2021/01/29 Javascript
python实现趣味图片字符化
2019/04/30 Python
利用python在大量数据文件下删除某一行的例子
2019/08/21 Python
python进行参数传递的方法
2020/05/12 Python
AmazeUI折叠式卡片布局,整合内容列表、表格组件实现
2020/08/20 HTML / CSS
Perry Ellis官网:美国男士品味服装
2016/12/09 全球购物
大女孩胸罩:Big Girls Bras
2016/12/15 全球购物
英国在线电子和小工具商店:TecoBuy
2018/10/06 全球购物
公司部门司机岗位职责
2014/01/03 职场文书
信息技术毕业生自荐信范文
2014/03/13 职场文书
小学三八妇女节活动方案
2014/03/16 职场文书
公司年会抽奖活动主持词
2014/03/31 职场文书
个人整改方案范文
2014/10/25 职场文书
2014年信访维稳工作总结
2014/12/08 职场文书
工程服务质量承诺书
2015/04/29 职场文书
入学证明
2015/06/23 职场文书
导游词之台湾安平古堡
2019/12/25 职场文书
SpringCloud Alibaba 基本开发框架搭建过程
2021/06/13 Java/Android
postman中form-data、x-www-form-urlencoded、raw、binary的区别介绍
2022/01/18 HTML / CSS