解决Keras的自定义lambda层去reshape张量时model保存出错问题


Posted in Python onJuly 01, 2020

前几天忙着参加一个AI Challenger比赛,一直没有更新博客,忙了将近一个月的时间,也没有取得很好的成绩,不过这这段时间内的确学到了很多,就在决赛结束的前一天晚上,准备复现使用一个新的网络UPerNet的时候出现了一个很匪夷所思,莫名其妙的一个问题。谷歌很久都没有解决,最后在一个日语网站上看到了解决方法。

事后想想,这个问题在后面搭建网络的时候会很常见,但是网上却没有人提出解决办法,So, I think that's very necessary for me to note this.

背景

分割网络在进行上采样的时候我用的是双线性插值上采样的,而Keras里面并没有实现双线性插值的函数,所以要自己调用tensorflow里面的tf.image.resize_bilinear()函数来进行resize,如果直接用tf.image.resize_bilinear()函数对Keras张量进行resize的话,会报出异常,大概意思是tenorflow张量不能转换为Keras张量,要想将Kears Tensor转换为 Tensorflow Tensor需要进行自定义层,Keras自定义层的时候需要用到Lambda层来包装。

大概源码(只是大概意思)如下:

from keras.layers import Lambda
import tensorflow as tf
 
first_layer=Input(batch_shape=(None, 64, 32, 3))
f=Conv2D(filters, 3, activation = None, padding = 'same', kernel_initializer = 'glorot_normal',name='last_conv_3')(x)
upsample_bilinear = Lambda(lambda x: tf.image.resize_bilinear(x,size=first_layer.get_shape().as_list()[1:3]))
f=upsample_bilinear(f)

然后编译 这个源码:

optimizer = SGD(lr=0.01, momentum=0.9)
model.compile(optimizer = optimizer, loss = model_dice, metrics = ['accuracy'])
model.save('model.hdf5')

其中要注意到这个tf.image.resize_bilinear()里面的size,我用的是根据张量(first_layer)的形状来做为reshape后的形状,保存模型用的是model.save().然后就会出现以下错误!

异常描述:

在一个epoch完成后保存model时出现下面错误,五个错误提示随机出现:

TypeError: cannot serialize ‘_io.TextIOWrapper' object

TypeError: object.new(PyCapsule) is not safe, use PyCapsule.new()

AttributeError: ‘NoneType' object has no attribute ‘update'

TypeError: cannot deepcopy this pattern object

TypeError: can't pickle module objects

问题分析:

这个有两方面原因:

tf.image.resize_bilinear()中的size不应该用另一个张量的size去指定。

如果用了另一个张量去指定size,用model.save()来保存model是不能序列化的。那么保存model的时候只能保存权重——model.save_weights('mode_weights.hdf5')

解决办法(两种):

1.tf.image.resize_bilinear()的size用常数去指定

upsample_bilinear = Lambda(lambda x: tf.image.resize_bilinear(x,size=[64,32]))

2.如果用了另一个张量去指定size,那么就修改保存模型的函数,变成只保存权重

model.save_weights('model_weights.hdf5')

总结:

​​​​我想使用keras的Lambda层去reshape一个张量

如果为重塑形状指定了张量,则保存模型(保存)将失败

您可以使用save_weights而不是save进行保存

补充知识:Keras 添加一个自定义的loss层(output及compile中,输出及loss的表示方法)

例如:

计算两个层之间的距离,作为一个loss

distance=keras.layers.Lambda(lambda x: tf.norm(x, axis=0))(keras.layers.Subtract(Dense1-Dense2))

这是添加的一个loss层,这个distance就直接作为loss

model=Model(input=[,,,], output=[distance])

model.compile(....., loss=lambda y_true, y_pred: ypred)

以上这篇解决Keras的自定义lambda层去reshape张量时model保存出错问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
高性能web服务器框架Tornado简单实现restful接口及开发实例
Jul 16 Python
Python实现115网盘自动下载的方法
Sep 30 Python
简单介绍Python中的RSS处理
Apr 13 Python
python利用paramiko连接远程服务器执行命令的方法
Oct 16 Python
Python之Scrapy爬虫框架安装及使用详解
Nov 16 Python
分享一个简单的python读写文件脚本
Nov 25 Python
python中找出numpy array数组的最值及其索引方法
Apr 17 Python
如何利用Boost.Python实现Python C/C++混合编程详解
Nov 08 Python
pygame游戏之旅 游戏中添加显示文字
Nov 20 Python
python把ipynb文件转换成pdf文件过程详解
Jul 09 Python
python cv2截取不规则区域图片实例
Dec 21 Python
python中with用法讲解
Feb 07 Python
完美解决keras 读取多个hdf5文件进行训练的问题
Jul 01 #Python
学python需要去培训机构吗
Jul 01 #Python
详解python logging日志传输
Jul 01 #Python
python怎么调用自己的函数
Jul 01 #Python
解决keras模型保存h5文件提示无此目录问题
Jul 01 #Python
如何解决安装python3.6.1失败
Jul 01 #Python
python如何求圆的面积
Jul 01 #Python
You might like
用PHP实现Ftp用户的在线管理的代码
2007/03/06 PHP
php的curl实现get和post的代码
2008/08/23 PHP
php获取网页中图片、DIV内容的简单方法
2014/06/19 PHP
WordPress中访客登陆实现邮件提醒的PHP脚本实例分享
2015/12/14 PHP
PHP响应post请求上传文件的方法
2015/12/17 PHP
THINKPHP截取中文字符串函数实例代码
2017/03/20 PHP
JQuery 学习笔记 选择器之五
2009/07/23 Javascript
javascript Math.random()随机数函数
2009/11/04 Javascript
javaScript 删除字符串空格多种方法小结
2012/10/24 Javascript
Jquery插件实现点击获取验证码后60秒内禁止重新获取
2015/03/13 Javascript
JS实现的3D拖拽翻页效果代码
2015/10/31 Javascript
jQuery 1.9.1源码分析系列(十)事件系统之主动触发事件和模拟冒泡处理
2015/11/24 Javascript
js实现对ajax请求面向对象的封装
2016/01/08 Javascript
jquery 正整数数字校验正则表达式
2017/01/10 Javascript
vue.js实现数据动态响应 Vue.set的简单应用
2017/06/15 Javascript
Vue.js常用指令的使用小结
2017/06/23 Javascript
JavaScript事件委托实现原理及优点进行
2020/08/29 Javascript
vue keep-alive的简单总结
2021/01/25 Vue.js
Python中请使用isinstance()判断变量类型
2014/08/25 Python
Python使用scrapy采集时伪装成HTTP/1.1的方法
2015/04/08 Python
python使用socket远程连接错误处理方法
2015/04/29 Python
Django内容增加富文本功能的实例
2017/10/17 Python
Python向MySQL批量插数据的实例讲解
2018/03/31 Python
如何实现删除numpy.array中的行或列
2018/05/08 Python
在Python中Dataframe通过print输出多行时显示省略号的实例
2018/12/22 Python
python高斯分布概率密度函数的使用详解
2019/07/10 Python
Django对数据库进行添加与更新的例子
2019/07/12 Python
Django单元测试工具test client使用详解
2019/08/02 Python
基于pandas向csv添加新的行和列
2020/05/25 Python
Spring http服务远程调用实现过程解析
2020/06/11 Python
详解CSS3选择器:nth-child和:nth-of-type之间的差异
2017/09/18 HTML / CSS
下面代码从性能上考虑,有什么问题
2015/04/03 面试题
六月份红领巾广播稿
2014/02/03 职场文书
毕业寄语大全
2014/04/09 职场文书
关于运动会广播稿50字
2014/10/18 职场文书
SpringBoot 整合mongoDB并自定义连接池的示例代码
2022/02/28 MongoDB