Pytorch转keras的有效方法,以FlowNet为例讲解


Posted in Python onMay 26, 2020

Pytorch凭借动态图机制,获得了广泛的使用,大有超越tensorflow的趋势,不过在工程应用上,TF仍然占据优势。有的时候我们会遇到这种情况,需要把模型应用到工业中,运用到实际项目上,TF支持的PB文件和TF的C++接口就成为了有效的工具。今天就给大家讲解一下Pytorch转成Keras的方法,进而我们也可以获得Pb文件,因为Keras是支持tensorflow的,我将会在下一篇博客讲解获得Pb文件,并使用Pb文件的方法。

Pytorch To Keras

首先,我们必须有清楚的认识,网上以及github上一些所谓的pytorch转换Keras或者Keras转换成Pytorch的工具代码几乎不能运行或者有使用的局限性(比如仅仅能转换某一些模型),但是我们是可以用这些转换代码中看出一些端倪来,比如二者的参数的尺寸(shape)的形式、channel的排序(first or last)是否一样,掌握到差异性,就能根据这些差异自己编写转换代码,没错,自己编写转换代码,是最稳妥的办法。整个过程也就分为两个部分。笔者将会以Nvidia开源的FlowNet为例,将开源的Pytorch代码转化为Keras模型。

按照Pytorch中模型的结构,编写对应的Keras代码,用keras的函数式API,构建起来会非常方便。

把Pytorch的模型参数,按照层的名称依次赋值给Keras的模型

以上两步虽然看上去简单,但实际我也走了不少弯路。这里一个关键的地方,就是参数的shape在两个框架中是否统一,那当然是不统一的。下面我以FlowNet为例。

Pytorch中的FlowNet代码

我们仅仅展示层名称和层参数,就不把整个结构贴出来了,否则会占很多的空间,形成水文。

先看用Keras搭建的flowNet模型,直接用model.summary()输出模型信息

__________________________________________________________________________________________________
Layer (type)   Output Shape  Param # Connected to   
==================================================================================================
input_1 (InputLayer)  (None, 6, 512, 512) 0      
__________________________________________________________________________________________________
conv0 (Conv2D)   (None, 64, 512, 512) 3520 input_1[0][0]   
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 64, 512, 512) 0  conv0[0][0]   
__________________________________________________________________________________________________
zero_padding2d_1 (ZeroPadding2D (None, 64, 514, 514) 0  leaky_re_lu_1[0][0]  
__________________________________________________________________________________________________
conv1 (Conv2D)   (None, 64, 256, 256) 36928 zero_padding2d_1[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None, 64, 256, 256) 0  conv1[0][0]   
__________________________________________________________________________________________________
conv1_1 (Conv2D)  (None, 128, 256, 256 73856 leaky_re_lu_2[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 128, 256, 256 0  conv1_1[0][0]   
__________________________________________________________________________________________________
zero_padding2d_2 (ZeroPadding2D (None, 128, 258, 258 0  leaky_re_lu_3[0][0]  
__________________________________________________________________________________________________
conv2 (Conv2D)   (None, 128, 128, 128 147584 zero_padding2d_2[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_4 (LeakyReLU) (None, 128, 128, 128 0  conv2[0][0]   
__________________________________________________________________________________________________
conv2_1 (Conv2D)  (None, 128, 128, 128 147584 leaky_re_lu_4[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU) (None, 128, 128, 128 0  conv2_1[0][0]   
__________________________________________________________________________________________________
zero_padding2d_3 (ZeroPadding2D (None, 128, 130, 130 0  leaky_re_lu_5[0][0]  
__________________________________________________________________________________________________
conv3 (Conv2D)   (None, 256, 64, 64) 295168 zero_padding2d_3[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_6 (LeakyReLU) (None, 256, 64, 64) 0  conv3[0][0]   
__________________________________________________________________________________________________
conv3_1 (Conv2D)  (None, 256, 64, 64) 590080 leaky_re_lu_6[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_7 (LeakyReLU) (None, 256, 64, 64) 0  conv3_1[0][0]   
__________________________________________________________________________________________________
zero_padding2d_4 (ZeroPadding2D (None, 256, 66, 66) 0  leaky_re_lu_7[0][0]  
__________________________________________________________________________________________________
conv4 (Conv2D)   (None, 512, 32, 32) 1180160 zero_padding2d_4[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_8 (LeakyReLU) (None, 512, 32, 32) 0  conv4[0][0]   
__________________________________________________________________________________________________
conv4_1 (Conv2D)  (None, 512, 32, 32) 2359808 leaky_re_lu_8[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_9 (LeakyReLU) (None, 512, 32, 32) 0  conv4_1[0][0]   
__________________________________________________________________________________________________
zero_padding2d_5 (ZeroPadding2D (None, 512, 34, 34) 0  leaky_re_lu_9[0][0]  
__________________________________________________________________________________________________
conv5 (Conv2D)   (None, 512, 16, 16) 2359808 zero_padding2d_5[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_10 (LeakyReLU) (None, 512, 16, 16) 0  conv5[0][0]   
__________________________________________________________________________________________________
conv5_1 (Conv2D)  (None, 512, 16, 16) 2359808 leaky_re_lu_10[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_11 (LeakyReLU) (None, 512, 16, 16) 0  conv5_1[0][0]   
__________________________________________________________________________________________________
zero_padding2d_6 (ZeroPadding2D (None, 512, 18, 18) 0  leaky_re_lu_11[0][0]  
__________________________________________________________________________________________________
conv6 (Conv2D)   (None, 1024, 8, 8) 4719616 zero_padding2d_6[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_12 (LeakyReLU) (None, 1024, 8, 8) 0  conv6[0][0]   
__________________________________________________________________________________________________
conv6_1 (Conv2D)  (None, 1024, 8, 8) 9438208 leaky_re_lu_12[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_13 (LeakyReLU) (None, 1024, 8, 8) 0  conv6_1[0][0]   
__________________________________________________________________________________________________
deconv5 (Conv2DTranspose) (None, 512, 16, 16) 8389120 leaky_re_lu_13[0][0]  
__________________________________________________________________________________________________
predict_flow6 (Conv2D)  (None, 2, 8, 8) 18434 leaky_re_lu_13[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_14 (LeakyReLU) (None, 512, 16, 16) 0  deconv5[0][0]   
__________________________________________________________________________________________________
upsampled_flow6_to_5 (Conv2DTra (None, 2, 16, 16) 66  predict_flow6[0][0]  
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 1026, 16, 16) 0  leaky_re_lu_11[0][0]  
         leaky_re_lu_14[0][0]  
         upsampled_flow6_to_5[0][0] 
__________________________________________________________________________________________________
inter_conv5 (Conv2D)  (None, 512, 16, 16) 4728320 concatenate_1[0][0]  
__________________________________________________________________________________________________
deconv4 (Conv2DTranspose) (None, 256, 32, 32) 4202752 concatenate_1[0][0]  
__________________________________________________________________________________________________
predict_flow5 (Conv2D)  (None, 2, 16, 16) 9218 inter_conv5[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_15 (LeakyReLU) (None, 256, 32, 32) 0  deconv4[0][0]   
__________________________________________________________________________________________________
upsampled_flow5_to4 (Conv2DTran (None, 2, 32, 32) 66  predict_flow5[0][0]  
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 770, 32, 32) 0  leaky_re_lu_9[0][0]  
         leaky_re_lu_15[0][0]  
         upsampled_flow5_to4[0][0] 
__________________________________________________________________________________________________
inter_conv4 (Conv2D)  (None, 256, 32, 32) 1774336 concatenate_2[0][0]  
__________________________________________________________________________________________________
deconv3 (Conv2DTranspose) (None, 128, 64, 64) 1577088 concatenate_2[0][0]  
__________________________________________________________________________________________________
predict_flow4 (Conv2D)  (None, 2, 32, 32) 4610 inter_conv4[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_16 (LeakyReLU) (None, 128, 64, 64) 0  deconv3[0][0]   
__________________________________________________________________________________________________
upsampled_flow4_to3 (Conv2DTran (None, 2, 64, 64) 66  predict_flow4[0][0]  
__________________________________________________________________________________________________
concatenate_3 (Concatenate) (None, 386, 64, 64) 0  leaky_re_lu_7[0][0]  
         leaky_re_lu_16[0][0]  
         upsampled_flow4_to3[0][0] 
__________________________________________________________________________________________________
inter_conv3 (Conv2D)  (None, 128, 64, 64) 444800 concatenate_3[0][0]  
__________________________________________________________________________________________________
deconv2 (Conv2DTranspose) (None, 64, 128, 128) 395328 concatenate_3[0][0]  
__________________________________________________________________________________________________
predict_flow3 (Conv2D)  (None, 2, 64, 64) 2306 inter_conv3[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_17 (LeakyReLU) (None, 64, 128, 128) 0  deconv2[0][0]   
__________________________________________________________________________________________________
upsampled_flow3_to2 (Conv2DTran (None, 2, 128, 128) 66  predict_flow3[0][0]  
__________________________________________________________________________________________________
concatenate_4 (Concatenate) (None, 194, 128, 128 0  leaky_re_lu_5[0][0]  
         leaky_re_lu_17[0][0]  
         upsampled_flow3_to2[0][0] 
__________________________________________________________________________________________________
inter_conv2 (Conv2D)  (None, 64, 128, 128) 111808 concatenate_4[0][0]  
__________________________________________________________________________________________________
predict_flow2 (Conv2D)  (None, 2, 128, 128) 1154 inter_conv2[0][0]  
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D) (None, 2, 512, 512) 0  predict_flow2[0][0]

再看看Pytorch搭建的flownet模型

(conv0): Sequential(
 (0): Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv1): Sequential(
 (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv1_1): Sequential(
 (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv2): Sequential(
 (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv2_1): Sequential(
 (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv3): Sequential(
 (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv3_1): Sequential(
 (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv4): Sequential(
 (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv4_1): Sequential(
 (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv5): Sequential(
 (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv5_1): Sequential(
 (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv6): Sequential(
 (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv6_1): Sequential(
 (0): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (deconv5): Sequential(
 (0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (deconv4): Sequential(
 (0): ConvTranspose2d(1026, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (deconv3): Sequential(
 (0): ConvTranspose2d(770, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (deconv2): Sequential(
 (0): ConvTranspose2d(386, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (inter_conv5): Sequential(
 (0): Conv2d(1026, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 )
 (inter_conv4): Sequential(
 (0): Conv2d(770, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 )
 (inter_conv3): Sequential(
 (0): Conv2d(386, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 )
 (inter_conv2): Sequential(
 (0): Conv2d(194, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 )
 (predict_flow6): Conv2d(1024, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (predict_flow5): Conv2d(512, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (predict_flow4): Conv2d(256, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (predict_flow3): Conv2d(128, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (predict_flow2): Conv2d(64, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (upsampled_flow6_to_5): ConvTranspose2d(2, 2, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (upsampled_flow5_to_4): ConvTranspose2d(2, 2, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (upsampled_flow4_to_3): ConvTranspose2d(2, 2, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (upsampled_flow3_to_2): ConvTranspose2d(2, 2, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (upsample1): Upsample(scale_factor=4.0, mode=bilinear)
)
conv0 Sequential(
 (0): Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv0.0 Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv0.1 LeakyReLU(negative_slope=0.1, inplace)
conv1 Sequential(
 (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv1.0 Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
conv1.1 LeakyReLU(negative_slope=0.1, inplace)
conv1_1 Sequential(
 (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv1_1.0 Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv1_1.1 LeakyReLU(negative_slope=0.1, inplace)
conv2 Sequential(
 (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv2.0 Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
conv2.1 LeakyReLU(negative_slope=0.1, inplace)
conv2_1 Sequential(
 (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv2_1.0 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv2_1.1 LeakyReLU(negative_slope=0.1, inplace)
conv3 Sequential(
 (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv3.0 Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
conv3.1 LeakyReLU(negative_slope=0.1, inplace)
conv3_1 Sequential(
 (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv3_1.0 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv3_1.1 LeakyReLU(negative_slope=0.1, inplace)
conv4 Sequential(
 (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv4.0 Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
conv4.1 LeakyReLU(negative_slope=0.1, inplace)
conv4_1 Sequential(
 (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv4_1.0 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv4_1.1 LeakyReLU(negative_slope=0.1, inplace)
conv5 Sequential(
 (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv5.0 Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
conv5.1 LeakyReLU(negative_slope=0.1, inplace)
conv5_1 Sequential(
 (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv5_1.0 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv5_1.1 LeakyReLU(negative_slope=0.1, inplace)
conv6 Sequential(
 (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv6.0 Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
conv6.1 LeakyReLU(negative_slope=0.1, inplace)
conv6_1 Sequential(
 (0): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv6_1.0 Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv6_1.1 LeakyReLU(negative_slope=0.1, inplace)
deconv5 Sequential(
 (0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
deconv5.0 ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
deconv5.1 LeakyReLU(negative_slope=0.1, inplace)
deconv4 Sequential(
 (0): ConvTranspose2d(1026, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
deconv4.0 ConvTranspose2d(1026, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
deconv4.1 LeakyReLU(negative_slope=0.1, inplace)
deconv3 Sequential(
 (0): ConvTranspose2d(770, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
deconv3.0 ConvTranspose2d(770, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
deconv3.1 LeakyReLU(negative_slope=0.1, inplace)
deconv2 Sequential(
 (0): ConvTranspose2d(386, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
deconv2.0 ConvTranspose2d(386, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
deconv2.1 LeakyReLU(negative_slope=0.1, inplace)
inter_conv5 Sequential(
 (0): Conv2d(1026, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
inter_conv5.0 Conv2d(1026, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
inter_conv4 Sequential(
 (0): Conv2d(770, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
inter_conv4.0 Conv2d(770, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
inter_conv3 Sequential(
 (0): Conv2d(386, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
inter_conv3.0 Conv2d(386, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
inter_conv2 Sequential(
 (0): Conv2d(194, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

因为Pytorch模型用name_modules()输出不是按顺序的,动态图机制决定了只有在有数据流动之后才知道走过的路径。所以上面的顺序也是乱的。但我想表明的是,我用Keras搭建的模型确实是根据官方开源的Pytorch模型搭建的。

模型搭建完毕之后,就到了关键的步骤:给Keras模型赋值。

给Keras模型赋值

这个步骤其实注意三个点

Pytorch是channels_first的,Keras默认是channels_last,在代码开头加上这两句:

K.set_image_data_format(‘channels_first')
K.set_learning_phase(0)

众所周知,卷积层的权重是一个4维张量,那么,在Pytorch和keras中,卷积核的权重的形式是否一致的,那自然是不一致的,要不然我为啥还要写这一点。那么就涉及到Pytorch权重的变形。

既然卷积层权重形式在两个框架是不一致的,转置卷积自然也是不一致的。

我们先看看卷积层在两个框架中的形式

keras的卷积层权重形式

我们用以下代码看keras卷积层权重形式

for l in model.layers:
  print(l.name)
  for i, w in enumerate(l.get_weights()):
   print('%d'%i , w.shape)

第一个卷积层输出如下 0之后是卷积权重的shape,1之后的是偏置项

conv0
0 (3, 3, 6, 64)
1 (64,)

所以Keras的卷积层权重形式是[ height, width, input_channels, out_channels]

Pytorch的卷积层权重形式

net = FlowNet2SD()
 for n, m in net.named_parameters():
  print(n)
  print(m.data.size())

conv0.0.weight
torch.Size([64, 6, 3, 3])
conv0.0.bias
torch.Size([64])

用上面的代码得到所有层的参数的shape,同样找到第一个卷积层的参数,查看shape。

通过对比我们可以发现,Pytorch的卷积层shape是[ out_channels, input_channels, height, width]的形式。

那么我们在取出Pytorch权重之后,需要用np.transpose改变一下权重的排序,才能送到Keras模型对应的层上。

Keras中转置卷积权重形式

deconv4
0 (4, 4, 256, 1026)
1 (256,)

代码仍然和上面一样,找到转置卷积的对应的位置,查看一下

可以看出在Keras中,转置卷积形式是 [ height, width, out_channels, input_channels]

Pytorch中转置卷积权重形式

deconv4.0.weight
torch.Size([1026, 256, 4, 4])
deconv4.0.bias
torch.Size([256])

代码仍然和上面一样,找到转置卷积的对应的位置,查看一下

可以看出在Pytorch中,转置卷积形式是 [ input_channels,out_channels,height, width]

小结

对于卷积层来说,Pytorch的权重需要使用

np.transpose(weight.data.numpy(), [2, 3, 1, 0])

才能赋值给keras模型对应的层的权重。

对于转置卷积来说,通过对比其实也是一样的。不信你去试试嘛。O(∩_∩)O哈哈~

对于偏置项,两种模块都是一维的向量,不需要处理。

有的情况还可能需要通道颠倒一下,但是很少需要这样做。

weights[::-1,::-1,:,:]

赋值

结束了预处理之后,我们就进入第二步,开始赋值了。

先看预处理的代码:

for k,v in weights_from_torch.items():
 if 'bias' not in k:
  weights_from_torch[k] = v.data.numpy().transpose(2, 3, 1, 0)

赋值代码我只截了一部分供大家参考:

k_model = k_model()
for layer in k_model.layers:
 current_layer_name = layer.name
 if current_layer_name=='conv0':
  weights = [weights_from_torch['conv0.0.weight'],weights_from_torch['conv0.0.bias']]
  layer.set_weights(weights)
 elif current_layer_name=='conv1':
  weights = [weights_from_torch['conv1.0.weight'],weights_from_torch['conv1.0.bias']]
  layer.set_weights(weights)
 elif current_layer_name=='conv1_1':
  weights = [weights_from_torch['conv1_1.0.weight'],weights_from_torch['conv1_1.0.bias']]
  layer.set_weights(weights)

首先就是定义Keras模型,用layers获得所有层的迭代器。

遍历迭代器,对一个层赋予相应的值。

赋值需要用save_weights,其参数需要是一个列表,形式和get_weights的返回结果一致,即 [ conv_weights, bias_weights]

最后祝愿大家能实现自己模型的迁移。工程开源在了个人Github,有详细的使用介绍,并且包含使用数据,大家可以直接运行。

以上这篇Pytorch转keras的有效方法,以FlowNet为例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中的对象拷贝示例 python引用传递
Jan 23 Python
跟老齐学Python之集合(set)
Sep 24 Python
python实现的简单猜数字游戏
Apr 04 Python
在Python中使用正则表达式的方法
Aug 13 Python
python实现一行输入多个值和一行输出多个值的例子
Jul 16 Python
详解Python3迁移接口变化采坑记
Oct 11 Python
python matplotlib如何给图中的点加标签
Nov 14 Python
Python+OpenCV实现旋转文本校正方式
Jan 09 Python
django项目中新增app的2种实现方法
Apr 01 Python
详解如何在PyCharm控制台中输出彩色文字和背景
Aug 17 Python
基于python模拟bfs和dfs代码实例
Nov 19 Python
python中常用的数据结构介绍
Jan 12 Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
Pytorch转onnx、torchscript方式
May 25 #Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
pytorch中 gpu与gpu、gpu与cpu 在load时相互转化操作
May 25 #Python
基于pandas向csv添加新的行和列
May 25 #Python
You might like
PHP 获取文件路径(灵活应用__FILE__)
2013/02/15 PHP
php 启动报错如何解决
2014/01/17 PHP
PHP下的浮点运算不准的解决方法
2016/10/27 PHP
Jquery ui css framework
2010/06/28 Javascript
JqGrid web打印实现代码
2011/05/31 Javascript
编写高效jQuery代码的4个原则和5个技巧
2014/04/24 Javascript
Bootstrap基础学习
2015/06/16 Javascript
javascript中 try catch用法
2015/08/16 Javascript
比较常见的javascript中定义函数的区别
2015/11/09 Javascript
深入理解JavaScript程序中内存泄漏
2016/03/17 Javascript
websocket+node.js实现实时聊天系统问题咨询
2017/05/17 Javascript
使用nvm管理不同版本的node与npm的方法
2017/10/31 Javascript
react-router4 配合webpack require.ensure 实现异步加载的示例
2018/01/18 Javascript
jQuery zTree树插件的使用教程
2019/08/16 jQuery
js+css实现全屏侧边栏
2020/06/16 Javascript
js实现扫雷源代码
2020/11/27 Javascript
[28:42]Ti4正赛VG vs NEWBEE1
2014/07/19 DOTA
详解Python的Django框架中Manager方法的使用
2015/07/21 Python
使用Python来编写HTTP服务器的超级指南
2016/02/18 Python
python3+PyQt5使用数据库窗口视图
2018/04/24 Python
Python中使用Counter进行字典创建以及key数量统计的方法
2018/07/06 Python
python处理“
2019/06/10 Python
Python实现直方图均衡基本原理解析
2019/08/08 Python
Python+numpy实现矩阵的行列扩展方式
2019/11/29 Python
pytorch中获取模型input/output shape实例
2019/12/30 Python
python GUI库图形界面开发之PyQt5选项卡控件QTabWidget详细使用方法与实例
2020/03/01 Python
Python 程序报错崩溃后如何倒回到崩溃的位置(推荐)
2020/06/23 Python
免税水晶:Duty Free Crystal
2019/05/13 全球购物
size?法国官网:英国伦敦的球鞋精品店
2020/03/15 全球购物
Java如何格式化日期
2012/08/07 面试题
介绍一下#error预处理
2015/09/25 面试题
四风问题个人对照检查剖析材料
2014/09/27 职场文书
幼儿园家长工作总结2015
2015/04/25 职场文书
小学数学教师研修感悟
2015/11/18 职场文书
Python list去重且保持原顺序不变的方法
2021/04/03 Python
Java对文件的读写操作方法
2022/04/29 Java/Android