Pytorch实现WGAN用于动漫头像生成


Posted in Python onMarch 04, 2021

WGAN与GAN的不同

  • 去除sigmoid
  • 使用具有动量的优化方法,比如使用RMSProp
  • 要对Discriminator的权重做修整限制以确保lipschitz连续约

WGAN实战卷积生成动漫头像 

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os
from anime_face_generator.dataset import ImageDataset
 
batch_size = 32
num_epoch = 100
z_dimension = 100
dir_path = './wgan_img'
 
# 创建文件夹
if not os.path.exists(dir_path):
  os.mkdir(dir_path)
 
 
def to_img(x):
  """因为我们在生成器里面用了tanh"""
  out = 0.5 * (x + 1)
  return out
 
 
dataset = ImageDataset()
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
 
 
class Generator(nn.Module):
  def __init__(self):
    super().__init__()
 
    self.gen = nn.Sequential(
      # 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map
      nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
      nn.BatchNorm2d(512),
      nn.ReLU(True),
      # 上一步的输出形状:(512) x 4 x 4
      nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
      nn.BatchNorm2d(256),
      nn.ReLU(True),
      # 上一步的输出形状: (256) x 8 x 8
      nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
      nn.BatchNorm2d(128),
      nn.ReLU(True),
      # 上一步的输出形状: (256) x 16 x 16
      nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      # 上一步的输出形状:(256) x 32 x 32
      nn.ConvTranspose2d(64, 3, 5, 3, 1, bias=False),
      nn.Tanh() # 输出范围 -1~1 故而采用Tanh
      # nn.Sigmoid()
      # 输出形状:3 x 96 x 96
    )
 
  def forward(self, x):
    x = self.gen(x)
    return x
 
  def weight_init(m):
    # weight_initialization: important for wgan
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
      m.weight.data.normal_(0, 0.02)
    elif class_name.find('Norm') != -1:
      m.weight.data.normal_(1.0, 0.02)
 
 
class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.dis = nn.Sequential(
      nn.Conv2d(3, 64, 5, 3, 1, bias=False),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (64) x 32 x 32
 
      nn.Conv2d(64, 128, 4, 2, 1, bias=False),
      nn.BatchNorm2d(128),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (128) x 16 x 16
 
      nn.Conv2d(128, 256, 4, 2, 1, bias=False),
      nn.BatchNorm2d(256),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (256) x 8 x 8
 
      nn.Conv2d(256, 512, 4, 2, 1, bias=False),
      nn.BatchNorm2d(512),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (512) x 4 x 4
 
      nn.Conv2d(512, 1, 4, 1, 0, bias=False),
      nn.Flatten(),
      # nn.Sigmoid() # 输出一个数(概率)
    )
 
  def forward(self, x):
    x = self.dis(x)
    return x
 
  def weight_init(m):
    # weight_initialization: important for wgan
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
      m.weight.data.normal_(0, 0.02)
    elif class_name.find('Norm') != -1:
      m.weight.data.normal_(1.0, 0.02)
 
 
def save(model, filename="model.pt", out_dir="out/"):
  if model is not None:
    if not os.path.exists(out_dir):
      os.mkdir(out_dir)
    torch.save({'model': model.state_dict()}, out_dir + filename)
  else:
    print("[ERROR]:Please build a model!!!")
 
 
import QuickModelBuilder as builder
 
if __name__ == '__main__':
  one = torch.FloatTensor([1]).cuda()
  mone = -1 * one
 
  is_print = True
  # 创建对象
  D = Discriminator()
  G = Generator()
  D.weight_init()
  G.weight_init()
 
  if torch.cuda.is_available():
    D = D.cuda()
    G = G.cuda()
 
  lr = 2e-4
  d_optimizer = torch.optim.RMSprop(D.parameters(), lr=lr, )
  g_optimizer = torch.optim.RMSprop(G.parameters(), lr=lr, )
  d_scheduler = torch.optim.lr_scheduler.ExponentialLR(d_optimizer, gamma=0.99)
  g_scheduler = torch.optim.lr_scheduler.ExponentialLR(g_optimizer, gamma=0.99)
 
  fake_img = None
 
  # ##########################进入训练##判别器的判断过程#####################
  for epoch in range(num_epoch): # 进行多个epoch的训练
    pbar = builder.MyTqdm(epoch=epoch, maxval=len(dataloader))
    for i, img in enumerate(dataloader):
      num_img = img.size(0)
      real_img = img.cuda() # 将tensor变成Variable放入计算图中
      # 这里的优化器是D的优化器
      for param in D.parameters():
        param.requires_grad = True
      # ########判别器训练train#####################
      # 分为两部分:1、真的图像判别为真;2、假的图像判别为假
 
      # 计算真实图片的损失
      d_optimizer.zero_grad() # 在反向传播之前,先将梯度归0
      real_out = D(real_img) # 将真实图片放入判别器中
      d_loss_real = real_out.mean(0).view(1)
      d_loss_real.backward(one)
 
      # 计算生成图片的损失
      z = torch.randn(num_img, z_dimension).cuda() # 随机生成一些噪声
      z = z.reshape(num_img, z_dimension, 1, 1)
      fake_img = G(z).detach() # 随机噪声放入生成网络中,生成一张假的图片。 # 避免梯度传到G,因为G不用更新, detach分离
      fake_out = D(fake_img) # 判别器判断假的图片,
      d_loss_fake = fake_out.mean(0).view(1)
      d_loss_fake.backward(mone)
 
      d_loss = d_loss_fake - d_loss_real
      d_optimizer.step() # 更新参数
 
      # 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c=0.01
      for parm in D.parameters():
        parm.data.clamp_(-0.01, 0.01)
 
      # ==================训练生成器============================
      # ###############################生成网络的训练###############################
      for param in D.parameters():
        param.requires_grad = False
 
      # 这里的优化器是G的优化器,所以不需要冻结D的梯度,因为不是D的优化器,不会更新D
      g_optimizer.zero_grad() # 梯度归0
 
      z = torch.randn(num_img, z_dimension).cuda()
      z = z.reshape(num_img, z_dimension, 1, 1)
      fake_img = G(z) # 随机噪声输入到生成器中,得到一副假的图片
      output = D(fake_img) # 经过判别器得到的结果
      # g_loss = criterion(output, real_label) # 得到的假的图片与真实的图片的label的loss
      g_loss = torch.mean(output).view(1)
      # bp and optimize
      g_loss.backward(one) # 进行反向传播
      g_optimizer.step() # .step()一般用在反向传播后面,用于更新生成网络的参数
 
      # 打印中间的损失
      pbar.set_right_info(d_loss=d_loss.data.item(),
                g_loss=g_loss.data.item(),
                real_scores=real_out.data.mean().item(),
                fake_scores=fake_out.data.mean().item(),
                )
      pbar.update()
      try:
        fake_images = to_img(fake_img.cpu())
        save_image(fake_images, dir_path + '/fake_images-{}.png'.format(epoch + 1))
      except:
        pass
      if is_print:
        is_print = False
        real_images = to_img(real_img.cpu())
        save_image(real_images, dir_path + '/real_images.png')
    pbar.finish()
    d_scheduler.step()
    g_scheduler.step()
    save(D, "wgan_D.pt")
    save(G, "wgan_G.pt")

到此这篇关于Pytorch实现WGAN用于动漫头像生成的文章就介绍到这了,更多相关Pytorch实现WGAN用于动漫头像生成内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python读取环境变量的方法和自定义类分享
Nov 22 Python
Python实现获取某天是某个月中的第几周
Feb 11 Python
解读Django框架中的低层次缓存API
Jul 24 Python
Python使用当前时间、随机数产生一个唯一数字的方法
Sep 18 Python
Python装饰器知识点补充
May 28 Python
Python Socket编程之多线程聊天室
Jul 28 Python
python调用其他文件函数或类的示例
Jul 16 Python
python tkinter图形界面代码统计工具
Sep 18 Python
python实现大学人员管理系统
Oct 25 Python
Python中无限循环需要什么条件
May 27 Python
python中编写函数并调用的知识点总结
Jan 13 Python
使用python对excel表格处理的一些小功能
Jan 25 Python
基于PyInstaller各参数的含义说明
Mar 04 #Python
解决Pyinstaller打包软件失败的一个坑
Mar 04 #Python
selenium+python自动化78-autoit参数化与批量上传功能的实现
Mar 04 #Python
解决PDF 转图片时丢文字的一种可能方式
Mar 04 #Python
pandas数据分组groupby()和统计函数agg()的使用
Mar 04 #Python
pyx文件 生成pyd 文件用于 cython调用的实现
Mar 04 #Python
Python .py生成.pyd文件并打包.exe 的注意事项说明
Mar 04 #Python
You might like
PHP SQLite类
2009/05/07 PHP
解析php开发中的中文编码问题
2013/08/08 PHP
php实现的验证码文件类实例
2015/06/18 PHP
function foo的原型与prototype属性解惑
2010/11/19 Javascript
20个最新的jQuery插件
2012/01/13 Javascript
JQuery.Ajax之错误调试帮助信息介绍
2013/07/04 Javascript
浏览器图片选择预览、旋转、批量上传的JS代码实现
2013/12/04 Javascript
jquery实现键盘左右翻页特效
2015/04/30 Javascript
详解JavaScript中setSeconds()方法的使用
2015/06/11 Javascript
js实现网页收藏功能
2015/12/17 Javascript
jQuery+ajax简单实现文件上传的方法
2016/06/03 Javascript
JS调用某段SQL语句的方法
2016/10/20 Javascript
jQuery实现页面倒计时并刷新效果
2017/03/13 Javascript
vue watch监听对象及对应值的变化详解
2018/02/24 Javascript
jQuery实现合并表格单元格中相同行操作示例
2019/01/28 jQuery
详解原生JS动态添加和删除类
2019/03/26 Javascript
Vue动态加载图片在跨域时无法显示的问题及解决方法
2020/03/10 Javascript
javascript实现前端成语点击验证
2020/06/24 Javascript
js实现简单音乐播放器
2020/06/30 Javascript
python统计文本文件内单词数量的方法
2015/05/30 Python
Python实现计算两个时间之间相差天数的方法
2017/05/10 Python
Python设计模式之命令模式简单示例
2018/01/10 Python
Python+OpenCV图片局部区域像素值处理详解
2019/01/23 Python
pycharm 更改创建文件默认路径的操作
2020/02/15 Python
Django模板标签{% for %}循环,获取制定条数据实例
2020/05/14 Python
CSS3实战第一波 让我们尽情的圆角吧
2010/08/27 HTML / CSS
英国领先的在线鱼贩:The Fish Society
2020/08/12 全球购物
教师绩效工资方案
2014/02/01 职场文书
中学生学雷锋活动心得体会
2014/03/10 职场文书
乡镇消防安全责任书
2014/07/23 职场文书
党的群众路线对照检查材料范文
2014/09/24 职场文书
刑事辩护授权委托书范本
2014/10/17 职场文书
党员学习新党章思想汇报
2014/10/25 职场文书
2014年档案管理工作总结
2014/11/17 职场文书
捐助倡议书
2015/01/19 职场文书
Python实现学生管理系统并生成exe可执行文件详解流程
2022/01/22 Python