tensorflow之自定义神经网络层实例


Posted in Python onFebruary 07, 2020

如下所示:

import tensorflow as tf
tfe = tf.contrib.eager

tf.enable_eager_execution()

大多数情况下,在为机器学习模型编写代码时,您希望在比单个操作和单个变量操作更高的抽象级别上操作。

1.关于图层的一些有用操作

许多机器学习模型可以表达为相对简单的图层的组合和堆叠,TensorFlow提供了一组许多常用图层,以及您从头开始或作为组合创建自己的应用程序特定图层的简单方法。TensorFlow在tf.keras包中包含完整的Keras API,而Keras层在构建自己的模型时非常有用。

#在tf.keras.layers包中,图层是对象。要构造一个图层,只需构造一个对象。大多数层将输出维度/通道的数量作为第一个参数。
layer=tf.keras.layers.Dense(100)
#输入维度的数量通常是不必要的,因为它可以在第一次使用图层时推断出来,但如果您想手动指定它,则可以提供它,这在某些复杂模型中很有用。
layer=tf.keras.layers.Dense(10,input_shape=(None,5))
#调用层
layer(tf.zeros([10,5]))
 

#图层有许多有用的方法。例如,您可以通过调用layer.variables来检查图层中的所有变量。在这种情况下,完全连接的层将具有权重和偏差的变量。
variable=layer.variables
# variable[0]
layer.kernel.numpy()
layer.bias

2.自定义图层

实现自己的层的最佳方法是扩展tf.keras.Layer类并实现:

__init__,您可以在其中执行所有与输入无关的初始化

build方法,您知道输入张量的形状,并可以进行其余的初始化

call方法,在这里进行正向传播计算

请注意,您不必等到调用build来创建变量,您也可以在__init__中创建它们。但是,在build中创建它们的优点是它可以根据图层将要操作的输入的形状启用后期变量创建。另一方面,在__init__中创建变量意味着需要明确指定创建变量所需的形状。

class MyDenseLayer(tf.keras.layers.Layer):
 def __init__(self, num_outputs):
  super(MyDenseLayer, self).__init__()
  self.num_outputs = num_outputs
  
 def build(self, input_shape):
  self.kernel = self.add_variable("kernel", 
                  shape=[input_shape[-1].value, 
                      self.num_outputs])
  
 def call(self, input):
  return tf.matmul(input, self.kernel)
 
layer = MyDenseLayer(10)
print(layer(tf.zeros([10, 5])))
print(layer.variables)

3.搭建网络结构

机器学习模型中许多有趣的图层是通过组合现有层来实现的。例如,resnet中的每个residual块是卷积,批量标准化等的组合。

创建包含其他图层的类似图层的东西时使用的主类是tf.keras.Model。实现一个是通过继承自tf.keras.Model完成的。

class ResnetIdentityBlock(tf.keras.Model):
 def __init__(self, kernel_size, filters):
  super(ResnetIdentityBlock, self).__init__(name='')
  filters1, filters2, filters3 = filters
 
  self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))
  self.bn2a = tf.keras.layers.BatchNormalization()
 
  self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')
  self.bn2b = tf.keras.layers.BatchNormalization()
 
  self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))
  self.bn2c = tf.keras.layers.BatchNormalization()
 
 def call(self, input_tensor, training=False):
  x = self.conv2a(input_tensor)
  x = self.bn2a(x, training=training)
  x = tf.nn.relu(x)
 
  x = self.conv2b(x)
  x = self.bn2b(x, training=training)
  x = tf.nn.relu(x)
 
  x = self.conv2c(x)
  x = self.bn2c(x, training=training)
 
  x += input_tensor
  return tf.nn.relu(x)
 
  
block = ResnetIdentityBlock(1, [1, 2, 3])
print(block(tf.zeros([1, 2, 3, 3])))
print([x.name for x in block.variables])

以上这篇tensorflow之自定义神经网络层实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python logging模块学习笔记
May 24 Python
Python自动化测试ConfigParser模块读写配置文件
Aug 15 Python
python中学习K-Means和图片压缩
Nov 20 Python
python plotly绘制直方图实例详解
Jul 22 Python
安装Pycharm2019以及配置anconda教程的方法步骤
Nov 11 Python
使用Pandas的Series方法绘制图像教程
Dec 04 Python
浅谈ROC曲线的最佳阈值如何选取
Feb 28 Python
基于python实现监听Rabbitmq系统日志代码示例
Nov 28 Python
使用pandas读取表格数据并进行单行数据拼接的详细教程
Mar 03 Python
Python import模块的缓存问题解决方案
Jun 02 Python
Python可视化学习之seaborn调色盘
Feb 24 Python
python自动获取微信公众号最新文章的实现代码
Jul 15 Python
在tensorflow中设置使用某一块GPU、多GPU、CPU的操作
Feb 07 #Python
谈一谈数组拼接tf.concat()和np.concatenate()的区别
Feb 07 #Python
python文件和文件夹复制函数
Feb 07 #Python
tf.concat中axis的含义与使用详解
Feb 07 #Python
浅谈tensorflow 中tf.concat()的使用
Feb 07 #Python
Python for循环通过序列索引迭代过程解析
Feb 07 #Python
python中with用法讲解
Feb 07 #Python
You might like
UCenter中的一个可逆加密函数authcode函数代码
2010/07/20 PHP
set_include_path和get_include_path使用及注意事项
2013/02/02 PHP
php实现批量上传数据到数据库(.csv格式)的案例
2017/06/18 PHP
javascript 网页跳转的方法
2008/12/24 Javascript
js验证整数加保留小数点的简单实例
2013/12/02 Javascript
带左右箭头图片轮播的JS代码
2013/12/18 Javascript
jquery鼠标停止移动事件
2013/12/21 Javascript
超级简单实现JavaScript MVC 样式框架
2015/03/24 Javascript
JS判断输入字符串长度实例代码(汉字算两个字符,字母数字算一个)
2016/08/02 Javascript
微信小程序之拖拽排序(代码分享)
2017/01/21 Javascript
判断JavaScript中的两个变量是否相等的操作符
2019/12/21 Javascript
vue 封装 Adminlte3组件的实现
2020/03/18 Javascript
Vue左滑组件slider使用详解
2020/08/21 Javascript
详解webpack的clean-webpack-plugin插件报错
2020/10/16 Javascript
Vue实现一种简单的无限循环滚动动画的示例
2021/01/10 Vue.js
Django实现登录随机验证码的示例代码
2018/06/20 Python
python之拟合的实现
2019/07/19 Python
django使用haystack调用Elasticsearch实现索引搜索
2019/07/24 Python
Python Django实现layui风格+django分页功能的例子
2019/08/29 Python
Python如何使用Gitlab API实现批量的合并分支
2019/11/27 Python
Python装饰器用法与知识点小结
2020/03/09 Python
Python实现寻找回文数字过程解析
2020/06/09 Python
numpy的Fancy Indexing和array比较详解
2020/06/11 Python
Django使用django-simple-captcha做验证码的实现示例
2021/01/07 Python
巴西最大的珠宝连锁店:Vivara
2019/04/18 全球购物
J2EE面试题大全
2016/08/06 面试题
大专应届生个人的自我评价
2013/11/21 职场文书
物流专业大学生职业生涯规划书范文
2014/01/15 职场文书
家长对学生的评语
2014/04/18 职场文书
给老婆的保证书范文
2014/04/28 职场文书
2014年预备党员端正入党动机思想汇报
2014/09/13 职场文书
社团个人总结范文
2015/03/05 职场文书
教师自荐信范文
2015/03/06 职场文书
Python爬虫之爬取二手房信息
2021/04/27 Python
磁贴还没死, 微软Win11可修改注册表找回Win10开始菜单
2021/11/21 数码科技
python井字棋游戏实现人机对战
2022/04/28 Python