Pytorch转keras的有效方法,以FlowNet为例讲解


Posted in Python onMay 26, 2020

Pytorch凭借动态图机制,获得了广泛的使用,大有超越tensorflow的趋势,不过在工程应用上,TF仍然占据优势。有的时候我们会遇到这种情况,需要把模型应用到工业中,运用到实际项目上,TF支持的PB文件和TF的C++接口就成为了有效的工具。今天就给大家讲解一下Pytorch转成Keras的方法,进而我们也可以获得Pb文件,因为Keras是支持tensorflow的,我将会在下一篇博客讲解获得Pb文件,并使用Pb文件的方法。

Pytorch To Keras

首先,我们必须有清楚的认识,网上以及github上一些所谓的pytorch转换Keras或者Keras转换成Pytorch的工具代码几乎不能运行或者有使用的局限性(比如仅仅能转换某一些模型),但是我们是可以用这些转换代码中看出一些端倪来,比如二者的参数的尺寸(shape)的形式、channel的排序(first or last)是否一样,掌握到差异性,就能根据这些差异自己编写转换代码,没错,自己编写转换代码,是最稳妥的办法。整个过程也就分为两个部分。笔者将会以Nvidia开源的FlowNet为例,将开源的Pytorch代码转化为Keras模型。

按照Pytorch中模型的结构,编写对应的Keras代码,用keras的函数式API,构建起来会非常方便。

把Pytorch的模型参数,按照层的名称依次赋值给Keras的模型

以上两步虽然看上去简单,但实际我也走了不少弯路。这里一个关键的地方,就是参数的shape在两个框架中是否统一,那当然是不统一的。下面我以FlowNet为例。

Pytorch中的FlowNet代码

我们仅仅展示层名称和层参数,就不把整个结构贴出来了,否则会占很多的空间,形成水文。

先看用Keras搭建的flowNet模型,直接用model.summary()输出模型信息

__________________________________________________________________________________________________
Layer (type)   Output Shape  Param # Connected to   
==================================================================================================
input_1 (InputLayer)  (None, 6, 512, 512) 0      
__________________________________________________________________________________________________
conv0 (Conv2D)   (None, 64, 512, 512) 3520 input_1[0][0]   
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 64, 512, 512) 0  conv0[0][0]   
__________________________________________________________________________________________________
zero_padding2d_1 (ZeroPadding2D (None, 64, 514, 514) 0  leaky_re_lu_1[0][0]  
__________________________________________________________________________________________________
conv1 (Conv2D)   (None, 64, 256, 256) 36928 zero_padding2d_1[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None, 64, 256, 256) 0  conv1[0][0]   
__________________________________________________________________________________________________
conv1_1 (Conv2D)  (None, 128, 256, 256 73856 leaky_re_lu_2[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 128, 256, 256 0  conv1_1[0][0]   
__________________________________________________________________________________________________
zero_padding2d_2 (ZeroPadding2D (None, 128, 258, 258 0  leaky_re_lu_3[0][0]  
__________________________________________________________________________________________________
conv2 (Conv2D)   (None, 128, 128, 128 147584 zero_padding2d_2[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_4 (LeakyReLU) (None, 128, 128, 128 0  conv2[0][0]   
__________________________________________________________________________________________________
conv2_1 (Conv2D)  (None, 128, 128, 128 147584 leaky_re_lu_4[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU) (None, 128, 128, 128 0  conv2_1[0][0]   
__________________________________________________________________________________________________
zero_padding2d_3 (ZeroPadding2D (None, 128, 130, 130 0  leaky_re_lu_5[0][0]  
__________________________________________________________________________________________________
conv3 (Conv2D)   (None, 256, 64, 64) 295168 zero_padding2d_3[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_6 (LeakyReLU) (None, 256, 64, 64) 0  conv3[0][0]   
__________________________________________________________________________________________________
conv3_1 (Conv2D)  (None, 256, 64, 64) 590080 leaky_re_lu_6[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_7 (LeakyReLU) (None, 256, 64, 64) 0  conv3_1[0][0]   
__________________________________________________________________________________________________
zero_padding2d_4 (ZeroPadding2D (None, 256, 66, 66) 0  leaky_re_lu_7[0][0]  
__________________________________________________________________________________________________
conv4 (Conv2D)   (None, 512, 32, 32) 1180160 zero_padding2d_4[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_8 (LeakyReLU) (None, 512, 32, 32) 0  conv4[0][0]   
__________________________________________________________________________________________________
conv4_1 (Conv2D)  (None, 512, 32, 32) 2359808 leaky_re_lu_8[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_9 (LeakyReLU) (None, 512, 32, 32) 0  conv4_1[0][0]   
__________________________________________________________________________________________________
zero_padding2d_5 (ZeroPadding2D (None, 512, 34, 34) 0  leaky_re_lu_9[0][0]  
__________________________________________________________________________________________________
conv5 (Conv2D)   (None, 512, 16, 16) 2359808 zero_padding2d_5[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_10 (LeakyReLU) (None, 512, 16, 16) 0  conv5[0][0]   
__________________________________________________________________________________________________
conv5_1 (Conv2D)  (None, 512, 16, 16) 2359808 leaky_re_lu_10[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_11 (LeakyReLU) (None, 512, 16, 16) 0  conv5_1[0][0]   
__________________________________________________________________________________________________
zero_padding2d_6 (ZeroPadding2D (None, 512, 18, 18) 0  leaky_re_lu_11[0][0]  
__________________________________________________________________________________________________
conv6 (Conv2D)   (None, 1024, 8, 8) 4719616 zero_padding2d_6[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_12 (LeakyReLU) (None, 1024, 8, 8) 0  conv6[0][0]   
__________________________________________________________________________________________________
conv6_1 (Conv2D)  (None, 1024, 8, 8) 9438208 leaky_re_lu_12[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_13 (LeakyReLU) (None, 1024, 8, 8) 0  conv6_1[0][0]   
__________________________________________________________________________________________________
deconv5 (Conv2DTranspose) (None, 512, 16, 16) 8389120 leaky_re_lu_13[0][0]  
__________________________________________________________________________________________________
predict_flow6 (Conv2D)  (None, 2, 8, 8) 18434 leaky_re_lu_13[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_14 (LeakyReLU) (None, 512, 16, 16) 0  deconv5[0][0]   
__________________________________________________________________________________________________
upsampled_flow6_to_5 (Conv2DTra (None, 2, 16, 16) 66  predict_flow6[0][0]  
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 1026, 16, 16) 0  leaky_re_lu_11[0][0]  
         leaky_re_lu_14[0][0]  
         upsampled_flow6_to_5[0][0] 
__________________________________________________________________________________________________
inter_conv5 (Conv2D)  (None, 512, 16, 16) 4728320 concatenate_1[0][0]  
__________________________________________________________________________________________________
deconv4 (Conv2DTranspose) (None, 256, 32, 32) 4202752 concatenate_1[0][0]  
__________________________________________________________________________________________________
predict_flow5 (Conv2D)  (None, 2, 16, 16) 9218 inter_conv5[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_15 (LeakyReLU) (None, 256, 32, 32) 0  deconv4[0][0]   
__________________________________________________________________________________________________
upsampled_flow5_to4 (Conv2DTran (None, 2, 32, 32) 66  predict_flow5[0][0]  
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 770, 32, 32) 0  leaky_re_lu_9[0][0]  
         leaky_re_lu_15[0][0]  
         upsampled_flow5_to4[0][0] 
__________________________________________________________________________________________________
inter_conv4 (Conv2D)  (None, 256, 32, 32) 1774336 concatenate_2[0][0]  
__________________________________________________________________________________________________
deconv3 (Conv2DTranspose) (None, 128, 64, 64) 1577088 concatenate_2[0][0]  
__________________________________________________________________________________________________
predict_flow4 (Conv2D)  (None, 2, 32, 32) 4610 inter_conv4[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_16 (LeakyReLU) (None, 128, 64, 64) 0  deconv3[0][0]   
__________________________________________________________________________________________________
upsampled_flow4_to3 (Conv2DTran (None, 2, 64, 64) 66  predict_flow4[0][0]  
__________________________________________________________________________________________________
concatenate_3 (Concatenate) (None, 386, 64, 64) 0  leaky_re_lu_7[0][0]  
         leaky_re_lu_16[0][0]  
         upsampled_flow4_to3[0][0] 
__________________________________________________________________________________________________
inter_conv3 (Conv2D)  (None, 128, 64, 64) 444800 concatenate_3[0][0]  
__________________________________________________________________________________________________
deconv2 (Conv2DTranspose) (None, 64, 128, 128) 395328 concatenate_3[0][0]  
__________________________________________________________________________________________________
predict_flow3 (Conv2D)  (None, 2, 64, 64) 2306 inter_conv3[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_17 (LeakyReLU) (None, 64, 128, 128) 0  deconv2[0][0]   
__________________________________________________________________________________________________
upsampled_flow3_to2 (Conv2DTran (None, 2, 128, 128) 66  predict_flow3[0][0]  
__________________________________________________________________________________________________
concatenate_4 (Concatenate) (None, 194, 128, 128 0  leaky_re_lu_5[0][0]  
         leaky_re_lu_17[0][0]  
         upsampled_flow3_to2[0][0] 
__________________________________________________________________________________________________
inter_conv2 (Conv2D)  (None, 64, 128, 128) 111808 concatenate_4[0][0]  
__________________________________________________________________________________________________
predict_flow2 (Conv2D)  (None, 2, 128, 128) 1154 inter_conv2[0][0]  
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D) (None, 2, 512, 512) 0  predict_flow2[0][0]

再看看Pytorch搭建的flownet模型

(conv0): Sequential(
 (0): Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv1): Sequential(
 (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv1_1): Sequential(
 (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv2): Sequential(
 (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv2_1): Sequential(
 (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv3): Sequential(
 (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv3_1): Sequential(
 (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv4): Sequential(
 (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv4_1): Sequential(
 (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv5): Sequential(
 (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv5_1): Sequential(
 (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv6): Sequential(
 (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv6_1): Sequential(
 (0): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (deconv5): Sequential(
 (0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (deconv4): Sequential(
 (0): ConvTranspose2d(1026, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (deconv3): Sequential(
 (0): ConvTranspose2d(770, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (deconv2): Sequential(
 (0): ConvTranspose2d(386, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (inter_conv5): Sequential(
 (0): Conv2d(1026, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 )
 (inter_conv4): Sequential(
 (0): Conv2d(770, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 )
 (inter_conv3): Sequential(
 (0): Conv2d(386, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 )
 (inter_conv2): Sequential(
 (0): Conv2d(194, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 )
 (predict_flow6): Conv2d(1024, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (predict_flow5): Conv2d(512, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (predict_flow4): Conv2d(256, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (predict_flow3): Conv2d(128, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (predict_flow2): Conv2d(64, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (upsampled_flow6_to_5): ConvTranspose2d(2, 2, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (upsampled_flow5_to_4): ConvTranspose2d(2, 2, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (upsampled_flow4_to_3): ConvTranspose2d(2, 2, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (upsampled_flow3_to_2): ConvTranspose2d(2, 2, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (upsample1): Upsample(scale_factor=4.0, mode=bilinear)
)
conv0 Sequential(
 (0): Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv0.0 Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv0.1 LeakyReLU(negative_slope=0.1, inplace)
conv1 Sequential(
 (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv1.0 Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
conv1.1 LeakyReLU(negative_slope=0.1, inplace)
conv1_1 Sequential(
 (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv1_1.0 Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv1_1.1 LeakyReLU(negative_slope=0.1, inplace)
conv2 Sequential(
 (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv2.0 Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
conv2.1 LeakyReLU(negative_slope=0.1, inplace)
conv2_1 Sequential(
 (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv2_1.0 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv2_1.1 LeakyReLU(negative_slope=0.1, inplace)
conv3 Sequential(
 (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv3.0 Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
conv3.1 LeakyReLU(negative_slope=0.1, inplace)
conv3_1 Sequential(
 (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv3_1.0 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv3_1.1 LeakyReLU(negative_slope=0.1, inplace)
conv4 Sequential(
 (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv4.0 Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
conv4.1 LeakyReLU(negative_slope=0.1, inplace)
conv4_1 Sequential(
 (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv4_1.0 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv4_1.1 LeakyReLU(negative_slope=0.1, inplace)
conv5 Sequential(
 (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv5.0 Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
conv5.1 LeakyReLU(negative_slope=0.1, inplace)
conv5_1 Sequential(
 (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv5_1.0 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv5_1.1 LeakyReLU(negative_slope=0.1, inplace)
conv6 Sequential(
 (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv6.0 Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
conv6.1 LeakyReLU(negative_slope=0.1, inplace)
conv6_1 Sequential(
 (0): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv6_1.0 Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv6_1.1 LeakyReLU(negative_slope=0.1, inplace)
deconv5 Sequential(
 (0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
deconv5.0 ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
deconv5.1 LeakyReLU(negative_slope=0.1, inplace)
deconv4 Sequential(
 (0): ConvTranspose2d(1026, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
deconv4.0 ConvTranspose2d(1026, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
deconv4.1 LeakyReLU(negative_slope=0.1, inplace)
deconv3 Sequential(
 (0): ConvTranspose2d(770, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
deconv3.0 ConvTranspose2d(770, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
deconv3.1 LeakyReLU(negative_slope=0.1, inplace)
deconv2 Sequential(
 (0): ConvTranspose2d(386, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
deconv2.0 ConvTranspose2d(386, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
deconv2.1 LeakyReLU(negative_slope=0.1, inplace)
inter_conv5 Sequential(
 (0): Conv2d(1026, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
inter_conv5.0 Conv2d(1026, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
inter_conv4 Sequential(
 (0): Conv2d(770, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
inter_conv4.0 Conv2d(770, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
inter_conv3 Sequential(
 (0): Conv2d(386, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
inter_conv3.0 Conv2d(386, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
inter_conv2 Sequential(
 (0): Conv2d(194, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

因为Pytorch模型用name_modules()输出不是按顺序的,动态图机制决定了只有在有数据流动之后才知道走过的路径。所以上面的顺序也是乱的。但我想表明的是,我用Keras搭建的模型确实是根据官方开源的Pytorch模型搭建的。

模型搭建完毕之后,就到了关键的步骤:给Keras模型赋值。

给Keras模型赋值

这个步骤其实注意三个点

Pytorch是channels_first的,Keras默认是channels_last,在代码开头加上这两句:

K.set_image_data_format(‘channels_first')
K.set_learning_phase(0)

众所周知,卷积层的权重是一个4维张量,那么,在Pytorch和keras中,卷积核的权重的形式是否一致的,那自然是不一致的,要不然我为啥还要写这一点。那么就涉及到Pytorch权重的变形。

既然卷积层权重形式在两个框架是不一致的,转置卷积自然也是不一致的。

我们先看看卷积层在两个框架中的形式

keras的卷积层权重形式

我们用以下代码看keras卷积层权重形式

for l in model.layers:
  print(l.name)
  for i, w in enumerate(l.get_weights()):
   print('%d'%i , w.shape)

第一个卷积层输出如下 0之后是卷积权重的shape,1之后的是偏置项

conv0
0 (3, 3, 6, 64)
1 (64,)

所以Keras的卷积层权重形式是[ height, width, input_channels, out_channels]

Pytorch的卷积层权重形式

net = FlowNet2SD()
 for n, m in net.named_parameters():
  print(n)
  print(m.data.size())

conv0.0.weight
torch.Size([64, 6, 3, 3])
conv0.0.bias
torch.Size([64])

用上面的代码得到所有层的参数的shape,同样找到第一个卷积层的参数,查看shape。

通过对比我们可以发现,Pytorch的卷积层shape是[ out_channels, input_channels, height, width]的形式。

那么我们在取出Pytorch权重之后,需要用np.transpose改变一下权重的排序,才能送到Keras模型对应的层上。

Keras中转置卷积权重形式

deconv4
0 (4, 4, 256, 1026)
1 (256,)

代码仍然和上面一样,找到转置卷积的对应的位置,查看一下

可以看出在Keras中,转置卷积形式是 [ height, width, out_channels, input_channels]

Pytorch中转置卷积权重形式

deconv4.0.weight
torch.Size([1026, 256, 4, 4])
deconv4.0.bias
torch.Size([256])

代码仍然和上面一样,找到转置卷积的对应的位置,查看一下

可以看出在Pytorch中,转置卷积形式是 [ input_channels,out_channels,height, width]

小结

对于卷积层来说,Pytorch的权重需要使用

np.transpose(weight.data.numpy(), [2, 3, 1, 0])

才能赋值给keras模型对应的层的权重。

对于转置卷积来说,通过对比其实也是一样的。不信你去试试嘛。O(∩_∩)O哈哈~

对于偏置项,两种模块都是一维的向量,不需要处理。

有的情况还可能需要通道颠倒一下,但是很少需要这样做。

weights[::-1,::-1,:,:]

赋值

结束了预处理之后,我们就进入第二步,开始赋值了。

先看预处理的代码:

for k,v in weights_from_torch.items():
 if 'bias' not in k:
  weights_from_torch[k] = v.data.numpy().transpose(2, 3, 1, 0)

赋值代码我只截了一部分供大家参考:

k_model = k_model()
for layer in k_model.layers:
 current_layer_name = layer.name
 if current_layer_name=='conv0':
  weights = [weights_from_torch['conv0.0.weight'],weights_from_torch['conv0.0.bias']]
  layer.set_weights(weights)
 elif current_layer_name=='conv1':
  weights = [weights_from_torch['conv1.0.weight'],weights_from_torch['conv1.0.bias']]
  layer.set_weights(weights)
 elif current_layer_name=='conv1_1':
  weights = [weights_from_torch['conv1_1.0.weight'],weights_from_torch['conv1_1.0.bias']]
  layer.set_weights(weights)

首先就是定义Keras模型,用layers获得所有层的迭代器。

遍历迭代器,对一个层赋予相应的值。

赋值需要用save_weights,其参数需要是一个列表,形式和get_weights的返回结果一致,即 [ conv_weights, bias_weights]

最后祝愿大家能实现自己模型的迁移。工程开源在了个人Github,有详细的使用介绍,并且包含使用数据,大家可以直接运行。

以上这篇Pytorch转keras的有效方法,以FlowNet为例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用python分割TXT文件成4K的TXT文件
May 23 Python
Python实现优先级队列结构的方法详解
Jun 02 Python
用matplotlib画等高线图详解
Dec 14 Python
python人民币小写转大写辅助工具
Jun 20 Python
Python并发:多线程与多进程的详解
Jan 24 Python
python 实现二维字典的键值合并等函数
Dec 06 Python
Python concurrent.futures模块使用实例
Dec 24 Python
python文件处理fileinput使用方法详解
Jan 02 Python
Python tkinter三种布局实例详解
Jan 06 Python
python 正则表达式参数替换实例详解
Jan 17 Python
python使用建议与技巧分享(二)
Aug 17 Python
如何通过python计算圆周率PI
Nov 11 Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
Pytorch转onnx、torchscript方式
May 25 #Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
pytorch中 gpu与gpu、gpu与cpu 在load时相互转化操作
May 25 #Python
基于pandas向csv添加新的行和列
May 25 #Python
You might like
什么是调频(FM)、调幅(AM)、短波(SW)、长波(LW)
2021/03/01 无线电
PHP实现浏览器中直接输出图片的方法示例
2018/03/14 PHP
PHP addAttribute()函数讲解
2019/02/03 PHP
PHP实现的文件浏览器功能简单示例
2019/09/12 PHP
限制复选框的最大可选数
2006/07/01 Javascript
Javascript 验证上传图片大小[客户端]
2009/08/01 Javascript
js模拟类继承小例子
2010/07/17 Javascript
jquery $.ajax各个事件执行顺序
2010/10/15 Javascript
javascript中的关于类型转换的性能优化
2010/12/14 Javascript
jquery事件preventDefault()方法用法实例
2015/01/16 Javascript
javascript实现检验的各种规则
2015/07/31 Javascript
jquery动态切换背景图片的简单实现方法
2016/05/14 Javascript
Bootstrap中的表单验证插件bootstrapValidator使用方法整理(推荐)
2016/06/21 Javascript
jQuery中delegate()方法的用法详解
2016/10/13 Javascript
Bootstrap fileinput文件上传预览插件使用详解
2017/05/16 Javascript
JS与HTML结合实现流程进度展示条思路详解
2017/09/03 Javascript
详解vue的diff算法原理
2018/05/20 Javascript
layui: layer.open加载窗体时出现遮罩层的解决方法
2019/09/26 Javascript
[54:25]Ti4 循环赛第三日LGD vs MOUZ
2014/07/12 DOTA
[57:41]Secret vs Serenity 2018国际邀请赛小组赛BO2 第一场 8.16
2018/08/17 DOTA
简单介绍Python的轻便web框架Bottle
2015/04/08 Python
Python3中使用PyMongo的方法详解
2017/07/28 Python
深入理解Python爬虫代理池服务
2018/02/28 Python
Python hexstring-list-str之间的转换方法
2019/06/12 Python
Python数组拼接np.concatenate实现过程
2020/04/18 Python
使用Keras建立模型并训练等一系列操作方式
2020/07/02 Python
CSS3制作皮卡丘动画壁纸的示例
2020/11/02 HTML / CSS
HTML5 Video/Audio播放本地文件示例介绍
2013/11/18 HTML / CSS
机电一体化专业应届生求职信
2013/11/27 职场文书
中华美德颂演讲稿
2014/05/20 职场文书
建筑工地大门标语
2014/06/18 职场文书
2014年综治维稳工作总结
2014/11/17 职场文书
让人感觉高大上的讲话稿怎么写?
2019/07/08 职场文书
导游词之湖北武当山
2019/09/23 职场文书
golang fmt格式“占位符”的实例用法详解
2021/07/04 Golang
Java练习之潜艇小游戏的实现
2022/03/16 Java/Android