Pytorch实现WGAN用于动漫头像生成


Posted in Python onMarch 04, 2021

WGAN与GAN的不同

  • 去除sigmoid
  • 使用具有动量的优化方法,比如使用RMSProp
  • 要对Discriminator的权重做修整限制以确保lipschitz连续约

WGAN实战卷积生成动漫头像 

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os
from anime_face_generator.dataset import ImageDataset
 
batch_size = 32
num_epoch = 100
z_dimension = 100
dir_path = './wgan_img'
 
# 创建文件夹
if not os.path.exists(dir_path):
  os.mkdir(dir_path)
 
 
def to_img(x):
  """因为我们在生成器里面用了tanh"""
  out = 0.5 * (x + 1)
  return out
 
 
dataset = ImageDataset()
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
 
 
class Generator(nn.Module):
  def __init__(self):
    super().__init__()
 
    self.gen = nn.Sequential(
      # 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map
      nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
      nn.BatchNorm2d(512),
      nn.ReLU(True),
      # 上一步的输出形状:(512) x 4 x 4
      nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
      nn.BatchNorm2d(256),
      nn.ReLU(True),
      # 上一步的输出形状: (256) x 8 x 8
      nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
      nn.BatchNorm2d(128),
      nn.ReLU(True),
      # 上一步的输出形状: (256) x 16 x 16
      nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      # 上一步的输出形状:(256) x 32 x 32
      nn.ConvTranspose2d(64, 3, 5, 3, 1, bias=False),
      nn.Tanh() # 输出范围 -1~1 故而采用Tanh
      # nn.Sigmoid()
      # 输出形状:3 x 96 x 96
    )
 
  def forward(self, x):
    x = self.gen(x)
    return x
 
  def weight_init(m):
    # weight_initialization: important for wgan
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
      m.weight.data.normal_(0, 0.02)
    elif class_name.find('Norm') != -1:
      m.weight.data.normal_(1.0, 0.02)
 
 
class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.dis = nn.Sequential(
      nn.Conv2d(3, 64, 5, 3, 1, bias=False),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (64) x 32 x 32
 
      nn.Conv2d(64, 128, 4, 2, 1, bias=False),
      nn.BatchNorm2d(128),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (128) x 16 x 16
 
      nn.Conv2d(128, 256, 4, 2, 1, bias=False),
      nn.BatchNorm2d(256),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (256) x 8 x 8
 
      nn.Conv2d(256, 512, 4, 2, 1, bias=False),
      nn.BatchNorm2d(512),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (512) x 4 x 4
 
      nn.Conv2d(512, 1, 4, 1, 0, bias=False),
      nn.Flatten(),
      # nn.Sigmoid() # 输出一个数(概率)
    )
 
  def forward(self, x):
    x = self.dis(x)
    return x
 
  def weight_init(m):
    # weight_initialization: important for wgan
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
      m.weight.data.normal_(0, 0.02)
    elif class_name.find('Norm') != -1:
      m.weight.data.normal_(1.0, 0.02)
 
 
def save(model, filename="model.pt", out_dir="out/"):
  if model is not None:
    if not os.path.exists(out_dir):
      os.mkdir(out_dir)
    torch.save({'model': model.state_dict()}, out_dir + filename)
  else:
    print("[ERROR]:Please build a model!!!")
 
 
import QuickModelBuilder as builder
 
if __name__ == '__main__':
  one = torch.FloatTensor([1]).cuda()
  mone = -1 * one
 
  is_print = True
  # 创建对象
  D = Discriminator()
  G = Generator()
  D.weight_init()
  G.weight_init()
 
  if torch.cuda.is_available():
    D = D.cuda()
    G = G.cuda()
 
  lr = 2e-4
  d_optimizer = torch.optim.RMSprop(D.parameters(), lr=lr, )
  g_optimizer = torch.optim.RMSprop(G.parameters(), lr=lr, )
  d_scheduler = torch.optim.lr_scheduler.ExponentialLR(d_optimizer, gamma=0.99)
  g_scheduler = torch.optim.lr_scheduler.ExponentialLR(g_optimizer, gamma=0.99)
 
  fake_img = None
 
  # ##########################进入训练##判别器的判断过程#####################
  for epoch in range(num_epoch): # 进行多个epoch的训练
    pbar = builder.MyTqdm(epoch=epoch, maxval=len(dataloader))
    for i, img in enumerate(dataloader):
      num_img = img.size(0)
      real_img = img.cuda() # 将tensor变成Variable放入计算图中
      # 这里的优化器是D的优化器
      for param in D.parameters():
        param.requires_grad = True
      # ########判别器训练train#####################
      # 分为两部分:1、真的图像判别为真;2、假的图像判别为假
 
      # 计算真实图片的损失
      d_optimizer.zero_grad() # 在反向传播之前,先将梯度归0
      real_out = D(real_img) # 将真实图片放入判别器中
      d_loss_real = real_out.mean(0).view(1)
      d_loss_real.backward(one)
 
      # 计算生成图片的损失
      z = torch.randn(num_img, z_dimension).cuda() # 随机生成一些噪声
      z = z.reshape(num_img, z_dimension, 1, 1)
      fake_img = G(z).detach() # 随机噪声放入生成网络中,生成一张假的图片。 # 避免梯度传到G,因为G不用更新, detach分离
      fake_out = D(fake_img) # 判别器判断假的图片,
      d_loss_fake = fake_out.mean(0).view(1)
      d_loss_fake.backward(mone)
 
      d_loss = d_loss_fake - d_loss_real
      d_optimizer.step() # 更新参数
 
      # 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c=0.01
      for parm in D.parameters():
        parm.data.clamp_(-0.01, 0.01)
 
      # ==================训练生成器============================
      # ###############################生成网络的训练###############################
      for param in D.parameters():
        param.requires_grad = False
 
      # 这里的优化器是G的优化器,所以不需要冻结D的梯度,因为不是D的优化器,不会更新D
      g_optimizer.zero_grad() # 梯度归0
 
      z = torch.randn(num_img, z_dimension).cuda()
      z = z.reshape(num_img, z_dimension, 1, 1)
      fake_img = G(z) # 随机噪声输入到生成器中,得到一副假的图片
      output = D(fake_img) # 经过判别器得到的结果
      # g_loss = criterion(output, real_label) # 得到的假的图片与真实的图片的label的loss
      g_loss = torch.mean(output).view(1)
      # bp and optimize
      g_loss.backward(one) # 进行反向传播
      g_optimizer.step() # .step()一般用在反向传播后面,用于更新生成网络的参数
 
      # 打印中间的损失
      pbar.set_right_info(d_loss=d_loss.data.item(),
                g_loss=g_loss.data.item(),
                real_scores=real_out.data.mean().item(),
                fake_scores=fake_out.data.mean().item(),
                )
      pbar.update()
      try:
        fake_images = to_img(fake_img.cpu())
        save_image(fake_images, dir_path + '/fake_images-{}.png'.format(epoch + 1))
      except:
        pass
      if is_print:
        is_print = False
        real_images = to_img(real_img.cpu())
        save_image(real_images, dir_path + '/real_images.png')
    pbar.finish()
    d_scheduler.step()
    g_scheduler.step()
    save(D, "wgan_D.pt")
    save(G, "wgan_G.pt")

到此这篇关于Pytorch实现WGAN用于动漫头像生成的文章就介绍到这了,更多相关Pytorch实现WGAN用于动漫头像生成内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python简单文本处理的方法
Jul 10 Python
python 读写、创建 文件的方法(必看)
Sep 12 Python
Python基于matplotlib绘制栈式直方图的方法示例
Aug 09 Python
Python使用pyodbc访问数据库操作方法详解
Jul 05 Python
对PyQt5中树结构的实现方法详解
Jun 17 Python
django的auth认证,authenticate和装饰器功能详解
Jul 25 Python
Python Request爬取seo.chinaz.com百度权重网站的查询结果过程解析
Aug 13 Python
150行python代码实现贪吃蛇游戏
Apr 24 Python
Django自定义YamlField实现过程解析
Nov 11 Python
Python如何使用logging为Flask增加logid
Mar 30 Python
一行代码python实现文件共享服务器
Apr 22 Python
Python实现为PDF去除水印的示例代码
Apr 03 Python
基于PyInstaller各参数的含义说明
Mar 04 #Python
解决Pyinstaller打包软件失败的一个坑
Mar 04 #Python
selenium+python自动化78-autoit参数化与批量上传功能的实现
Mar 04 #Python
解决PDF 转图片时丢文字的一种可能方式
Mar 04 #Python
pandas数据分组groupby()和统计函数agg()的使用
Mar 04 #Python
pyx文件 生成pyd 文件用于 cython调用的实现
Mar 04 #Python
Python .py生成.pyd文件并打包.exe 的注意事项说明
Mar 04 #Python
You might like
用来解析.htpasswd文件的PHP类
2012/09/05 PHP
Zend Framework处理Json数据方法详解
2016/12/09 PHP
php ajax confirm 删除实例详解
2019/03/06 PHP
拖动一个HTML元素
2006/12/22 Javascript
不间断滚动JS打包类,基本可以实现所有的滚动效果,太强了
2007/12/08 Javascript
jQuery 常见操作实现方式和常用函数方法总结
2011/05/06 Javascript
使用jQuery.fn自定义jQuery翻页插件
2013/01/20 Javascript
js Array操作的最简短最容易理解方法
2013/12/09 Javascript
原生JavaScript实现Ajax的方法
2016/04/07 Javascript
Javascript之Number对象介绍
2016/06/07 Javascript
JavaScript拖动层Div代码
2017/03/01 Javascript
jQuery插件echarts实现的单折线图效果示例【附demo源码下载】
2017/03/04 Javascript
从0到1搭建Element的后台框架的方法步骤
2019/04/10 Javascript
vue+moment实现倒计时效果
2019/08/26 Javascript
Python生成数字图片代码分享
2017/10/31 Python
python实现一个简单的并查集的示例代码
2018/03/19 Python
Python3.7中安装openCV库的方法
2018/07/11 Python
python事件驱动event实现详解
2018/11/21 Python
Python3实现从排序数组中删除重复项算法分析
2019/04/03 Python
使用Django和Postgres进行全文搜索的实例代码
2020/02/13 Python
Keras SGD 随机梯度下降优化器参数设置方式
2020/06/19 Python
Python通过类的组合模拟街道红绿灯
2020/09/16 Python
深深扎根运动世界的生活品牌:Tillys
2017/10/30 全球购物
来自美国主售篮球鞋的零售商店:KICKSUSA
2017/11/28 全球购物
澳大利亚女士时装在线:Rockmans
2018/09/26 全球购物
Notino意大利:购买香水和化妆品
2018/11/14 全球购物
现在输入n个数字,以逗号,分开;然后可选择升或者降序排序;按提交键就在另一页面显示按什么排序,结果为,提供reset
2012/11/09 面试题
早读迟到检讨书
2014/01/24 职场文书
优秀导游先进事迹材料
2014/01/25 职场文书
国际商贸专业自荐信
2014/06/09 职场文书
党员志愿者活动方案
2014/08/28 职场文书
会议室管理制度范本
2015/08/06 职场文书
小学班级口号大全
2015/12/25 职场文书
《揠苗助长》教学反思
2016/02/20 职场文书
《悬崖边的树》读后感2篇
2019/12/02 职场文书
python 实现定时任务的四种方式
2021/04/01 Python