解决Keras的自定义lambda层去reshape张量时model保存出错问题


Posted in Python onJuly 01, 2020

前几天忙着参加一个AI Challenger比赛,一直没有更新博客,忙了将近一个月的时间,也没有取得很好的成绩,不过这这段时间内的确学到了很多,就在决赛结束的前一天晚上,准备复现使用一个新的网络UPerNet的时候出现了一个很匪夷所思,莫名其妙的一个问题。谷歌很久都没有解决,最后在一个日语网站上看到了解决方法。

事后想想,这个问题在后面搭建网络的时候会很常见,但是网上却没有人提出解决办法,So, I think that's very necessary for me to note this.

背景

分割网络在进行上采样的时候我用的是双线性插值上采样的,而Keras里面并没有实现双线性插值的函数,所以要自己调用tensorflow里面的tf.image.resize_bilinear()函数来进行resize,如果直接用tf.image.resize_bilinear()函数对Keras张量进行resize的话,会报出异常,大概意思是tenorflow张量不能转换为Keras张量,要想将Kears Tensor转换为 Tensorflow Tensor需要进行自定义层,Keras自定义层的时候需要用到Lambda层来包装。

大概源码(只是大概意思)如下:

from keras.layers import Lambda
import tensorflow as tf
 
first_layer=Input(batch_shape=(None, 64, 32, 3))
f=Conv2D(filters, 3, activation = None, padding = 'same', kernel_initializer = 'glorot_normal',name='last_conv_3')(x)
upsample_bilinear = Lambda(lambda x: tf.image.resize_bilinear(x,size=first_layer.get_shape().as_list()[1:3]))
f=upsample_bilinear(f)

然后编译 这个源码:

optimizer = SGD(lr=0.01, momentum=0.9)
model.compile(optimizer = optimizer, loss = model_dice, metrics = ['accuracy'])
model.save('model.hdf5')

其中要注意到这个tf.image.resize_bilinear()里面的size,我用的是根据张量(first_layer)的形状来做为reshape后的形状,保存模型用的是model.save().然后就会出现以下错误!

异常描述:

在一个epoch完成后保存model时出现下面错误,五个错误提示随机出现:

TypeError: cannot serialize ‘_io.TextIOWrapper' object

TypeError: object.new(PyCapsule) is not safe, use PyCapsule.new()

AttributeError: ‘NoneType' object has no attribute ‘update'

TypeError: cannot deepcopy this pattern object

TypeError: can't pickle module objects

问题分析:

这个有两方面原因:

tf.image.resize_bilinear()中的size不应该用另一个张量的size去指定。

如果用了另一个张量去指定size,用model.save()来保存model是不能序列化的。那么保存model的时候只能保存权重——model.save_weights('mode_weights.hdf5')

解决办法(两种):

1.tf.image.resize_bilinear()的size用常数去指定

upsample_bilinear = Lambda(lambda x: tf.image.resize_bilinear(x,size=[64,32]))

2.如果用了另一个张量去指定size,那么就修改保存模型的函数,变成只保存权重

model.save_weights('model_weights.hdf5')

总结:

​​​​我想使用keras的Lambda层去reshape一个张量

如果为重塑形状指定了张量,则保存模型(保存)将失败

您可以使用save_weights而不是save进行保存

补充知识:Keras 添加一个自定义的loss层(output及compile中,输出及loss的表示方法)

例如:

计算两个层之间的距离,作为一个loss

distance=keras.layers.Lambda(lambda x: tf.norm(x, axis=0))(keras.layers.Subtract(Dense1-Dense2))

这是添加的一个loss层,这个distance就直接作为loss

model=Model(input=[,,,], output=[distance])

model.compile(....., loss=lambda y_true, y_pred: ypred)

以上这篇解决Keras的自定义lambda层去reshape张量时model保存出错问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的is和id用法分析
Jan 26 Python
在Django的URLconf中进行函数导入的方法
Jul 18 Python
Python结合ImageMagick实现多张图片合并为一个pdf文件的方法
Apr 24 Python
Python Flask前后端Ajax交互的方法示例
Jul 31 Python
使用python opencv对目录下图片进行去重的方法
Jan 12 Python
python高斯分布概率密度函数的使用详解
Jul 10 Python
关于pytorch中网络loss传播和参数更新的理解
Aug 20 Python
pytorch 求网络模型参数实例
Dec 30 Python
Python实现点云投影到平面显示
Jan 18 Python
python实现滑雪游戏
Feb 22 Python
Django模型中字段属性choice使用说明
Mar 30 Python
python语言中pandas字符串分割str.split()函数
Aug 05 Python
完美解决keras 读取多个hdf5文件进行训练的问题
Jul 01 #Python
学python需要去培训机构吗
Jul 01 #Python
详解python logging日志传输
Jul 01 #Python
python怎么调用自己的函数
Jul 01 #Python
解决keras模型保存h5文件提示无此目录问题
Jul 01 #Python
如何解决安装python3.6.1失败
Jul 01 #Python
python如何求圆的面积
Jul 01 #Python
You might like
如何将数据从文本导入到mysql
2006/10/09 PHP
PHP面向对象的使用教程 简单数据库连接
2006/11/25 PHP
PHP获取url的函数代码
2011/08/02 PHP
php的array_multisort()使用方法介绍
2012/05/16 PHP
php实现的xml操作类
2016/01/15 PHP
php实现留言板功能(会话控制)
2017/05/23 PHP
javascript使用eval或者new Function进行语法检查
2010/10/16 Javascript
JavaScript高级程序设计阅读笔记(十六) javascript检测浏览器和操作系统-detect.js
2012/08/14 Javascript
自定义jQuery选项卡插件实例
2013/03/27 Javascript
JavaScript 开发工具webstrom使用指南
2014/12/09 Javascript
JS拖拽组件学习使用
2016/01/19 Javascript
JavaScript File API文件上传预览
2016/02/02 Javascript
跨域资源共享 CORS 详解
2016/04/26 Javascript
jQuery grep()方法详解及实例代码
2016/10/30 Javascript
Javascript 严格模式use strict详解
2017/09/16 Javascript
js实现简单掷骰子效果
2019/10/24 Javascript
Python类的专用方法实例分析
2015/01/09 Python
基于python实现名片管理系统
2018/11/30 Python
解决python字典对值(值为列表)赋值出现重复的问题
2019/01/20 Python
用django-allauth实现第三方登录的示例代码
2019/06/24 Python
linux环境下Django的安装配置详解
2019/07/22 Python
python logging通过json文件配置的步骤
2020/04/27 Python
python解释器安装教程的方法步骤
2020/07/02 Python
浅析PyCharm 的初始设置(知道)
2020/10/12 Python
python中如何使用虚拟环境
2020/10/14 Python
3D动画《斗罗大陆》上线当日播放过亿
2021/03/16 国漫
HTML5 通过Vedio标签实现视频循环播放的示例代码
2020/08/05 HTML / CSS
传媒专业推荐信范文
2013/11/23 职场文书
《盲人摸象》教学反思
2014/02/16 职场文书
材料工程专业毕业生求职信
2014/03/04 职场文书
申论倡议书范文
2014/05/13 职场文书
承诺书模板
2014/08/30 职场文书
党员干部四风问题整改措施思想汇报
2014/10/12 职场文书
2015年社区创卫工作总结
2015/04/21 职场文书
python3实现Dijkstra算法最短路径的实现
2021/05/12 Python
MySQL数据库查询之多表查询总结
2022/08/05 MySQL