tensorflow之自定义神经网络层实例


Posted in Python onFebruary 07, 2020

如下所示:

import tensorflow as tf
tfe = tf.contrib.eager

tf.enable_eager_execution()

大多数情况下,在为机器学习模型编写代码时,您希望在比单个操作和单个变量操作更高的抽象级别上操作。

1.关于图层的一些有用操作

许多机器学习模型可以表达为相对简单的图层的组合和堆叠,TensorFlow提供了一组许多常用图层,以及您从头开始或作为组合创建自己的应用程序特定图层的简单方法。TensorFlow在tf.keras包中包含完整的Keras API,而Keras层在构建自己的模型时非常有用。

#在tf.keras.layers包中,图层是对象。要构造一个图层,只需构造一个对象。大多数层将输出维度/通道的数量作为第一个参数。
layer=tf.keras.layers.Dense(100)
#输入维度的数量通常是不必要的,因为它可以在第一次使用图层时推断出来,但如果您想手动指定它,则可以提供它,这在某些复杂模型中很有用。
layer=tf.keras.layers.Dense(10,input_shape=(None,5))
#调用层
layer(tf.zeros([10,5]))
 

#图层有许多有用的方法。例如,您可以通过调用layer.variables来检查图层中的所有变量。在这种情况下,完全连接的层将具有权重和偏差的变量。
variable=layer.variables
# variable[0]
layer.kernel.numpy()
layer.bias

2.自定义图层

实现自己的层的最佳方法是扩展tf.keras.Layer类并实现:

__init__,您可以在其中执行所有与输入无关的初始化

build方法,您知道输入张量的形状,并可以进行其余的初始化

call方法,在这里进行正向传播计算

请注意,您不必等到调用build来创建变量,您也可以在__init__中创建它们。但是,在build中创建它们的优点是它可以根据图层将要操作的输入的形状启用后期变量创建。另一方面,在__init__中创建变量意味着需要明确指定创建变量所需的形状。

class MyDenseLayer(tf.keras.layers.Layer):
 def __init__(self, num_outputs):
  super(MyDenseLayer, self).__init__()
  self.num_outputs = num_outputs
  
 def build(self, input_shape):
  self.kernel = self.add_variable("kernel", 
                  shape=[input_shape[-1].value, 
                      self.num_outputs])
  
 def call(self, input):
  return tf.matmul(input, self.kernel)
 
layer = MyDenseLayer(10)
print(layer(tf.zeros([10, 5])))
print(layer.variables)

3.搭建网络结构

机器学习模型中许多有趣的图层是通过组合现有层来实现的。例如,resnet中的每个residual块是卷积,批量标准化等的组合。

创建包含其他图层的类似图层的东西时使用的主类是tf.keras.Model。实现一个是通过继承自tf.keras.Model完成的。

class ResnetIdentityBlock(tf.keras.Model):
 def __init__(self, kernel_size, filters):
  super(ResnetIdentityBlock, self).__init__(name='')
  filters1, filters2, filters3 = filters
 
  self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))
  self.bn2a = tf.keras.layers.BatchNormalization()
 
  self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')
  self.bn2b = tf.keras.layers.BatchNormalization()
 
  self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))
  self.bn2c = tf.keras.layers.BatchNormalization()
 
 def call(self, input_tensor, training=False):
  x = self.conv2a(input_tensor)
  x = self.bn2a(x, training=training)
  x = tf.nn.relu(x)
 
  x = self.conv2b(x)
  x = self.bn2b(x, training=training)
  x = tf.nn.relu(x)
 
  x = self.conv2c(x)
  x = self.bn2c(x, training=training)
 
  x += input_tensor
  return tf.nn.relu(x)
 
  
block = ResnetIdentityBlock(1, [1, 2, 3])
print(block(tf.zeros([1, 2, 3, 3])))
print([x.name for x in block.variables])

以上这篇tensorflow之自定义神经网络层实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
分析Python编程时利用wxPython来支持多线程的方法
Apr 07 Python
Python实现快速排序算法及去重的快速排序的简单示例
Jun 26 Python
Python多进程库multiprocessing中进程池Pool类的使用详解
Nov 24 Python
python 设置文件编码格式的实现方法
Dec 21 Python
Python从文件中读取数据的方法讲解
Feb 14 Python
python获取微信企业号打卡数据并生成windows计划任务
Apr 30 Python
如何通过python画loss曲线的方法
Jun 26 Python
python变量的存储原理详解
Jul 10 Python
深入学习python多线程与GIL
Aug 26 Python
使用tensorflow DataSet实现高效加载变长文本输入
Jan 20 Python
Python基于yield遍历多个可迭代对象
Mar 12 Python
Python3爬虫带上cookie的实例代码
Jul 28 Python
在tensorflow中设置使用某一块GPU、多GPU、CPU的操作
Feb 07 #Python
谈一谈数组拼接tf.concat()和np.concatenate()的区别
Feb 07 #Python
python文件和文件夹复制函数
Feb 07 #Python
tf.concat中axis的含义与使用详解
Feb 07 #Python
浅谈tensorflow 中tf.concat()的使用
Feb 07 #Python
Python for循环通过序列索引迭代过程解析
Feb 07 #Python
python中with用法讲解
Feb 07 #Python
You might like
轻松修复Discuz!数据库
2008/05/03 PHP
php数组去重实例及分析
2013/11/26 PHP
浅析ThinkPHP缓存之快速缓存(F方法)和动态缓存(S方法)(日常整理)
2015/10/26 PHP
php 中的closure用法详解
2017/06/12 PHP
PHP解密支付宝小程序的加密数据、手机号的示例代码
2021/02/26 PHP
treepanel动态加载数据实现代码
2012/12/15 Javascript
js 通用订单代码
2013/12/23 Javascript
JavaScript获取网页支持表单字符集的方法
2015/04/02 Javascript
jquery实现的仿天猫侧导航tab切换效果
2015/08/24 Javascript
jQuery添加和删除指定标签的方法
2015/12/16 Javascript
vue.js组件vue-waterfall-easy实现瀑布流效果
2017/08/22 Javascript
解决jquery的ajax调取后端数据成功却渲染失败的问题
2018/08/08 jQuery
vue展示dicom文件医疗系统的实现代码
2018/08/27 Javascript
详解vue父子组件关于模态框状态的绑定方案
2019/06/05 Javascript
vue服务端渲染操作简单入门实例分析
2019/08/28 Javascript
使用Layui搭建后台管理界面的操作方法
2019/09/20 Javascript
[32:56]完美世界DOTA2联赛PWL S3 Rebirth vs CPG 第二场 12.11
2020/12/16 DOTA
浅谈Python 字符串格式化输出(format/printf)
2016/07/21 Python
Python实现运行其他程序的四种方式实例分析
2017/08/17 Python
13个最常用的Python深度学习库介绍
2017/10/28 Python
利用Python如何实现数据驱动的接口自动化测试
2018/05/11 Python
python 处理telnet返回的More,以及get想要的那个参数方法
2019/02/14 Python
Python子类继承父类构造函数详解
2019/02/19 Python
pyqt5让图片自适应QLabel大小上以及移除已显示的图片方法
2019/06/21 Python
pytorch 实现张量tensor,图片,CPU,GPU,数组等的转换
2020/01/13 Python
python tkinter实现连连看游戏
2020/11/16 Python
HTML5拖拽的简单实例
2016/05/30 HTML / CSS
HTML5 weui使用笔记
2019/11/21 HTML / CSS
柒牌官方商城:中国男装优秀品牌
2017/06/30 全球购物
瀑布模型都有哪些优缺点
2014/06/23 面试题
《和我们一样享受春天》教学反思
2014/02/07 职场文书
《祁黄羊》教学反思
2014/04/22 职场文书
2015年度合同管理工作总结
2015/05/22 职场文书
学习弘扬焦裕禄精神心得体会
2016/01/23 职场文书
初三语文教学反思
2016/03/03 职场文书
2019数学教师下学期工作总结
2019/06/27 职场文书