tensorflow之自定义神经网络层实例


Posted in Python onFebruary 07, 2020

如下所示:

import tensorflow as tf
tfe = tf.contrib.eager

tf.enable_eager_execution()

大多数情况下,在为机器学习模型编写代码时,您希望在比单个操作和单个变量操作更高的抽象级别上操作。

1.关于图层的一些有用操作

许多机器学习模型可以表达为相对简单的图层的组合和堆叠,TensorFlow提供了一组许多常用图层,以及您从头开始或作为组合创建自己的应用程序特定图层的简单方法。TensorFlow在tf.keras包中包含完整的Keras API,而Keras层在构建自己的模型时非常有用。

#在tf.keras.layers包中,图层是对象。要构造一个图层,只需构造一个对象。大多数层将输出维度/通道的数量作为第一个参数。
layer=tf.keras.layers.Dense(100)
#输入维度的数量通常是不必要的,因为它可以在第一次使用图层时推断出来,但如果您想手动指定它,则可以提供它,这在某些复杂模型中很有用。
layer=tf.keras.layers.Dense(10,input_shape=(None,5))
#调用层
layer(tf.zeros([10,5]))
 

#图层有许多有用的方法。例如,您可以通过调用layer.variables来检查图层中的所有变量。在这种情况下,完全连接的层将具有权重和偏差的变量。
variable=layer.variables
# variable[0]
layer.kernel.numpy()
layer.bias

2.自定义图层

实现自己的层的最佳方法是扩展tf.keras.Layer类并实现:

__init__,您可以在其中执行所有与输入无关的初始化

build方法,您知道输入张量的形状,并可以进行其余的初始化

call方法,在这里进行正向传播计算

请注意,您不必等到调用build来创建变量,您也可以在__init__中创建它们。但是,在build中创建它们的优点是它可以根据图层将要操作的输入的形状启用后期变量创建。另一方面,在__init__中创建变量意味着需要明确指定创建变量所需的形状。

class MyDenseLayer(tf.keras.layers.Layer):
 def __init__(self, num_outputs):
  super(MyDenseLayer, self).__init__()
  self.num_outputs = num_outputs
  
 def build(self, input_shape):
  self.kernel = self.add_variable("kernel", 
                  shape=[input_shape[-1].value, 
                      self.num_outputs])
  
 def call(self, input):
  return tf.matmul(input, self.kernel)
 
layer = MyDenseLayer(10)
print(layer(tf.zeros([10, 5])))
print(layer.variables)

3.搭建网络结构

机器学习模型中许多有趣的图层是通过组合现有层来实现的。例如,resnet中的每个residual块是卷积,批量标准化等的组合。

创建包含其他图层的类似图层的东西时使用的主类是tf.keras.Model。实现一个是通过继承自tf.keras.Model完成的。

class ResnetIdentityBlock(tf.keras.Model):
 def __init__(self, kernel_size, filters):
  super(ResnetIdentityBlock, self).__init__(name='')
  filters1, filters2, filters3 = filters
 
  self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))
  self.bn2a = tf.keras.layers.BatchNormalization()
 
  self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')
  self.bn2b = tf.keras.layers.BatchNormalization()
 
  self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))
  self.bn2c = tf.keras.layers.BatchNormalization()
 
 def call(self, input_tensor, training=False):
  x = self.conv2a(input_tensor)
  x = self.bn2a(x, training=training)
  x = tf.nn.relu(x)
 
  x = self.conv2b(x)
  x = self.bn2b(x, training=training)
  x = tf.nn.relu(x)
 
  x = self.conv2c(x)
  x = self.bn2c(x, training=training)
 
  x += input_tensor
  return tf.nn.relu(x)
 
  
block = ResnetIdentityBlock(1, [1, 2, 3])
print(block(tf.zeros([1, 2, 3, 3])))
print([x.name for x in block.variables])

以上这篇tensorflow之自定义神经网络层实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现远程调用MetaSploit的方法
Aug 22 Python
Django框架下在URLconf中指定视图缓存的方法
Jul 23 Python
Django中的Signal代码详解
Feb 05 Python
解决python3 urllib 链接中有中文的问题
Jul 16 Python
Flask和Django框架中自定义模型类的表名、父类相关问题分析
Jul 19 Python
Python 删除整个文本中的空格,并实现按行显示
Jul 24 Python
pytorch permute维度转换方法
Dec 14 Python
pandas的连接函数concat()函数的具体使用方法
Jul 09 Python
python对矩阵进行转置的2种处理方法
Jul 17 Python
Python 实现毫秒级淘宝抢购脚本的示例代码
Sep 16 Python
Python 可变类型和不可变类型及引用过程解析
Sep 27 Python
利用Selenium添加cookie实现自动登录的示例代码(fofa)
May 08 Python
在tensorflow中设置使用某一块GPU、多GPU、CPU的操作
Feb 07 #Python
谈一谈数组拼接tf.concat()和np.concatenate()的区别
Feb 07 #Python
python文件和文件夹复制函数
Feb 07 #Python
tf.concat中axis的含义与使用详解
Feb 07 #Python
浅谈tensorflow 中tf.concat()的使用
Feb 07 #Python
Python for循环通过序列索引迭代过程解析
Feb 07 #Python
python中with用法讲解
Feb 07 #Python
You might like
PHP+Ajax实现的无刷新分页功能详解【附demo源码下载】
2017/07/03 PHP
JS中confirm,alert,prompt函数区别分析
2011/01/17 Javascript
js 获取时间间隔实现代码
2014/05/12 Javascript
nodejs教程之异步I/O
2014/11/21 NodeJs
了解Javascript的模块化开发
2015/03/02 Javascript
JavaScript数组前面插入元素的方法
2015/04/06 Javascript
jquery实现Ctrl+Enter提交表单的方法
2015/07/21 Javascript
js表单验证实例讲解
2016/03/31 Javascript
JS留言功能的简单实现案例(推荐)
2016/06/23 Javascript
jQuery 3.0十大新特性最终版发布
2016/07/14 Javascript
深入理解React中es6创建组件this的方法
2016/08/29 Javascript
KnockoutJS 3.X API 第四章之表单value绑定
2016/10/10 Javascript
通过修改360抢票的刷新频率和突破8车次限制实现方法
2017/01/04 Javascript
Vue.js路由vue-router使用方法详解
2017/03/20 Javascript
详解node-ccap模块生成captcha验证码
2017/07/01 Javascript
JS模拟超市简易收银台小程序代码解析
2017/08/18 Javascript
微信小程序实现签到功能
2018/10/31 Javascript
Vue.js上传图片到阿里云OSS存储的方法示例
2018/12/13 Javascript
vue props 单项数据流实例分享
2020/02/16 Javascript
js的Object.assign用法示例分析
2020/03/05 Javascript
python使用knn实现特征向量分类
2018/12/26 Python
Pandas之Fillna填充缺失数据的方法
2019/06/25 Python
Django文件存储 默认存储系统解析
2019/08/02 Python
Python中sys模块功能与用法实例详解
2020/02/26 Python
jupyter notebook 实现matplotlib图动态刷新
2020/04/22 Python
python 串行执行和并行执行实例
2020/04/30 Python
html5自定义video标签的海报与播放按钮功能
2019/12/04 HTML / CSS
Mio Skincare美国官网:身体紧致及孕期身体护理
2017/03/05 全球购物
三维科技面试题
2013/07/27 面试题
中学校庆方案
2014/03/17 职场文书
学校食堂食品安全责任书
2014/07/28 职场文书
父亲节活动策划方案
2014/08/24 职场文书
2014机关干部学习“焦裕禄精神”思想汇报
2014/09/19 职场文书
聚会通知怎么写
2015/04/23 职场文书
百善孝为先:关于孝道的经典语录
2019/10/18 职场文书
使用python+pygame开发消消乐游戏附完整源码
2021/06/10 Python