tensorflow之自定义神经网络层实例


Posted in Python onFebruary 07, 2020

如下所示:

import tensorflow as tf
tfe = tf.contrib.eager

tf.enable_eager_execution()

大多数情况下,在为机器学习模型编写代码时,您希望在比单个操作和单个变量操作更高的抽象级别上操作。

1.关于图层的一些有用操作

许多机器学习模型可以表达为相对简单的图层的组合和堆叠,TensorFlow提供了一组许多常用图层,以及您从头开始或作为组合创建自己的应用程序特定图层的简单方法。TensorFlow在tf.keras包中包含完整的Keras API,而Keras层在构建自己的模型时非常有用。

#在tf.keras.layers包中,图层是对象。要构造一个图层,只需构造一个对象。大多数层将输出维度/通道的数量作为第一个参数。
layer=tf.keras.layers.Dense(100)
#输入维度的数量通常是不必要的,因为它可以在第一次使用图层时推断出来,但如果您想手动指定它,则可以提供它,这在某些复杂模型中很有用。
layer=tf.keras.layers.Dense(10,input_shape=(None,5))
#调用层
layer(tf.zeros([10,5]))
 

#图层有许多有用的方法。例如,您可以通过调用layer.variables来检查图层中的所有变量。在这种情况下,完全连接的层将具有权重和偏差的变量。
variable=layer.variables
# variable[0]
layer.kernel.numpy()
layer.bias

2.自定义图层

实现自己的层的最佳方法是扩展tf.keras.Layer类并实现:

__init__,您可以在其中执行所有与输入无关的初始化

build方法,您知道输入张量的形状,并可以进行其余的初始化

call方法,在这里进行正向传播计算

请注意,您不必等到调用build来创建变量,您也可以在__init__中创建它们。但是,在build中创建它们的优点是它可以根据图层将要操作的输入的形状启用后期变量创建。另一方面,在__init__中创建变量意味着需要明确指定创建变量所需的形状。

class MyDenseLayer(tf.keras.layers.Layer):
 def __init__(self, num_outputs):
  super(MyDenseLayer, self).__init__()
  self.num_outputs = num_outputs
  
 def build(self, input_shape):
  self.kernel = self.add_variable("kernel", 
                  shape=[input_shape[-1].value, 
                      self.num_outputs])
  
 def call(self, input):
  return tf.matmul(input, self.kernel)
 
layer = MyDenseLayer(10)
print(layer(tf.zeros([10, 5])))
print(layer.variables)

3.搭建网络结构

机器学习模型中许多有趣的图层是通过组合现有层来实现的。例如,resnet中的每个residual块是卷积,批量标准化等的组合。

创建包含其他图层的类似图层的东西时使用的主类是tf.keras.Model。实现一个是通过继承自tf.keras.Model完成的。

class ResnetIdentityBlock(tf.keras.Model):
 def __init__(self, kernel_size, filters):
  super(ResnetIdentityBlock, self).__init__(name='')
  filters1, filters2, filters3 = filters
 
  self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))
  self.bn2a = tf.keras.layers.BatchNormalization()
 
  self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')
  self.bn2b = tf.keras.layers.BatchNormalization()
 
  self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))
  self.bn2c = tf.keras.layers.BatchNormalization()
 
 def call(self, input_tensor, training=False):
  x = self.conv2a(input_tensor)
  x = self.bn2a(x, training=training)
  x = tf.nn.relu(x)
 
  x = self.conv2b(x)
  x = self.bn2b(x, training=training)
  x = tf.nn.relu(x)
 
  x = self.conv2c(x)
  x = self.bn2c(x, training=training)
 
  x += input_tensor
  return tf.nn.relu(x)
 
  
block = ResnetIdentityBlock(1, [1, 2, 3])
print(block(tf.zeros([1, 2, 3, 3])))
print([x.name for x in block.variables])

以上这篇tensorflow之自定义神经网络层实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
人机交互程序 python实现人机对话
Nov 14 Python
Python使用Scrapy爬虫框架全站爬取图片并保存本地的实现代码
Mar 04 Python
PyQt5每天必学之创建窗口居中效果
Apr 19 Python
python之从文件读取数据到list的实例讲解
Apr 19 Python
解决已经安装requests,却依然提示No module named requests问题
May 18 Python
Ubuntu下升级 python3.7.1流程备忘(推荐)
Dec 10 Python
python实现DEM数据的阴影生成的方法
Jul 23 Python
Python3多线程版TCP端口扫描器
Aug 31 Python
python性能测量工具cProfile使用解析
Sep 26 Python
在keras 中获取张量 tensor 的维度大小实例
Jun 10 Python
解决pip install psycopg2出错问题
Jul 09 Python
Python 3.9的到来到底是意味着什么
Oct 14 Python
在tensorflow中设置使用某一块GPU、多GPU、CPU的操作
Feb 07 #Python
谈一谈数组拼接tf.concat()和np.concatenate()的区别
Feb 07 #Python
python文件和文件夹复制函数
Feb 07 #Python
tf.concat中axis的含义与使用详解
Feb 07 #Python
浅谈tensorflow 中tf.concat()的使用
Feb 07 #Python
Python for循环通过序列索引迭代过程解析
Feb 07 #Python
python中with用法讲解
Feb 07 #Python
You might like
一个用于网络的工具函数库
2006/10/09 PHP
php实现二进制和文本相互转换的方法
2015/04/18 PHP
超详细的php用户注册页面填写信息完整实例(附源码)
2015/11/17 PHP
PHP在线打包下载功能示例
2016/10/15 PHP
PHP解压ZIP文件到指定文件夹的方法
2016/11/17 PHP
深入浅出讲解:php的socket通信原理
2016/12/03 PHP
PHP编程实现多维数组按照某个键值排序的方法小结【2种方法】
2017/04/27 PHP
Laravel利用gulp如何构建前端资源详解
2018/06/03 PHP
JQuery为textarea添加maxlength属性的代码
2010/04/07 Javascript
javascript showModalDialog 内跳转页面的问题
2010/11/25 Javascript
疯狂Jquery第一天(Jquery学习笔记)
2012/05/11 Javascript
showModelDialog弹出文件下载窗口的使用示例
2013/11/19 Javascript
js实现图片漂浮效果的方法
2015/03/02 Javascript
JS烟花背景效果实现方法
2015/03/03 Javascript
jQuery实现的左右移动焦点图效果
2016/01/14 Javascript
js实现图片轮播效果学习笔记
2017/07/26 Javascript
vue.js 嵌套循环、if判断、动态删除的实例
2018/03/07 Javascript
Vue.js实现立体计算器
2020/02/22 Javascript
Vuex的各个模块封装的实现
2020/06/05 Javascript
vue同个按钮控制展开和折叠同个事件操作
2020/07/29 Javascript
vue v-for 点击当前行,获取当前行数据及event当前事件对象的操作
2020/09/10 Javascript
Nuxt的路由配置和参数传递方式
2020/11/06 Javascript
Python多线程实例教程
2014/09/06 Python
Django静态资源URL STATIC_ROOT的配置方法
2014/11/08 Python
用Python写王者荣耀刷金币脚本
2017/12/21 Python
Python实现定时备份mysql数据库并把备份数据库邮件发送
2018/03/08 Python
python路径的写法及目录的获取方式
2019/12/26 Python
tensorflow生成多个tfrecord文件实例
2020/02/17 Python
python中openpyxl和xlsxwriter对Excel的操作方法
2021/03/01 Python
大学生咖啡店创业计划书
2014/01/21 职场文书
低碳环保标语
2014/06/12 职场文书
临时用工协议书范本
2014/10/29 职场文书
心理健康教育培训研修感言
2015/11/18 职场文书
《刷子李》教学反思
2016/02/20 职场文书
学校2016年九九重阳节活动总结
2016/04/01 职场文书
导游词之韩国济州岛
2019/10/28 职场文书