Pytorch实现WGAN用于动漫头像生成


Posted in Python onMarch 04, 2021

WGAN与GAN的不同

  • 去除sigmoid
  • 使用具有动量的优化方法,比如使用RMSProp
  • 要对Discriminator的权重做修整限制以确保lipschitz连续约

WGAN实战卷积生成动漫头像 

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os
from anime_face_generator.dataset import ImageDataset
 
batch_size = 32
num_epoch = 100
z_dimension = 100
dir_path = './wgan_img'
 
# 创建文件夹
if not os.path.exists(dir_path):
  os.mkdir(dir_path)
 
 
def to_img(x):
  """因为我们在生成器里面用了tanh"""
  out = 0.5 * (x + 1)
  return out
 
 
dataset = ImageDataset()
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
 
 
class Generator(nn.Module):
  def __init__(self):
    super().__init__()
 
    self.gen = nn.Sequential(
      # 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map
      nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
      nn.BatchNorm2d(512),
      nn.ReLU(True),
      # 上一步的输出形状:(512) x 4 x 4
      nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
      nn.BatchNorm2d(256),
      nn.ReLU(True),
      # 上一步的输出形状: (256) x 8 x 8
      nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
      nn.BatchNorm2d(128),
      nn.ReLU(True),
      # 上一步的输出形状: (256) x 16 x 16
      nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      # 上一步的输出形状:(256) x 32 x 32
      nn.ConvTranspose2d(64, 3, 5, 3, 1, bias=False),
      nn.Tanh() # 输出范围 -1~1 故而采用Tanh
      # nn.Sigmoid()
      # 输出形状:3 x 96 x 96
    )
 
  def forward(self, x):
    x = self.gen(x)
    return x
 
  def weight_init(m):
    # weight_initialization: important for wgan
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
      m.weight.data.normal_(0, 0.02)
    elif class_name.find('Norm') != -1:
      m.weight.data.normal_(1.0, 0.02)
 
 
class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.dis = nn.Sequential(
      nn.Conv2d(3, 64, 5, 3, 1, bias=False),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (64) x 32 x 32
 
      nn.Conv2d(64, 128, 4, 2, 1, bias=False),
      nn.BatchNorm2d(128),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (128) x 16 x 16
 
      nn.Conv2d(128, 256, 4, 2, 1, bias=False),
      nn.BatchNorm2d(256),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (256) x 8 x 8
 
      nn.Conv2d(256, 512, 4, 2, 1, bias=False),
      nn.BatchNorm2d(512),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (512) x 4 x 4
 
      nn.Conv2d(512, 1, 4, 1, 0, bias=False),
      nn.Flatten(),
      # nn.Sigmoid() # 输出一个数(概率)
    )
 
  def forward(self, x):
    x = self.dis(x)
    return x
 
  def weight_init(m):
    # weight_initialization: important for wgan
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
      m.weight.data.normal_(0, 0.02)
    elif class_name.find('Norm') != -1:
      m.weight.data.normal_(1.0, 0.02)
 
 
def save(model, filename="model.pt", out_dir="out/"):
  if model is not None:
    if not os.path.exists(out_dir):
      os.mkdir(out_dir)
    torch.save({'model': model.state_dict()}, out_dir + filename)
  else:
    print("[ERROR]:Please build a model!!!")
 
 
import QuickModelBuilder as builder
 
if __name__ == '__main__':
  one = torch.FloatTensor([1]).cuda()
  mone = -1 * one
 
  is_print = True
  # 创建对象
  D = Discriminator()
  G = Generator()
  D.weight_init()
  G.weight_init()
 
  if torch.cuda.is_available():
    D = D.cuda()
    G = G.cuda()
 
  lr = 2e-4
  d_optimizer = torch.optim.RMSprop(D.parameters(), lr=lr, )
  g_optimizer = torch.optim.RMSprop(G.parameters(), lr=lr, )
  d_scheduler = torch.optim.lr_scheduler.ExponentialLR(d_optimizer, gamma=0.99)
  g_scheduler = torch.optim.lr_scheduler.ExponentialLR(g_optimizer, gamma=0.99)
 
  fake_img = None
 
  # ##########################进入训练##判别器的判断过程#####################
  for epoch in range(num_epoch): # 进行多个epoch的训练
    pbar = builder.MyTqdm(epoch=epoch, maxval=len(dataloader))
    for i, img in enumerate(dataloader):
      num_img = img.size(0)
      real_img = img.cuda() # 将tensor变成Variable放入计算图中
      # 这里的优化器是D的优化器
      for param in D.parameters():
        param.requires_grad = True
      # ########判别器训练train#####################
      # 分为两部分:1、真的图像判别为真;2、假的图像判别为假
 
      # 计算真实图片的损失
      d_optimizer.zero_grad() # 在反向传播之前,先将梯度归0
      real_out = D(real_img) # 将真实图片放入判别器中
      d_loss_real = real_out.mean(0).view(1)
      d_loss_real.backward(one)
 
      # 计算生成图片的损失
      z = torch.randn(num_img, z_dimension).cuda() # 随机生成一些噪声
      z = z.reshape(num_img, z_dimension, 1, 1)
      fake_img = G(z).detach() # 随机噪声放入生成网络中,生成一张假的图片。 # 避免梯度传到G,因为G不用更新, detach分离
      fake_out = D(fake_img) # 判别器判断假的图片,
      d_loss_fake = fake_out.mean(0).view(1)
      d_loss_fake.backward(mone)
 
      d_loss = d_loss_fake - d_loss_real
      d_optimizer.step() # 更新参数
 
      # 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c=0.01
      for parm in D.parameters():
        parm.data.clamp_(-0.01, 0.01)
 
      # ==================训练生成器============================
      # ###############################生成网络的训练###############################
      for param in D.parameters():
        param.requires_grad = False
 
      # 这里的优化器是G的优化器,所以不需要冻结D的梯度,因为不是D的优化器,不会更新D
      g_optimizer.zero_grad() # 梯度归0
 
      z = torch.randn(num_img, z_dimension).cuda()
      z = z.reshape(num_img, z_dimension, 1, 1)
      fake_img = G(z) # 随机噪声输入到生成器中,得到一副假的图片
      output = D(fake_img) # 经过判别器得到的结果
      # g_loss = criterion(output, real_label) # 得到的假的图片与真实的图片的label的loss
      g_loss = torch.mean(output).view(1)
      # bp and optimize
      g_loss.backward(one) # 进行反向传播
      g_optimizer.step() # .step()一般用在反向传播后面,用于更新生成网络的参数
 
      # 打印中间的损失
      pbar.set_right_info(d_loss=d_loss.data.item(),
                g_loss=g_loss.data.item(),
                real_scores=real_out.data.mean().item(),
                fake_scores=fake_out.data.mean().item(),
                )
      pbar.update()
      try:
        fake_images = to_img(fake_img.cpu())
        save_image(fake_images, dir_path + '/fake_images-{}.png'.format(epoch + 1))
      except:
        pass
      if is_print:
        is_print = False
        real_images = to_img(real_img.cpu())
        save_image(real_images, dir_path + '/real_images.png')
    pbar.finish()
    d_scheduler.step()
    g_scheduler.step()
    save(D, "wgan_D.pt")
    save(G, "wgan_G.pt")

到此这篇关于Pytorch实现WGAN用于动漫头像生成的文章就介绍到这了,更多相关Pytorch实现WGAN用于动漫头像生成内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
深入Python解释器理解Python中的字节码
Apr 01 Python
Python解析json文件相关知识学习
Mar 01 Python
Python中asyncio模块的深入讲解
Jun 10 Python
selenium 安装与chromedriver安装的方法步骤
Jun 12 Python
python判断所输入的任意一个正整数是否为素数的两种方法
Jun 27 Python
如何用Python做一个微信机器人自动拉群
Jul 03 Python
Python替换NumPy数组中大于某个值的所有元素实例
Jun 08 Python
python能在浏览器能运行吗
Jun 17 Python
Python3基于print打印带颜色字符串
Jul 06 Python
python cookie反爬处理的实现
Nov 01 Python
使用Python制作一盏 3D 花灯喜迎元宵佳节
Feb 26 Python
TensorFlow中tf.batch_matmul()的用法
Jun 02 Python
基于PyInstaller各参数的含义说明
Mar 04 #Python
解决Pyinstaller打包软件失败的一个坑
Mar 04 #Python
selenium+python自动化78-autoit参数化与批量上传功能的实现
Mar 04 #Python
解决PDF 转图片时丢文字的一种可能方式
Mar 04 #Python
pandas数据分组groupby()和统计函数agg()的使用
Mar 04 #Python
pyx文件 生成pyd 文件用于 cython调用的实现
Mar 04 #Python
Python .py生成.pyd文件并打包.exe 的注意事项说明
Mar 04 #Python
You might like
php将csv文件导入到mysql数据库的方法
2014/12/24 PHP
使用php的HTTP请求的库Requests实现美女图片墙
2015/02/22 PHP
php 如何设置一个严格控制过期时间的session
2017/05/05 PHP
PHP编程文件处理类SplFileObject和SplFileInfo用法实例分析
2017/07/22 PHP
ThinkPHP 3.2.3实现加减乘除图片验证码
2018/12/05 PHP
PHP使用 Imagick 扩展实现图片合成,圆角处理功能示例
2019/09/09 PHP
TNC vs IO BO3 第一场2.13
2021/03/10 DOTA
javascript OFFICE控件测试代码
2009/12/08 Javascript
jquery中dom操作和事件的实例学习 仿yahoo邮箱登录框的提示效果
2011/11/30 Javascript
js添加table的行和列 具体实现方法
2013/07/22 Javascript
js window.open弹出新的网页窗口
2014/01/16 Javascript
关于onchange事件在IE和FF下的表现及解决方法
2014/03/08 Javascript
jQuery自动完成插件completer附源码下载
2016/01/04 Javascript
vuejs指令详解
2017/02/07 Javascript
Nodejs中使用phantom将html转为pdf或图片格式的方法
2017/09/18 NodeJs
Node.js如何优雅的封装一个实用函数的npm包的方法
2019/04/29 Javascript
Vue 2.0 侦听器 watch属性代码详解
2019/06/19 Javascript
判断JavaScript中的两个变量是否相等的操作符
2019/12/21 Javascript
详解为什么Vue中不要用index作为key(diff算法)
2020/04/04 Javascript
JavaScript 判断浏览器是否是IE
2021/02/19 Javascript
Python竟能画这么漂亮的花,帅呆了(代码分享)
2017/11/15 Python
TensorFlow平台下Python实现神经网络
2018/03/10 Python
selenium+python自动化测试之页面元素定位
2019/01/23 Python
用django-allauth实现第三方登录的示例代码
2019/06/24 Python
春节到了 教你使用python来抢票回家
2020/01/06 Python
python数据分析:关键字提取方式
2020/02/24 Python
什么是python的函数体
2020/06/19 Python
Python创建临时文件和文件夹
2020/08/05 Python
俄罗斯金苹果网上化妆品和香水商店:Goldapple
2019/12/01 全球购物
在SQL Server中创建数据库主要有那种方式
2013/09/10 面试题
经贸日语专业自荐信
2014/09/02 职场文书
2014迎接教师节演讲稿
2014/09/10 职场文书
一年级语文上册复习计划
2015/01/17 职场文书
婚礼答谢礼品
2015/01/20 职场文书
SQL Server使用导出向导功能
2022/04/08 SQL Server
Python sklearn分类决策树方法详解
2022/09/23 Python