tensorflow之自定义神经网络层实例


Posted in Python onFebruary 07, 2020

如下所示:

import tensorflow as tf
tfe = tf.contrib.eager

tf.enable_eager_execution()

大多数情况下,在为机器学习模型编写代码时,您希望在比单个操作和单个变量操作更高的抽象级别上操作。

1.关于图层的一些有用操作

许多机器学习模型可以表达为相对简单的图层的组合和堆叠,TensorFlow提供了一组许多常用图层,以及您从头开始或作为组合创建自己的应用程序特定图层的简单方法。TensorFlow在tf.keras包中包含完整的Keras API,而Keras层在构建自己的模型时非常有用。

#在tf.keras.layers包中,图层是对象。要构造一个图层,只需构造一个对象。大多数层将输出维度/通道的数量作为第一个参数。
layer=tf.keras.layers.Dense(100)
#输入维度的数量通常是不必要的,因为它可以在第一次使用图层时推断出来,但如果您想手动指定它,则可以提供它,这在某些复杂模型中很有用。
layer=tf.keras.layers.Dense(10,input_shape=(None,5))
#调用层
layer(tf.zeros([10,5]))
 

#图层有许多有用的方法。例如,您可以通过调用layer.variables来检查图层中的所有变量。在这种情况下,完全连接的层将具有权重和偏差的变量。
variable=layer.variables
# variable[0]
layer.kernel.numpy()
layer.bias

2.自定义图层

实现自己的层的最佳方法是扩展tf.keras.Layer类并实现:

__init__,您可以在其中执行所有与输入无关的初始化

build方法,您知道输入张量的形状,并可以进行其余的初始化

call方法,在这里进行正向传播计算

请注意,您不必等到调用build来创建变量,您也可以在__init__中创建它们。但是,在build中创建它们的优点是它可以根据图层将要操作的输入的形状启用后期变量创建。另一方面,在__init__中创建变量意味着需要明确指定创建变量所需的形状。

class MyDenseLayer(tf.keras.layers.Layer):
 def __init__(self, num_outputs):
  super(MyDenseLayer, self).__init__()
  self.num_outputs = num_outputs
  
 def build(self, input_shape):
  self.kernel = self.add_variable("kernel", 
                  shape=[input_shape[-1].value, 
                      self.num_outputs])
  
 def call(self, input):
  return tf.matmul(input, self.kernel)
 
layer = MyDenseLayer(10)
print(layer(tf.zeros([10, 5])))
print(layer.variables)

3.搭建网络结构

机器学习模型中许多有趣的图层是通过组合现有层来实现的。例如,resnet中的每个residual块是卷积,批量标准化等的组合。

创建包含其他图层的类似图层的东西时使用的主类是tf.keras.Model。实现一个是通过继承自tf.keras.Model完成的。

class ResnetIdentityBlock(tf.keras.Model):
 def __init__(self, kernel_size, filters):
  super(ResnetIdentityBlock, self).__init__(name='')
  filters1, filters2, filters3 = filters
 
  self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))
  self.bn2a = tf.keras.layers.BatchNormalization()
 
  self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')
  self.bn2b = tf.keras.layers.BatchNormalization()
 
  self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))
  self.bn2c = tf.keras.layers.BatchNormalization()
 
 def call(self, input_tensor, training=False):
  x = self.conv2a(input_tensor)
  x = self.bn2a(x, training=training)
  x = tf.nn.relu(x)
 
  x = self.conv2b(x)
  x = self.bn2b(x, training=training)
  x = tf.nn.relu(x)
 
  x = self.conv2c(x)
  x = self.bn2c(x, training=training)
 
  x += input_tensor
  return tf.nn.relu(x)
 
  
block = ResnetIdentityBlock(1, [1, 2, 3])
print(block(tf.zeros([1, 2, 3, 3])))
print([x.name for x in block.variables])

以上这篇tensorflow之自定义神经网络层实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python处理圆角图片、圆形图片的例子
Apr 25 Python
Python文件及目录操作实例详解
Jun 04 Python
Python实现扩展内置类型的方法分析
Oct 16 Python
python表格存取的方法
Mar 07 Python
Python 批量合并多个txt文件的实例讲解
May 08 Python
Python模块、包(Package)概念与用法分析
May 31 Python
Python实现决策树并且使用Graphviz可视化的例子
Aug 09 Python
利用Python复制文件的9种方法总结
Sep 02 Python
通过python实现windows桌面截图代码实例
Jan 17 Python
Python代码一键转Jar包及Java调用Python新姿势
Mar 10 Python
Python unittest单元测试openpyxl实现过程解析
May 27 Python
PyTorch 导数应用的使用教程
Aug 31 Python
在tensorflow中设置使用某一块GPU、多GPU、CPU的操作
Feb 07 #Python
谈一谈数组拼接tf.concat()和np.concatenate()的区别
Feb 07 #Python
python文件和文件夹复制函数
Feb 07 #Python
tf.concat中axis的含义与使用详解
Feb 07 #Python
浅谈tensorflow 中tf.concat()的使用
Feb 07 #Python
Python for循环通过序列索引迭代过程解析
Feb 07 #Python
python中with用法讲解
Feb 07 #Python
You might like
第九节 绑定 [9]
2006/10/09 PHP
PHP根据树的前序遍历和中序遍历构造树并输出后序遍历的方法
2017/11/10 PHP
PHP队列场景以及实现代码实例详解
2021/02/26 PHP
firefox中用javascript实现鼠标位置的定位
2007/06/17 Javascript
javascript(jquery)利用函数修改全局变量的代码
2009/11/02 Javascript
用JS控制回车事件的代码
2011/02/20 Javascript
file控件选择上传文件确定后触发的js事件是哪个
2014/03/17 Javascript
点击button获取text内容并改变样式的js实现
2014/09/09 Javascript
jQuery实现字符串按指定长度加入特定内容的方法
2015/03/11 Javascript
JavaScript中的继承之类继承
2016/05/01 Javascript
功能强大的Bootstrap组件(结合js)
2016/08/03 Javascript
使用BootStrapValidator完成前端输入验证
2016/09/28 Javascript
JS添加或修改控件的样式(Class)实现方法
2016/10/15 Javascript
详解基于webpack搭建react运行环境
2017/06/01 Javascript
基于three.js编写的一个项目类示例代码
2018/01/05 Javascript
浅谈Webpack多页应用HMR卡住问题
2019/04/24 Javascript
[03:49]辉夜杯现场龙骑士COSER秀情商“我喜欢芬队!”
2015/12/27 DOTA
[40:55]DOTA2上海特级锦标赛主赛事日 - 2 败者组第二轮#4Newbee VS Fnatic
2016/03/03 DOTA
Python爬虫DOTA排行榜爬取实例(分享)
2017/06/13 Python
使用Django和Postgres进行全文搜索的实例代码
2020/02/13 Python
python opencv 实现读取、显示、写入图像的方法
2020/06/08 Python
Ubuntu 20.04安装Pycharm2020.2及锁定到任务栏的问题(小白级操作)
2020/10/29 Python
Sephora丝芙兰澳洲官方网站:国际知名化妆品购物
2016/10/27 全球购物
加拿大廉价机票预订网站:CheapOair.ca
2018/03/04 全球购物
Boden英国官网:英国知名原创时装品牌
2018/11/06 全球购物
幼儿园美术教学反思
2014/01/31 职场文书
大三学生做职业规划:给未来找个方向
2014/02/24 职场文书
初中班主任寄语
2014/04/04 职场文书
大型演出策划方案
2014/05/28 职场文书
学雷锋活动总结报告
2014/06/26 职场文书
药品营销专业毕业生自荐信
2014/07/02 职场文书
向女朋友道歉的话
2015/01/20 职场文书
2015医德医风个人工作总结
2015/04/02 职场文书
《桂花雨》教学反思
2016/02/19 职场文书
2016年清明节期间群众祭祀活动工作总结
2016/04/01 职场文书
JavaWeb 入门篇(3)ServletContext 详解 具体应用
2021/07/16 Java/Android