解决Keras的自定义lambda层去reshape张量时model保存出错问题


Posted in Python onJuly 01, 2020

前几天忙着参加一个AI Challenger比赛,一直没有更新博客,忙了将近一个月的时间,也没有取得很好的成绩,不过这这段时间内的确学到了很多,就在决赛结束的前一天晚上,准备复现使用一个新的网络UPerNet的时候出现了一个很匪夷所思,莫名其妙的一个问题。谷歌很久都没有解决,最后在一个日语网站上看到了解决方法。

事后想想,这个问题在后面搭建网络的时候会很常见,但是网上却没有人提出解决办法,So, I think that's very necessary for me to note this.

背景

分割网络在进行上采样的时候我用的是双线性插值上采样的,而Keras里面并没有实现双线性插值的函数,所以要自己调用tensorflow里面的tf.image.resize_bilinear()函数来进行resize,如果直接用tf.image.resize_bilinear()函数对Keras张量进行resize的话,会报出异常,大概意思是tenorflow张量不能转换为Keras张量,要想将Kears Tensor转换为 Tensorflow Tensor需要进行自定义层,Keras自定义层的时候需要用到Lambda层来包装。

大概源码(只是大概意思)如下:

from keras.layers import Lambda
import tensorflow as tf
 
first_layer=Input(batch_shape=(None, 64, 32, 3))
f=Conv2D(filters, 3, activation = None, padding = 'same', kernel_initializer = 'glorot_normal',name='last_conv_3')(x)
upsample_bilinear = Lambda(lambda x: tf.image.resize_bilinear(x,size=first_layer.get_shape().as_list()[1:3]))
f=upsample_bilinear(f)

然后编译 这个源码:

optimizer = SGD(lr=0.01, momentum=0.9)
model.compile(optimizer = optimizer, loss = model_dice, metrics = ['accuracy'])
model.save('model.hdf5')

其中要注意到这个tf.image.resize_bilinear()里面的size,我用的是根据张量(first_layer)的形状来做为reshape后的形状,保存模型用的是model.save().然后就会出现以下错误!

异常描述:

在一个epoch完成后保存model时出现下面错误,五个错误提示随机出现:

TypeError: cannot serialize ‘_io.TextIOWrapper' object

TypeError: object.new(PyCapsule) is not safe, use PyCapsule.new()

AttributeError: ‘NoneType' object has no attribute ‘update'

TypeError: cannot deepcopy this pattern object

TypeError: can't pickle module objects

问题分析:

这个有两方面原因:

tf.image.resize_bilinear()中的size不应该用另一个张量的size去指定。

如果用了另一个张量去指定size,用model.save()来保存model是不能序列化的。那么保存model的时候只能保存权重——model.save_weights('mode_weights.hdf5')

解决办法(两种):

1.tf.image.resize_bilinear()的size用常数去指定

upsample_bilinear = Lambda(lambda x: tf.image.resize_bilinear(x,size=[64,32]))

2.如果用了另一个张量去指定size,那么就修改保存模型的函数,变成只保存权重

model.save_weights('model_weights.hdf5')

总结:

​​​​我想使用keras的Lambda层去reshape一个张量

如果为重塑形状指定了张量,则保存模型(保存)将失败

您可以使用save_weights而不是save进行保存

补充知识:Keras 添加一个自定义的loss层(output及compile中,输出及loss的表示方法)

例如:

计算两个层之间的距离,作为一个loss

distance=keras.layers.Lambda(lambda x: tf.norm(x, axis=0))(keras.layers.Subtract(Dense1-Dense2))

这是添加的一个loss层,这个distance就直接作为loss

model=Model(input=[,,,], output=[distance])

model.compile(....., loss=lambda y_true, y_pred: ypred)

以上这篇解决Keras的自定义lambda层去reshape张量时model保存出错问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python程序中操作文件之flush()方法的使用教程
May 24 Python
python用模块zlib压缩与解压字符串和文件的方法
Dec 16 Python
Python叠加两幅栅格图像的实现方法
Jul 05 Python
Python Django的安装配置教程图文详解
Jul 17 Python
Python高级特性——详解多维数组切片(Slice)
Nov 26 Python
Pytorch evaluation每次运行结果不同的解决
Jan 02 Python
Python类的绑定方法和非绑定方法实例解析
Mar 04 Python
Python3之乱码\xe6\x97\xa0\xe6\xb3\x95处理方式
May 11 Python
Python类super()及私有属性原理解析
Jun 15 Python
django form和field具体方法和属性说明
Jul 09 Python
利用python批量爬取百度任意类别的图片的实现方法
Oct 07 Python
Django restful framework生成API文档过程详解
Nov 12 Python
完美解决keras 读取多个hdf5文件进行训练的问题
Jul 01 #Python
学python需要去培训机构吗
Jul 01 #Python
详解python logging日志传输
Jul 01 #Python
python怎么调用自己的函数
Jul 01 #Python
解决keras模型保存h5文件提示无此目录问题
Jul 01 #Python
如何解决安装python3.6.1失败
Jul 01 #Python
python如何求圆的面积
Jul 01 #Python
You might like
我的论坛源代码(十)
2006/10/09 PHP
php网页后退不再出现过期
2007/03/08 PHP
php中实现记住密码自动登录的代码
2011/03/02 PHP
深入理解PHP内核(一)
2015/11/10 PHP
Javascript 个人笔记(没有整理,很乱)
2007/07/07 Javascript
javascript 常用关键字列表集合
2007/12/04 Javascript
JavaScript 实现类的多种方法实例
2013/05/01 Javascript
JS实现从连接中获取youtube的key实例
2015/07/02 Javascript
JS实现iframe自适应高度的方法(兼容IE与FireFox)
2016/06/24 Javascript
javascript和jQuery中的AJAX技术详解【包含AJAX各种跨域技术】
2016/12/15 Javascript
VueJs单页应用实现微信网页授权及微信分享功能示例
2017/07/26 Javascript
JS实现图片手风琴效果
2020/04/17 Javascript
深入理解Vue.js源码之事件机制
2017/09/27 Javascript
微信小程序MUI导航栏透明渐变功能示例(通过改变rgba的a值实现)
2019/01/24 Javascript
ES2020系列之空值合并运算符 '??'
2020/07/22 Javascript
[02:45]DOTA2英雄基础教程 伐木机
2013/12/23 DOTA
Python获取Windows或Linux主机名称通用函数分享
2014/11/22 Python
Python中动态获取对象的属性和方法的教程
2015/04/09 Python
Python基于whois模块简单识别网站域名及所有者的方法
2018/04/23 Python
Tensorflow卷积神经网络实例进阶
2018/05/24 Python
pandas.DataFrame.to_json按行转json的方法
2018/06/05 Python
详解DeBug Python神级工具PySnooper
2019/07/03 Python
浅谈Python3实现两个矩形的交并比(IoU)
2020/01/18 Python
python Autopep8实现按PEP8风格自动排版Python代码
2021/03/02 Python
HTML5 Blob 实现文件下载功能的示例代码
2019/11/29 HTML / CSS
DHC美国官网:日本通信销售第一的化妆品品牌
2017/11/12 全球购物
澳大利亚最大的在线美发和美容零售商之一:My Hair Care & Beauty
2019/08/24 全球购物
Shopping happy life西班牙:以最优惠的价格提供最好的时尚配饰
2020/03/13 全球购物
使用索引有什么好处
2016/07/27 面试题
酒店服务实习自我鉴定
2013/09/22 职场文书
工程资料员岗位职责
2014/03/10 职场文书
相亲大会策划方案
2014/06/05 职场文书
2015年安全生产目标责任书
2015/01/29 职场文书
新员工试用期工作总结2015
2015/05/28 职场文书
【海涛dota解说】DCG联赛第一周 LGD VS DH
2022/04/01 DOTA
uniapp开发打包多端应用完整方法指南
2022/12/24 Javascript