Pytorch实现WGAN用于动漫头像生成


Posted in Python onMarch 04, 2021

WGAN与GAN的不同

  • 去除sigmoid
  • 使用具有动量的优化方法,比如使用RMSProp
  • 要对Discriminator的权重做修整限制以确保lipschitz连续约

WGAN实战卷积生成动漫头像 

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os
from anime_face_generator.dataset import ImageDataset
 
batch_size = 32
num_epoch = 100
z_dimension = 100
dir_path = './wgan_img'
 
# 创建文件夹
if not os.path.exists(dir_path):
  os.mkdir(dir_path)
 
 
def to_img(x):
  """因为我们在生成器里面用了tanh"""
  out = 0.5 * (x + 1)
  return out
 
 
dataset = ImageDataset()
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
 
 
class Generator(nn.Module):
  def __init__(self):
    super().__init__()
 
    self.gen = nn.Sequential(
      # 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map
      nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
      nn.BatchNorm2d(512),
      nn.ReLU(True),
      # 上一步的输出形状:(512) x 4 x 4
      nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
      nn.BatchNorm2d(256),
      nn.ReLU(True),
      # 上一步的输出形状: (256) x 8 x 8
      nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
      nn.BatchNorm2d(128),
      nn.ReLU(True),
      # 上一步的输出形状: (256) x 16 x 16
      nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      # 上一步的输出形状:(256) x 32 x 32
      nn.ConvTranspose2d(64, 3, 5, 3, 1, bias=False),
      nn.Tanh() # 输出范围 -1~1 故而采用Tanh
      # nn.Sigmoid()
      # 输出形状:3 x 96 x 96
    )
 
  def forward(self, x):
    x = self.gen(x)
    return x
 
  def weight_init(m):
    # weight_initialization: important for wgan
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
      m.weight.data.normal_(0, 0.02)
    elif class_name.find('Norm') != -1:
      m.weight.data.normal_(1.0, 0.02)
 
 
class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.dis = nn.Sequential(
      nn.Conv2d(3, 64, 5, 3, 1, bias=False),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (64) x 32 x 32
 
      nn.Conv2d(64, 128, 4, 2, 1, bias=False),
      nn.BatchNorm2d(128),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (128) x 16 x 16
 
      nn.Conv2d(128, 256, 4, 2, 1, bias=False),
      nn.BatchNorm2d(256),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (256) x 8 x 8
 
      nn.Conv2d(256, 512, 4, 2, 1, bias=False),
      nn.BatchNorm2d(512),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (512) x 4 x 4
 
      nn.Conv2d(512, 1, 4, 1, 0, bias=False),
      nn.Flatten(),
      # nn.Sigmoid() # 输出一个数(概率)
    )
 
  def forward(self, x):
    x = self.dis(x)
    return x
 
  def weight_init(m):
    # weight_initialization: important for wgan
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
      m.weight.data.normal_(0, 0.02)
    elif class_name.find('Norm') != -1:
      m.weight.data.normal_(1.0, 0.02)
 
 
def save(model, filename="model.pt", out_dir="out/"):
  if model is not None:
    if not os.path.exists(out_dir):
      os.mkdir(out_dir)
    torch.save({'model': model.state_dict()}, out_dir + filename)
  else:
    print("[ERROR]:Please build a model!!!")
 
 
import QuickModelBuilder as builder
 
if __name__ == '__main__':
  one = torch.FloatTensor([1]).cuda()
  mone = -1 * one
 
  is_print = True
  # 创建对象
  D = Discriminator()
  G = Generator()
  D.weight_init()
  G.weight_init()
 
  if torch.cuda.is_available():
    D = D.cuda()
    G = G.cuda()
 
  lr = 2e-4
  d_optimizer = torch.optim.RMSprop(D.parameters(), lr=lr, )
  g_optimizer = torch.optim.RMSprop(G.parameters(), lr=lr, )
  d_scheduler = torch.optim.lr_scheduler.ExponentialLR(d_optimizer, gamma=0.99)
  g_scheduler = torch.optim.lr_scheduler.ExponentialLR(g_optimizer, gamma=0.99)
 
  fake_img = None
 
  # ##########################进入训练##判别器的判断过程#####################
  for epoch in range(num_epoch): # 进行多个epoch的训练
    pbar = builder.MyTqdm(epoch=epoch, maxval=len(dataloader))
    for i, img in enumerate(dataloader):
      num_img = img.size(0)
      real_img = img.cuda() # 将tensor变成Variable放入计算图中
      # 这里的优化器是D的优化器
      for param in D.parameters():
        param.requires_grad = True
      # ########判别器训练train#####################
      # 分为两部分:1、真的图像判别为真;2、假的图像判别为假
 
      # 计算真实图片的损失
      d_optimizer.zero_grad() # 在反向传播之前,先将梯度归0
      real_out = D(real_img) # 将真实图片放入判别器中
      d_loss_real = real_out.mean(0).view(1)
      d_loss_real.backward(one)
 
      # 计算生成图片的损失
      z = torch.randn(num_img, z_dimension).cuda() # 随机生成一些噪声
      z = z.reshape(num_img, z_dimension, 1, 1)
      fake_img = G(z).detach() # 随机噪声放入生成网络中,生成一张假的图片。 # 避免梯度传到G,因为G不用更新, detach分离
      fake_out = D(fake_img) # 判别器判断假的图片,
      d_loss_fake = fake_out.mean(0).view(1)
      d_loss_fake.backward(mone)
 
      d_loss = d_loss_fake - d_loss_real
      d_optimizer.step() # 更新参数
 
      # 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c=0.01
      for parm in D.parameters():
        parm.data.clamp_(-0.01, 0.01)
 
      # ==================训练生成器============================
      # ###############################生成网络的训练###############################
      for param in D.parameters():
        param.requires_grad = False
 
      # 这里的优化器是G的优化器,所以不需要冻结D的梯度,因为不是D的优化器,不会更新D
      g_optimizer.zero_grad() # 梯度归0
 
      z = torch.randn(num_img, z_dimension).cuda()
      z = z.reshape(num_img, z_dimension, 1, 1)
      fake_img = G(z) # 随机噪声输入到生成器中,得到一副假的图片
      output = D(fake_img) # 经过判别器得到的结果
      # g_loss = criterion(output, real_label) # 得到的假的图片与真实的图片的label的loss
      g_loss = torch.mean(output).view(1)
      # bp and optimize
      g_loss.backward(one) # 进行反向传播
      g_optimizer.step() # .step()一般用在反向传播后面,用于更新生成网络的参数
 
      # 打印中间的损失
      pbar.set_right_info(d_loss=d_loss.data.item(),
                g_loss=g_loss.data.item(),
                real_scores=real_out.data.mean().item(),
                fake_scores=fake_out.data.mean().item(),
                )
      pbar.update()
      try:
        fake_images = to_img(fake_img.cpu())
        save_image(fake_images, dir_path + '/fake_images-{}.png'.format(epoch + 1))
      except:
        pass
      if is_print:
        is_print = False
        real_images = to_img(real_img.cpu())
        save_image(real_images, dir_path + '/real_images.png')
    pbar.finish()
    d_scheduler.step()
    g_scheduler.step()
    save(D, "wgan_D.pt")
    save(G, "wgan_G.pt")

到此这篇关于Pytorch实现WGAN用于动漫头像生成的文章就介绍到这了,更多相关Pytorch实现WGAN用于动漫头像生成内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python中使用PyHook监听鼠标和键盘事件实例
Jul 18 Python
浅谈python 线程池threadpool之实现
Nov 17 Python
Python中property属性实例解析
Feb 10 Python
Python框架Flask的基本数据库操作方法分析
Jul 13 Python
Python编程深度学习绘图库之matplotlib
Dec 28 Python
OpenCV搞定腾讯滑块验证码的实现代码
May 18 Python
Python找出列表中出现次数最多的元素三种方式
Feb 24 Python
Python实现一个简单的毕业生信息管理系统的示例代码
Jun 08 Python
通过代码实例了解Python3编程技巧
Oct 13 Python
python3中编码获取网页的实例方法
Nov 16 Python
Python实现网络聊天室的示例代码(支持多人聊天与私聊)
Jan 27 Python
python模块与C和C++动态库相互调用实现过程示例
Nov 02 Python
基于PyInstaller各参数的含义说明
Mar 04 #Python
解决Pyinstaller打包软件失败的一个坑
Mar 04 #Python
selenium+python自动化78-autoit参数化与批量上传功能的实现
Mar 04 #Python
解决PDF 转图片时丢文字的一种可能方式
Mar 04 #Python
pandas数据分组groupby()和统计函数agg()的使用
Mar 04 #Python
pyx文件 生成pyd 文件用于 cython调用的实现
Mar 04 #Python
Python .py生成.pyd文件并打包.exe 的注意事项说明
Mar 04 #Python
You might like
2020年4月放送决定!第2期TV动画《邪神酱飞踢》视觉图&主题曲情报公开!
2020/03/06 日漫
php使用PDO方法详解
2014/12/27 PHP
Prototype1.6 JS 官方下载地址
2007/11/30 Javascript
不使用中间变量,交换int型的 a, b两个变量的值。
2010/10/29 Javascript
Jquery实现兼容各大浏览器的Enter回车切换输入焦点的方法
2014/09/01 Javascript
angular.element方法汇总
2015/01/07 Javascript
javascript实现仿腾讯游戏选择
2015/05/14 Javascript
jquery实现具有嵌套功能的选项卡
2016/02/12 Javascript
使用jQuery实现WordPress中的Ctrl+Enter和@评论回复
2016/05/21 Javascript
AngularJS API之copy深拷贝详解及实例
2016/09/14 Javascript
微信小程序 倒计时组件实现代码
2016/10/24 Javascript
js控制台输出的方法(详解)
2016/11/26 Javascript
微信小程序学习之数据处理详解
2017/07/05 Javascript
angularJS实现动态添加,删除div方法
2018/02/27 Javascript
vue cli3.0 引入eslint 结合vscode使用
2019/05/27 Javascript
javascript创建元素和删除元素实例小结
2019/06/19 Javascript
LayUI switch 开关监听 获取属性值、更改状态的方法
2019/09/21 Javascript
JQuery获得内容和属性方法解析
2020/05/30 jQuery
解决vue单页面 回退页面 keeplive 缓存问题
2020/07/22 Javascript
Python Pandas找到缺失值的位置方法
2018/04/12 Python
Python按钮的响应事件详解
2019/03/04 Python
HTML5制作酷炫音频播放器插件图文教程
2014/12/30 HTML / CSS
早晨薰衣草在线女性精品店:Morning Lavender
2021/01/04 全球购物
荷兰DOD药房中文官网:DeOnlineDrogist
2020/12/27 全球购物
下面关于"联合"的题目的输出是什么
2013/08/06 面试题
运动会广播稿50字
2014/01/26 职场文书
广告学专业自荐信范文
2014/02/24 职场文书
工厂搬迁方案
2014/05/11 职场文书
媒体宣传策划方案
2014/05/25 职场文书
派出所班子党的群众路线对照检查材料思想汇报
2014/10/01 职场文书
家庭财产分割协议范文
2014/11/24 职场文书
搬迁通知
2015/04/20 职场文书
2015最新婚礼主持词
2015/06/30 职场文书
公司劳动纪律管理制度
2015/08/04 职场文书
nginx实现发布静态资源的方法
2021/03/31 Servers
如何使用php生成zip压缩包
2021/04/21 PHP