解决Keras的自定义lambda层去reshape张量时model保存出错问题


Posted in Python onJuly 01, 2020

前几天忙着参加一个AI Challenger比赛,一直没有更新博客,忙了将近一个月的时间,也没有取得很好的成绩,不过这这段时间内的确学到了很多,就在决赛结束的前一天晚上,准备复现使用一个新的网络UPerNet的时候出现了一个很匪夷所思,莫名其妙的一个问题。谷歌很久都没有解决,最后在一个日语网站上看到了解决方法。

事后想想,这个问题在后面搭建网络的时候会很常见,但是网上却没有人提出解决办法,So, I think that's very necessary for me to note this.

背景

分割网络在进行上采样的时候我用的是双线性插值上采样的,而Keras里面并没有实现双线性插值的函数,所以要自己调用tensorflow里面的tf.image.resize_bilinear()函数来进行resize,如果直接用tf.image.resize_bilinear()函数对Keras张量进行resize的话,会报出异常,大概意思是tenorflow张量不能转换为Keras张量,要想将Kears Tensor转换为 Tensorflow Tensor需要进行自定义层,Keras自定义层的时候需要用到Lambda层来包装。

大概源码(只是大概意思)如下:

from keras.layers import Lambda
import tensorflow as tf
 
first_layer=Input(batch_shape=(None, 64, 32, 3))
f=Conv2D(filters, 3, activation = None, padding = 'same', kernel_initializer = 'glorot_normal',name='last_conv_3')(x)
upsample_bilinear = Lambda(lambda x: tf.image.resize_bilinear(x,size=first_layer.get_shape().as_list()[1:3]))
f=upsample_bilinear(f)

然后编译 这个源码:

optimizer = SGD(lr=0.01, momentum=0.9)
model.compile(optimizer = optimizer, loss = model_dice, metrics = ['accuracy'])
model.save('model.hdf5')

其中要注意到这个tf.image.resize_bilinear()里面的size,我用的是根据张量(first_layer)的形状来做为reshape后的形状,保存模型用的是model.save().然后就会出现以下错误!

异常描述:

在一个epoch完成后保存model时出现下面错误,五个错误提示随机出现:

TypeError: cannot serialize ‘_io.TextIOWrapper' object

TypeError: object.new(PyCapsule) is not safe, use PyCapsule.new()

AttributeError: ‘NoneType' object has no attribute ‘update'

TypeError: cannot deepcopy this pattern object

TypeError: can't pickle module objects

问题分析:

这个有两方面原因:

tf.image.resize_bilinear()中的size不应该用另一个张量的size去指定。

如果用了另一个张量去指定size,用model.save()来保存model是不能序列化的。那么保存model的时候只能保存权重——model.save_weights('mode_weights.hdf5')

解决办法(两种):

1.tf.image.resize_bilinear()的size用常数去指定

upsample_bilinear = Lambda(lambda x: tf.image.resize_bilinear(x,size=[64,32]))

2.如果用了另一个张量去指定size,那么就修改保存模型的函数,变成只保存权重

model.save_weights('model_weights.hdf5')

总结:

​​​​我想使用keras的Lambda层去reshape一个张量

如果为重塑形状指定了张量,则保存模型(保存)将失败

您可以使用save_weights而不是save进行保存

补充知识:Keras 添加一个自定义的loss层(output及compile中,输出及loss的表示方法)

例如:

计算两个层之间的距离,作为一个loss

distance=keras.layers.Lambda(lambda x: tf.norm(x, axis=0))(keras.layers.Subtract(Dense1-Dense2))

这是添加的一个loss层,这个distance就直接作为loss

model=Model(input=[,,,], output=[distance])

model.compile(....., loss=lambda y_true, y_pred: ypred)

以上这篇解决Keras的自定义lambda层去reshape张量时model保存出错问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 错误和异常代码详解
Jan 29 Python
Django模板语言 Tags使用详解
Sep 09 Python
python智联招聘爬虫并导入到excel代码实例
Sep 09 Python
Python 字符串、列表、元组的截取与切片操作示例
Sep 17 Python
python网络爬虫 CrawlSpider使用详解
Sep 27 Python
使用python制作一个解压缩软件
Nov 13 Python
python 实现多线程下载m3u8格式视频并使用fmmpeg合并
Nov 15 Python
python求质数列表的例子
Nov 24 Python
在 Python 中接管键盘中断信号的实现方法
Feb 04 Python
使用sklearn的cross_val_score进行交叉验证实例
Feb 28 Python
scrapy利用selenium爬取豆瓣阅读的全步骤
Sep 20 Python
使用python tkinter开发一个爬取B站直播弹幕工具的实现代码
Feb 07 Python
完美解决keras 读取多个hdf5文件进行训练的问题
Jul 01 #Python
学python需要去培训机构吗
Jul 01 #Python
详解python logging日志传输
Jul 01 #Python
python怎么调用自己的函数
Jul 01 #Python
解决keras模型保存h5文件提示无此目录问题
Jul 01 #Python
如何解决安装python3.6.1失败
Jul 01 #Python
python如何求圆的面积
Jul 01 #Python
You might like
php防止CC攻击代码 php防止网页频繁刷新
2015/12/21 PHP
PHP表单数据写入MySQL数据库的代码
2016/05/31 PHP
PHP编程实现计算抽奖概率算法完整实例
2017/08/09 PHP
PHP抽象类基本用法示例
2018/12/28 PHP
php写入mysql中文乱码的实例解决方法
2019/09/17 PHP
很酷的javascript loading效果代码
2008/06/18 Javascript
Javascript绝句欣赏 一些经典的js代码
2012/02/22 Javascript
JavaScript中的私有/静态属性介绍
2012/07/26 Javascript
JS中生成随机数的用法及相关函数
2016/01/09 Javascript
完美的js div拖拽实例代码
2016/09/24 Javascript
Bootstrap图片轮播组件Carousel使用方法详解
2016/10/20 Javascript
VUE+elementui面包屑实现动态路由详解
2019/11/04 Javascript
JS实现小米轮播图
2020/09/21 Javascript
原生小程序封装跑马灯效果
2020/10/21 Javascript
Python基础中所出现的异常报错总结
2016/11/19 Python
Python 3实战爬虫之爬取京东图书的图片详解
2017/10/09 Python
Django数据库连接丢失问题的解决方法
2018/12/29 Python
Django REST framework内置路由用法
2019/07/26 Python
django创建最简单HTML页面跳转方法
2019/08/16 Python
pytorch 自定义数据集加载方法
2019/08/18 Python
python多线程并发及测试框架案例
2019/10/15 Python
Scrapy框架基本命令与settings.py设置
2020/02/06 Python
Python MOCK SERVER moco模拟接口测试过程解析
2020/04/13 Python
丝芙兰巴西官方商城:SEPHORA巴西
2016/10/31 全球购物
澳大利亚网上玩具商店:Mr Toys Toyworld
2018/03/25 全球购物
意大利在线眼镜精品店:Ottica Lipari
2019/11/11 全球购物
校友会欢迎辞
2014/01/13 职场文书
自考生自我评价分享
2014/01/18 职场文书
公司财务流程之主管工作流程
2014/03/03 职场文书
少先队学雷锋活动总结范文
2014/03/09 职场文书
学习三严三实对照检查材料思想汇报
2014/09/22 职场文书
预备党员2014年第四季度思想汇报范文
2014/10/25 职场文书
付款承诺函范文
2015/01/21 职场文书
幼儿园食品安全责任书
2015/05/08 职场文书
党小组推荐意见
2015/06/02 职场文书
七年级数学教学反思
2016/02/17 职场文书