tensorflow之自定义神经网络层实例


Posted in Python onFebruary 07, 2020

如下所示:

import tensorflow as tf
tfe = tf.contrib.eager

tf.enable_eager_execution()

大多数情况下,在为机器学习模型编写代码时,您希望在比单个操作和单个变量操作更高的抽象级别上操作。

1.关于图层的一些有用操作

许多机器学习模型可以表达为相对简单的图层的组合和堆叠,TensorFlow提供了一组许多常用图层,以及您从头开始或作为组合创建自己的应用程序特定图层的简单方法。TensorFlow在tf.keras包中包含完整的Keras API,而Keras层在构建自己的模型时非常有用。

#在tf.keras.layers包中,图层是对象。要构造一个图层,只需构造一个对象。大多数层将输出维度/通道的数量作为第一个参数。
layer=tf.keras.layers.Dense(100)
#输入维度的数量通常是不必要的,因为它可以在第一次使用图层时推断出来,但如果您想手动指定它,则可以提供它,这在某些复杂模型中很有用。
layer=tf.keras.layers.Dense(10,input_shape=(None,5))
#调用层
layer(tf.zeros([10,5]))
 

#图层有许多有用的方法。例如,您可以通过调用layer.variables来检查图层中的所有变量。在这种情况下,完全连接的层将具有权重和偏差的变量。
variable=layer.variables
# variable[0]
layer.kernel.numpy()
layer.bias

2.自定义图层

实现自己的层的最佳方法是扩展tf.keras.Layer类并实现:

__init__,您可以在其中执行所有与输入无关的初始化

build方法,您知道输入张量的形状,并可以进行其余的初始化

call方法,在这里进行正向传播计算

请注意,您不必等到调用build来创建变量,您也可以在__init__中创建它们。但是,在build中创建它们的优点是它可以根据图层将要操作的输入的形状启用后期变量创建。另一方面,在__init__中创建变量意味着需要明确指定创建变量所需的形状。

class MyDenseLayer(tf.keras.layers.Layer):
 def __init__(self, num_outputs):
  super(MyDenseLayer, self).__init__()
  self.num_outputs = num_outputs
  
 def build(self, input_shape):
  self.kernel = self.add_variable("kernel", 
                  shape=[input_shape[-1].value, 
                      self.num_outputs])
  
 def call(self, input):
  return tf.matmul(input, self.kernel)
 
layer = MyDenseLayer(10)
print(layer(tf.zeros([10, 5])))
print(layer.variables)

3.搭建网络结构

机器学习模型中许多有趣的图层是通过组合现有层来实现的。例如,resnet中的每个residual块是卷积,批量标准化等的组合。

创建包含其他图层的类似图层的东西时使用的主类是tf.keras.Model。实现一个是通过继承自tf.keras.Model完成的。

class ResnetIdentityBlock(tf.keras.Model):
 def __init__(self, kernel_size, filters):
  super(ResnetIdentityBlock, self).__init__(name='')
  filters1, filters2, filters3 = filters
 
  self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))
  self.bn2a = tf.keras.layers.BatchNormalization()
 
  self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')
  self.bn2b = tf.keras.layers.BatchNormalization()
 
  self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))
  self.bn2c = tf.keras.layers.BatchNormalization()
 
 def call(self, input_tensor, training=False):
  x = self.conv2a(input_tensor)
  x = self.bn2a(x, training=training)
  x = tf.nn.relu(x)
 
  x = self.conv2b(x)
  x = self.bn2b(x, training=training)
  x = tf.nn.relu(x)
 
  x = self.conv2c(x)
  x = self.bn2c(x, training=training)
 
  x += input_tensor
  return tf.nn.relu(x)
 
  
block = ResnetIdentityBlock(1, [1, 2, 3])
print(block(tf.zeros([1, 2, 3, 3])))
print([x.name for x in block.variables])

以上这篇tensorflow之自定义神经网络层实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python开发编码规范
Sep 08 Python
Linux下用Python脚本监控目录变化代码分享
May 21 Python
python读写json文件的简单实现
Apr 11 Python
python Socket之客户端和服务端握手详解
Sep 18 Python
python学习之hook钩子的原理和使用
Oct 25 Python
Python实现DDos攻击实例详解
Feb 02 Python
python使用Plotly绘图工具绘制气泡图
Apr 01 Python
python爬取基于m3u8协议的ts文件并合并
Apr 26 Python
Python函数中参数是传递值还是引用详解
Jul 02 Python
Mac 使用python3的matplot画图不显示的解决
Nov 23 Python
Python读写操作csv和excle文件代码实例
Mar 16 Python
python求numpy中array按列非零元素的平均值案例
Jun 08 Python
在tensorflow中设置使用某一块GPU、多GPU、CPU的操作
Feb 07 #Python
谈一谈数组拼接tf.concat()和np.concatenate()的区别
Feb 07 #Python
python文件和文件夹复制函数
Feb 07 #Python
tf.concat中axis的含义与使用详解
Feb 07 #Python
浅谈tensorflow 中tf.concat()的使用
Feb 07 #Python
Python for循环通过序列索引迭代过程解析
Feb 07 #Python
python中with用法讲解
Feb 07 #Python
You might like
phpexcel导出excel的颜色和网页中的颜色显示不一致
2012/12/11 PHP
ThinkPHP通过AJAX返回JSON的两种实现方法
2014/12/18 PHP
PHP获取文件行数的方法
2015/06/10 PHP
php实现微信企业号支付个人的方法详解
2017/07/26 PHP
php基于Redis消息队列实现的消息推送的方法
2018/11/28 PHP
thinkphp框架表单数组实现图片批量上传功能示例
2020/04/04 PHP
jQuery asp.net 用json格式返回自定义对象
2010/04/07 Javascript
实例讲解避免javascript冲突的方法
2016/01/03 Javascript
js和jquery实现监听键盘事件示例代码
2020/06/24 Javascript
JavaScript编写一个简易购物车功能
2016/09/17 Javascript
jQuery插件FusionCharts实现的2D饼状图效果【附demo源码下载】
2017/03/03 Javascript
js实现股票实时刷新数据案例
2017/05/14 Javascript
JS获取填报扩展单元格控件的值的解决办法
2017/07/14 Javascript
JS中精巧的自动柯里化实现方法
2017/12/12 Javascript
vue源码解析之事件机制原理
2018/04/21 Javascript
微信小程序实现简单评论功能
2018/11/28 Javascript
JS跨域请求的问题解析
2018/12/03 Javascript
Vue form表单动态添加组件实战案例
2019/09/02 Javascript
vue实现网络图片瀑布流 + 下拉刷新 + 上拉加载更多(步骤详解)
2020/01/14 Javascript
[41:17]完美世界DOTA2联赛PWL S3 access vs CPG 第二场 12.13
2020/12/17 DOTA
一篇文章入门Python生态系统(Python新手入门指导)
2015/12/11 Python
Python3用tkinter和PIL实现看图工具
2018/06/21 Python
win10安装tensorflow-gpu1.8.0详细完整步骤
2020/01/20 Python
Python对Tornado请求与响应的数据处理
2020/02/12 Python
Python如何把多个PDF文件合并代码实例
2020/02/13 Python
Python requests模块基础使用方法实例及高级应用(自动登陆,抓取网页源码)实例详解
2020/02/14 Python
2分钟教你实现环形/扇形菜单(基础版)
2020/01/15 HTML / CSS
ColourPop美国官网:卡拉泡泡,洛杉矶彩妆品牌
2019/04/28 全球购物
SQL面试题
2013/04/30 面试题
校园奶茶店创业计划书
2014/01/23 职场文书
查摆问题自查报告范文
2014/10/13 职场文书
2014年纪检监察工作总结
2014/11/11 职场文书
幼儿园大班毕业评语
2014/12/31 职场文书
运动员代表致辞
2015/07/29 职场文书
MySQL优化之慢日志查询
2022/06/10 MySQL
Java多线程并发FutureTask使用详解
2022/06/28 Java/Android