解决Keras的自定义lambda层去reshape张量时model保存出错问题


Posted in Python onJuly 01, 2020

前几天忙着参加一个AI Challenger比赛,一直没有更新博客,忙了将近一个月的时间,也没有取得很好的成绩,不过这这段时间内的确学到了很多,就在决赛结束的前一天晚上,准备复现使用一个新的网络UPerNet的时候出现了一个很匪夷所思,莫名其妙的一个问题。谷歌很久都没有解决,最后在一个日语网站上看到了解决方法。

事后想想,这个问题在后面搭建网络的时候会很常见,但是网上却没有人提出解决办法,So, I think that's very necessary for me to note this.

背景

分割网络在进行上采样的时候我用的是双线性插值上采样的,而Keras里面并没有实现双线性插值的函数,所以要自己调用tensorflow里面的tf.image.resize_bilinear()函数来进行resize,如果直接用tf.image.resize_bilinear()函数对Keras张量进行resize的话,会报出异常,大概意思是tenorflow张量不能转换为Keras张量,要想将Kears Tensor转换为 Tensorflow Tensor需要进行自定义层,Keras自定义层的时候需要用到Lambda层来包装。

大概源码(只是大概意思)如下:

from keras.layers import Lambda
import tensorflow as tf
 
first_layer=Input(batch_shape=(None, 64, 32, 3))
f=Conv2D(filters, 3, activation = None, padding = 'same', kernel_initializer = 'glorot_normal',name='last_conv_3')(x)
upsample_bilinear = Lambda(lambda x: tf.image.resize_bilinear(x,size=first_layer.get_shape().as_list()[1:3]))
f=upsample_bilinear(f)

然后编译 这个源码:

optimizer = SGD(lr=0.01, momentum=0.9)
model.compile(optimizer = optimizer, loss = model_dice, metrics = ['accuracy'])
model.save('model.hdf5')

其中要注意到这个tf.image.resize_bilinear()里面的size,我用的是根据张量(first_layer)的形状来做为reshape后的形状,保存模型用的是model.save().然后就会出现以下错误!

异常描述:

在一个epoch完成后保存model时出现下面错误,五个错误提示随机出现:

TypeError: cannot serialize ‘_io.TextIOWrapper' object

TypeError: object.new(PyCapsule) is not safe, use PyCapsule.new()

AttributeError: ‘NoneType' object has no attribute ‘update'

TypeError: cannot deepcopy this pattern object

TypeError: can't pickle module objects

问题分析:

这个有两方面原因:

tf.image.resize_bilinear()中的size不应该用另一个张量的size去指定。

如果用了另一个张量去指定size,用model.save()来保存model是不能序列化的。那么保存model的时候只能保存权重——model.save_weights('mode_weights.hdf5')

解决办法(两种):

1.tf.image.resize_bilinear()的size用常数去指定

upsample_bilinear = Lambda(lambda x: tf.image.resize_bilinear(x,size=[64,32]))

2.如果用了另一个张量去指定size,那么就修改保存模型的函数,变成只保存权重

model.save_weights('model_weights.hdf5')

总结:

​​​​我想使用keras的Lambda层去reshape一个张量

如果为重塑形状指定了张量,则保存模型(保存)将失败

您可以使用save_weights而不是save进行保存

补充知识:Keras 添加一个自定义的loss层(output及compile中,输出及loss的表示方法)

例如:

计算两个层之间的距离,作为一个loss

distance=keras.layers.Lambda(lambda x: tf.norm(x, axis=0))(keras.layers.Subtract(Dense1-Dense2))

这是添加的一个loss层,这个distance就直接作为loss

model=Model(input=[,,,], output=[distance])

model.compile(....., loss=lambda y_true, y_pred: ypred)

以上这篇解决Keras的自定义lambda层去reshape张量时model保存出错问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python计算书页码的统计数字问题实例
Sep 26 Python
对于Python的Django框架使用的一些实用建议
Apr 03 Python
探究Python中isalnum()方法的使用
May 18 Python
Python中使用bidict模块双向字典结构的奇技淫巧
Jul 12 Python
ansible作为python模块库使用的方法实例
Jan 17 Python
Python使用matplotlib填充图形指定区域代码示例
Jan 16 Python
用tensorflow搭建CNN的方法
Mar 05 Python
ZABBIX3.2使用python脚本实现监控报表的方法
Jul 02 Python
Django基于Models定制Admin后台实现过程解析
Nov 11 Python
只需要100行Python代码就可以实现的贪吃蛇小游戏
May 27 Python
Python Django ORM连表正反操作技巧
Jun 13 Python
Python字典的基础操作
Nov 01 Python
完美解决keras 读取多个hdf5文件进行训练的问题
Jul 01 #Python
学python需要去培训机构吗
Jul 01 #Python
详解python logging日志传输
Jul 01 #Python
python怎么调用自己的函数
Jul 01 #Python
解决keras模型保存h5文件提示无此目录问题
Jul 01 #Python
如何解决安装python3.6.1失败
Jul 01 #Python
python如何求圆的面积
Jul 01 #Python
You might like
php jquery 实现新闻标签分类与无刷新分页
2009/12/18 PHP
五款常用mysql slow log分析工具的比较分析
2011/05/22 PHP
详谈PHP面向对象中常用的关键字和魔术方法
2017/02/04 PHP
php基于自定义函数记录log日志方法
2017/07/21 PHP
PHP实现QQ、微信和支付宝三合一收款码实例代码
2018/02/19 PHP
关于JavaScript的gzip静态压缩方法
2007/01/05 Javascript
js动态创建表格,删除行列的小例子
2013/07/20 Javascript
jquery indexOf使用方法
2013/08/19 Javascript
jquery实现div阴影效果示例代码
2013/09/16 Javascript
js实现文字垂直滚动和鼠标悬停效果
2015/12/31 Javascript
使用three.js 画渐变的直线
2016/06/05 Javascript
BootStrap 可编辑表Table格
2016/11/24 Javascript
深入学习jQuery中的data()
2016/12/22 Javascript
基于jQuery实现火焰灯效果导航菜单
2017/01/04 Javascript
Vue-Access-Control 前端用户权限控制解决方案
2017/12/01 Javascript
nodejs更改项目端口号的方法
2018/05/13 NodeJs
AngularJS自定义过滤器用法经典实例总结
2018/05/17 Javascript
vxe-table vue table 表格组件功能
2019/05/26 Javascript
微信小程序全局变量GLOBALDATA的定义和调用过程解析
2019/09/23 Javascript
javascript实现简单打字游戏
2019/10/29 Javascript
深度解读vue-resize的具体用法
2020/07/08 Javascript
[02:27]2018DOTA2亚洲邀请赛趣味视频之钓鱼大赛 谁是垂钓冠军?
2018/04/05 DOTA
python根据文件大小打log日志
2014/10/09 Python
python 禁止函数修改列表的实现方法
2017/08/03 Python
Python实现读取Properties配置文件的方法
2018/03/29 Python
python实现关键词提取的示例讲解
2018/04/28 Python
Linux下python与C++使用dlib实现人脸检测
2018/06/29 Python
python MNIST手写识别数据调用API的方法
2018/08/08 Python
对python打乱数据集中X,y标签对的方法详解
2018/12/14 Python
What is view? why do we have view?
2012/06/22 面试题
护士求职自荐信范文
2014/03/19 职场文书
党委班子纠正“四风”问题整改措施
2014/10/28 职场文书
《清澈的湖水》教学反思
2016/02/17 职场文书
MySQL 如何限制一张表的记录数
2021/09/14 MySQL
php双向队列实例讲解
2021/11/17 PHP
戴尔Win11系统no bootable devices found解决教程
2022/09/23 数码科技