Python通过VGG16模型实现图像风格转换操作详解


Posted in Python onJanuary 16, 2020

本文实例讲述了Python通过VGG16模型实现图像风格转换操作。分享给大家供大家参考,具体如下:

1、图像的风格转化

卷积网络每一层的激活值可以看作一个分类器,多个分类器组成了图像在这一层的抽象表示,而且层数越深,越抽象

内容特征:图片中存在的具体元素,图像输入到CNN后在某一层的激活值

风格特征:绘制图片元素的风格,各个内容之间的共性,图像在CNN网络某一层激活值之间的关联

风格转换:在一幅图片内容特征的基础上添加另一幅图片的风格特征从而生成一幅新的图片。在卷积模型训练中,通过输入固定的图片来调整网络的参数从而达到利用图片训练网络的目的。而在生成特定风格图片时,固定已有的网络参数不变,调整图片从而使图片向目标风格转化。在内容风格转换时,调整图像的像素值,使其向目标图片在卷积网络输出的内容特征靠拢。在风格特征计算时,通过多个神经元的输出两两之间作内积求和得到Gram矩阵,然后对G矩阵做差求均值得到风格的损失函数。

Python通过VGG16模型实现图像风格转换操作详解                             Python通过VGG16模型实现图像风格转换操作详解

将内容损失函数和风格损失函数对应乘以权重再加起来就得到了总的损失函数,最后的生成图既有内容特征也有风格特征

2、通过Vgg16实现

2.1、预训练模型读取

通过预训练好的Vgg16模型来对图片进行风格转换,首先需要准备好vgg16的模型参数。链接: https://pan.baidu.com/s/1shw2M3Iv7UfGjn78dqFAkA 提取码: ejn8

通过numpy.load()导入并查看参数的内容:

import numpy as np
 
data=np.load('./vgg16_model.npy',allow_pickle=True,encoding='bytes')
# print(data.type())
data_dic=data.item()
# 查看网络层参数的键值
print(data_dic.keys())

打印键值如下,可以看到分别有不同的卷积和全连接层:

dict_keys([b'conv5_1', b'fc6', b'conv5_3', b'conv5_2', b'fc8', b'fc7', b'conv4_1',
 b'conv4_2', b'conv4_3', b'conv3_3', b'conv3_2', b'conv3_1', b'conv1_1', b'conv1_2', 
b'conv2_2', b'conv2_1'])

接着查看具体每层的参数,通过data_dic[key]可以获取到key对应层次的参数,例如可以看到卷积层1_1的权值w为3个3×3的卷积核,对应64个输出通道

# 查看卷积层1_1的参数w,b
w,b=data_dic[b'conv1_1']
print(w.shape,b.shape)   # (3, 3, 3, 64) (64,)
# 查看全连接层的参数
w,b=data_dic[b'fc8']
print(w.shape,b.shape)   # (4096, 1000) (1000,)

2.2、构建VGG网络

通过将已经训练好的参数填充到网络之中就可以搭建VGG网络了。

在类初始化函数中读取预训练模型文件中的参数到self.data_dic

首先构建卷积层,通过传入的各个卷积层name参数,读取模型中对应的卷积层参数并填充到网络中。例如读取第一个卷积层的权值和偏置值,传入name='conv1_1,则data_dic[name][0]可以得到权值weight,data_dic[name][1]得到偏置值bias。通过tf.constant构建常量,再执行卷积操作,加偏置项,经激活函数后输出。

接下来实现池化操作,由于池化不需要参数,所以直接对输入进行最大池化操作后输出即可

接着经过展开层,由于卷积池化后的数据是四维向量[batch_size,image_width,image_height,chanel],需要将最后三维展开,将最后三个维度相乘,通过tf.reshape()展开

最后需要把结果经过全连接层,它的实现和卷积层类似,读取权值和偏置参数后进行全连接操作后输出。

class VGGNet:
 def __init__(self, data_dir):
  data = np.load(data_dir, allow_pickle=True, encoding='bytes')
  self.data_dic = data.item()
 
 def conv_layer(self, x, name):
  # 实现卷积操作
  with tf.name_scope(name):
   # 从模型文件中读取各卷积层的参数值
   weight = tf.constant(self.data_dic[name][0], name='conv')
   bias = tf.constant(self.data_dic[name][1], name='bias')
   # 进行卷积操作
   y = tf.nn.conv2d(x, weight, [1, 1, 1, 1], padding='SAME')
   y = tf.nn.bias_add(y, bias)
   return tf.nn.relu(y)
 
 def pooling_layer(self, x, name):
  # 实现池化操作
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name)
 
 def flatten_layer(self, x, name):
  # 实现展开层
  with tf.name_scope(name):
   # x_shape->[batch_size,image_width,image_height,chanel]
   x_shape = x.get_shape().as_list()
   dimension = 1
   # 计算x的最后三个维度积
   for d in x_shape[1:]:
    dimension *= d
   output = tf.reshape(x, [-1, dimension])
   return output
 
 def fc_layer(self, x, name, activation=tf.nn.relu):
  # 实现全连接层
  with tf.name_scope(name):
   # 从模型文件中读取各全连接层的参数值
   weight = tf.constant(self.data_dic[name][0], name='fc')
   bias = tf.constant(self.data_dic[name][1], name='bias')
   # 进行全连接操作
   y = tf.matmul(x, weight)
   y = tf.nn.bias_add(y, bias)
   if activation==None:
    return y
   else:
    return tf.nn.relu(y)

通过self.build()函数实现Vgg16网络的搭建.数据输入后首先需要进行归一化处理,将输入的RGB数据拆分为R、G、B三个通道,再将三个通道分别减去一个固定值,最后将三通道按B、G、R顺序重新拼接为一个新的数据。

接下来则是通过上面的构建函数来搭建VGG网络,依次将五层的卷积池化网络、展开层、三个全连接层的参数读入各层,并搭建起网络,最后经softmax输出

def build(self,x_rgb):
  s_time=time.time()
  # 归一化处理,在第四维上将输入的图片的三通道拆分
  r,g,b=tf.split(x_rgb,[1,1,1],axis=3)
  # 分别将三通道上减去特定值归一化后再按bgr顺序拼起来
  VGG_MEAN = [103.939, 116.779, 123.68]
  x_bgr=tf.concat(
   [b-VGG_MEAN[0],
   g-VGG_MEAN[1],
   r-VGG_MEAN[2]],
   axis=3
  )
  # 判别拼接起来的数据是否符合期望,符合再继续往下执行
  assert x_bgr.get_shape()[1:]==[668,668,3]
 
  # 构建各个卷积、池化、全连接等层
  self.conv1_1=self.conv_layer(x_bgr,b'conv1_1')
  self.conv1_2=self.conv_layer(self.conv1_1,b'conv1_2')
  self.pool1=self.pooling_layer(self.conv1_2,b'pool1')
 
  self.conv2_1=self.conv_layer(self.pool1,b'conv2_1')
  self.conv2_2=self.conv_layer(self.conv2_1,b'conv2_2')
  self.pool2=self.pooling_layer(self.conv2_2,b'pool2')
 
  self.conv3_1=self.conv_layer(self.pool2,b'conv3_1')
  self.conv3_2=self.conv_layer(self.conv3_1,b'conv3_2')
  self.conv3_3=self.conv_layer(self.conv3_2,b'conv3_3')
  self.pool3=self.pooling_layer(self.conv3_3,b'pool3')
 
  self.conv4_1 = self.conv_layer(self.pool3, b'conv4_1')
  self.conv4_2 = self.conv_layer(self.conv4_1, b'conv4_2')
  self.conv4_3 = self.conv_layer(self.conv4_2, b'conv4_3')
  self.pool4 = self.pooling_layer(self.conv4_3, b'pool4')
 
  self.conv5_1 = self.conv_layer(self.pool4, b'conv5_1')
  self.conv5_2 = self.conv_layer(self.conv5_1, b'conv5_2')
  self.conv5_3 = self.conv_layer(self.conv5_2, b'conv5_3')
  self.pool5 = self.pooling_layer(self.conv5_3, b'pool5')
 
  self.flatten=self.flatten_layer(self.pool5,b'flatten')
  self.fc6=self.fc_layer(self.flatten,b'fc6')
  self.fc7 = self.fc_layer(self.fc6, b'fc7')
  self.fc8 = self.fc_layer(self.fc7, b'fc8',activation=None)
  self.prob=tf.nn.softmax(self.fc8,name='prob')
 
  print('模型构建完成,用时%d秒'%(time.time()-s_time))

2.3、图像风格转换

首先需要定义网络的输入与输出。网络的输入是风格图像和内容图像,两张图象都是668×668的3通道图片。首先通过PIL库中的Image对象完成读入内容图像style_img和风格图像content_img,并将其转化为数组,定义对应的占位符style_in和content_in,在训练时将图片填入。

网络的输出是一张结果图片668×668的3通道,通过随机函数初始化一个结果图像的数组res_out。

利用上面定义的VGGNet类来创建图片对象,并完成build操作。

vgg16_dir = './data/vgg16_model.npy'
style_img = './data/starry_night.jpg'
content_img = './data/city_night.jpg'
output_dir = './data'
 
 
def read_image(img):
 img = Image.open(img)
 img_np = np.array(img) # 将图片转化为[668,668,3]数组
 img_np = np.asarray([img_np], ) # 转化为[1,668,668,3]的数组
 return img_np
 
 
# 输入风格、内容图像数组
style_img = read_image(style_img)
content_img = read_image(content_img)
# 定义对应的输入图像的占位符
content_in = tf.placeholder(tf.float32, shape=[1, 668, 668, 3])
style_in = tf.placeholder(tf.float32, shape=[1, 668, 668, 3])
 
# 初始化输出的图像
initial_img = tf.truncated_normal((1, 668, 668, 3), mean=127.5, stddev=20)
res_out = tf.Variable(initial_img)
 
# 构建VGG网络对象
res_net = VGGNet(vgg16_dir)
style_net = VGGNet(vgg16_dir)
content_net = VGGNet(vgg16_dir)
res_net.build(res_out)
style_net.build(style_in)
content_net.build(content_in)

接着需要定义损失函数loss

对于内容损失,先选定内容风格图像和结果图像的卷积层,要相同,比如这里选取了卷积层1_1和2_1。然后这两个特征层的后三个通道求平方差,然后取均值,就是内容损失。

对于风格损失,首先需要对风格图像和结果图像的特征层求gram矩阵,然后对gram矩阵求平方差的均值。

最后按照系数比例将两个损失函数相加即可得到loss

# 计算损失,分别需要计算内容损失和风格损失
# 提取内容图像的内容特征
content_features = [
 content_net.conv1_2,
 content_net.conv2_2
 # content_net.conv2_2
]
# 对应结果图像提取相同层的内容特征
res_content = [
 res_net.conv1_2,
 res_net.conv2_2
 # res_net.conv2_2
]
# 计算内容损失
content_loss = tf.zeros(1, tf.float32)
for c, r in zip(content_features, res_content):
 content_loss += tf.reduce_mean((c - r) ** 2, [1, 2, 3])
 
 
# 计算风格损失的gram矩阵
def gram_matrix(x):
 b, w, h, ch = x.get_shape().as_list()
 features = tf.reshape(x, [b, w * h, ch])
 # 对features矩阵作内积,再除以一个常数
 gram = tf.matmul(features, features, adjoint_a=True) / tf.constant(w * h * ch, tf.float32)
 return gram
 
 
# 对风格图像提取特征
style_features = [
 # style_net.conv1_2
 style_net.conv4_3
]
style_gram = [gram_matrix(feature) for feature in style_features]
# 提取结果图像对应层的风格特征
res_features = [
 res_net.conv4_3
]
res_gram = [gram_matrix(feature) for feature in res_features]
# 计算风格损失
style_loss = tf.zeros(1, tf.float32)
for s, r in zip(style_gram, res_gram):
 style_loss += tf.reduce_mean((s - r) ** 2, [1, 2])
 
# 模型内容、风格特征的系数
k_content = 0.1
k_style = 500
# 按照系数将两个损失值相加
loss = k_content * content_loss + k_style * style_loss

接下来开始进行100轮的训练,打印并查看过程中的总损失、内容损失、风格损失值。并将每轮的生成结果图片输出到指定目录下

# 进行训练
learning_steps = 100
learning_rate = 10
train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss)
 
with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())
 for i in range(learning_steps):
  t_loss, c_loss, s_loss, _ = sess.run(
   [loss, content_loss, style_loss, train_op],
   feed_dict={content_in: content_img, style_in: style_img}
  )
  print('第%d轮训练,总损失:%.4f,内容损失:%.4f,风格损失:%.4f'
    % (i + 1, t_loss[0], c_loss[0], s_loss[0]))
  # 获取结果图像数组并保存
  res_arr = res_out.eval(sess)[0]
  res_arr = np.clip(res_arr, 0, 255) # 将结果数组中的值裁剪到0~255
  res_arr = np.asarray(res_arr, np.uint8) # 将图片数组转化为uint8
  img_path = os.path.join(output_dir, 'res_%d.jpg' % (i + 1))
  # 图像数组转化为图片
  res_img = Image.fromarray(res_arr)
  res_img.save(img_path)

运行结果如下可以看到依次分别为内容图片、风格图片、训练12轮、46轮、100轮结果图片

Python通过VGG16模型实现图像风格转换操作详解    Python通过VGG16模型实现图像风格转换操作详解 

Python通过VGG16模型实现图像风格转换操作详解      Python通过VGG16模型实现图像风格转换操作详解    Python通过VGG16模型实现图像风格转换操作详解

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
使用python编写批量卸载手机中安装的android应用脚本
Jul 21 Python
Python yield 使用方法浅析
May 20 Python
python3.4用循环往mysql5.7中写数据并输出的实现方法
Jun 20 Python
浅谈python内置变量-reversed(seq)
Jun 21 Python
解决Django后台ManyToManyField显示成Object的问题
Aug 09 Python
Flask框架 CSRF 保护实现方法详解
Oct 30 Python
Keras设置以及获取权重的实现
Jun 19 Python
Python3爬虫中识别图形验证码的实例讲解
Jul 30 Python
python判断一个变量是否已经设置的方法
Aug 13 Python
PyCharm中关于安装第三方包的三个建议
Sep 17 Python
python实现excel公式格式化的示例代码
Dec 23 Python
Python - 10行代码集2000张美女图
May 23 Python
Python使用turtle库绘制小猪佩奇(实例代码)
Jan 16 #Python
PyCharm汉化安装及永久激活详细教程(靠谱)
Jan 16 #Python
python如何使用Redis构建分布式锁
Jan 16 #Python
Python中url标签使用知识点总结
Jan 16 #Python
PyTorch的SoftMax交叉熵损失和梯度用法
Jan 15 #Python
pytorch方法测试——激活函数(ReLU)详解
Jan 15 #Python
pytorch的batch normalize使用详解
Jan 15 #Python
You might like
一个简单实现多条件查询的例子
2006/10/09 PHP
详解PHP显示MySQL数据的三种方法
2008/06/05 PHP
php修改时间格式的代码
2011/05/29 PHP
启用Csrf后POST数据时出现的400错误
2015/07/05 PHP
php轻松实现文件上传功能
2016/03/03 PHP
PHP批量获取网页中所有固定种子链接的方法
2016/11/18 PHP
php的无刷新操作实现方法分析
2020/02/28 PHP
实现复选框全选/全不选切换
2006/12/23 Javascript
jQuery.ajax 用户登录验证代码
2010/10/29 Javascript
Jquery getJSON方法详细分析
2013/12/26 Javascript
JS去除字符串两端空格的简单实例
2013/12/27 Javascript
jQuery避免$符和其他JS库冲突的方法对比
2014/02/20 Javascript
node.js中的fs.utimes方法使用说明
2014/12/15 Javascript
解决jQuery上传插件Uploadify出现Http Error 302错误的方法
2015/12/18 Javascript
AngularJS转换响应内容
2016/01/27 Javascript
TinyMCE汉化及本地上传图片功能实例详解
2016/05/31 Javascript
Angular和Vue双向数据绑定的实现原理(重点是vue的双向绑定)
2016/11/22 Javascript
手机软键盘弹出时影响布局的解决方法
2016/12/15 Javascript
js点击任意区域弹出层消失实现代码
2016/12/27 Javascript
利用Vue.js实现checkbox的全选反选效果
2017/01/18 Javascript
详解AngularJS1.6版本中ui-router路由中/#!/的解决方法
2017/05/22 Javascript
vue使用axios时关于this的指向问题详解
2017/12/22 Javascript
解决vue项目nginx部署到非根目录下刷新空白的问题
2018/09/27 Javascript
[02:04]2016国际邀请赛中国区预选赛VG.R晋级之路
2016/07/01 DOTA
rhythmbox中文名乱码问题解决方法
2008/09/06 Python
Python如何快速实现分布式任务
2017/07/06 Python
pandas使用get_dummies进行one-hot编码的方法
2018/07/10 Python
浅谈Scrapy网络爬虫框架的工作原理和数据采集
2019/02/07 Python
浅谈python多进程共享变量Value的使用tips
2019/07/16 Python
俄罗斯品牌服装和鞋子在线商店:BRIONITY
2020/03/26 全球购物
一封普通求职者的求职信
2013/11/20 职场文书
中学生校园广播稿
2014/01/16 职场文书
带病坚持工作事迹
2014/05/03 职场文书
音乐研修感悟
2015/11/18 职场文书
四年级数学教学反思
2016/02/16 职场文书
导游词之塘栖古镇
2019/12/04 职场文书