pytorch GAN伪造手写体mnist数据集方式


Posted in Python onJanuary 10, 2020

一,mnist数据集

pytorch GAN伪造手写体mnist数据集方式

形如上图的数字手写体就是mnist数据集。

二,GAN原理(生成对抗网络)

GAN网络一共由两部分组成:一个是伪造器(Generator,简称G),一个是判别器(Discrimniator,简称D)

一开始,G由服从某几个分布(如高斯分布)的噪音组成,生成的图片不断送给D判断是否正确,直到G生成的图片连D都判断以为是真的。D每一轮除了看过G生成的假图片以外,还要见数据集中的真图片,以前者和后者得到的损失函数值为依据更新D网络中的权值。因此G和D都在不停地更新权值。以下图为例:

pytorch GAN伪造手写体mnist数据集方式

在v1时的G只不过是 一堆噪声,见过数据集(real images)的D肯定能判断出G所生成的是假的。当然G也能知道D判断它是假的这个结果,因此G就会更新权值,到v2的时候,G就能生成更逼真的图片来让D判断,当然在v2时D也是会先看一次真图片,再去判断G所生成的图片。以此类推,不断循环就是GAN的思想。

三,训练代码

import argparse
import os
import numpy as np
import math
 
import torchvision.transforms as transforms
from torchvision.utils import save_image
 
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
 
import torch.nn as nn
import torch.nn.functional as F
import torch
 
os.makedirs("images", exist_ok=True)
 
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)
 
img_shape = (opt.channels, opt.img_size, opt.img_size) # 确定图片输入的格式为(1,28,28),由于mnist数据集是灰度图所以通道为1
cuda = True if torch.cuda.is_available() else False
 
 
class Generator(nn.Module):
 def __init__(self):
  super(Generator, self).__init__()
 
  def block(in_feat, out_feat, normalize=True):
   layers = [nn.Linear(in_feat, out_feat)]
   if normalize:
    layers.append(nn.BatchNorm1d(out_feat, 0.8))
   layers.append(nn.LeakyReLU(0.2, inplace=True))
   return layers
 
  self.model = nn.Sequential(
   *block(opt.latent_dim, 128, normalize=False),
   *block(128, 256),
   *block(256, 512),
   *block(512, 1024),
   nn.Linear(1024, int(np.prod(img_shape))),
   nn.Tanh()
  )
 
 def forward(self, z):
  img = self.model(z)
  img = img.view(img.size(0), *img_shape)
  return img
 
 
class Discriminator(nn.Module):
 def __init__(self):
  super(Discriminator, self).__init__()
 
  self.model = nn.Sequential(
   nn.Linear(int(np.prod(img_shape)), 512),
   nn.LeakyReLU(0.2, inplace=True),
   nn.Linear(512, 256),
   nn.LeakyReLU(0.2, inplace=True),
   nn.Linear(256, 1),
   nn.Sigmoid(),
  )
 
 def forward(self, img):
  img_flat = img.view(img.size(0), -1)
  validity = self.model(img_flat)
  return validity
 
 
# Loss function
adversarial_loss = torch.nn.BCELoss()
 
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()
 
if cuda:
 generator.cuda()
 discriminator.cuda()
 adversarial_loss.cuda()
 
# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
 datasets.MNIST(
  "../../data/mnist",
  train=True,
  download=True,
  transform=transforms.Compose(
   [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
  ),
 ),
 batch_size=opt.batch_size,
 shuffle=True,
)
 
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
 
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
 
# ----------
# Training
# ----------
if __name__ == '__main__':
 for epoch in range(opt.n_epochs):
  for i, (imgs, _) in enumerate(dataloader):
   # print(imgs.shape)
   # Adversarial ground truths
   valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False) # 全1
   fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False) # 全0
   # Configure input
   real_imgs = Variable(imgs.type(Tensor))
 
   # -----------------
   # Train Generator
   # -----------------
 
   optimizer_G.zero_grad() # 清空G网络 上一个batch的梯度
 
   # Sample noise as generator input
   z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # 生成的噪音,均值为0方差为1维度为(64,100)的噪音
   # Generate a batch of images
   gen_imgs = generator(z)
   # Loss measures generator's ability to fool the discriminator
   g_loss = adversarial_loss(discriminator(gen_imgs), valid)
 
   g_loss.backward() # g_loss用于更新G网络的权值,g_loss于D网络的判断结果 有关
   optimizer_G.step()
 
   # ---------------------
   # Train Discriminator
   # ---------------------
 
   optimizer_D.zero_grad() # 清空D网络 上一个batch的梯度
   # Measure discriminator's ability to classify real from generated samples
   real_loss = adversarial_loss(discriminator(real_imgs), valid)
   fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
   d_loss = (real_loss + fake_loss) / 2
 
   d_loss.backward() # d_loss用于更新D网络的权值
   optimizer_D.step()
 
   print(
    "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
    % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
   )
 
   batches_done = epoch * len(dataloader) + i
   if batches_done % opt.sample_interval == 0:
    save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True) # 保存一个batchsize中的25张
   if (epoch+1) %2 ==0:
    print('save..')
    torch.save(generator,'g%d.pth' % epoch)
    torch.save(discriminator,'d%d.pth' % epoch)

运行结果:

一开始时,G生成的全是杂音:

pytorch GAN伪造手写体mnist数据集方式

然后逐渐呈现数字的雏形:

pytorch GAN伪造手写体mnist数据集方式

最后一次生成的结果:

pytorch GAN伪造手写体mnist数据集方式

四,测试代码:

导入最后保存生成器的模型:

from gan import Generator,Discriminator
import torch
import matplotlib.pyplot as plt
from torch.autograd import Variable
import numpy as np
from torchvision.utils import save_image
 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Tensor = torch.cuda.FloatTensor
g = torch.load('g199.pth') #导入生成器Generator模型
#d = torch.load('d.pth')
g = g.to(device)
#d = d.to(device)
 
z = Variable(Tensor(np.random.normal(0, 1, (64, 100)))) #输入的噪音
gen_imgs =g(z) #生产图片
save_image(gen_imgs.data[:25], "images.png" , nrow=5, normalize=True)

生成结果:

pytorch GAN伪造手写体mnist数据集方式

以上这篇pytorch GAN伪造手写体mnist数据集方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之正规地说一句话
Sep 28 Python
零基础写python爬虫之爬虫框架Scrapy安装配置
Nov 06 Python
分享Python文本生成二维码实例
Jan 06 Python
使用python中的in ,not in来检查元素是不是在列表中的方法
Jul 06 Python
Python封装原理与实现方法详解
Aug 28 Python
Python3.5多进程原理与用法实例分析
Apr 05 Python
python  logging日志打印过程解析
Oct 22 Python
Python GUI自动化实现绕过验证码登录
Jan 10 Python
基于python监控程序是否关闭
Jan 14 Python
python3+opencv生成不规则黑白mask实例
Feb 19 Python
使用OpenCV实现人脸图像卡通化的示例代码
Jan 15 Python
一篇文章带你搞懂Python类的相关知识
May 20 Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
python列表推导和生成器表达式知识点总结
Jan 10 #Python
You might like
PHP查询MySQL大量数据的时候内存占用分析
2011/07/22 PHP
php class中self,parent,this的区别以及实例介绍
2013/04/24 PHP
判断php数组是否为索引数组的实现方法
2013/06/13 PHP
php文件夹的创建与删除方法
2015/01/24 PHP
PHP实现多文件上传的方法
2015/07/08 PHP
php实现将Session写入数据库
2015/07/26 PHP
PHP通过CURL实现定时任务的图片抓取功能示例
2016/10/03 PHP
PHP实现的简单操作SQLite数据库类与用法示例
2017/06/19 PHP
Laravel中encrypt和decrypt的实现方法
2017/09/24 PHP
PHP实现获取文件mime类型多种方法解析
2020/05/28 PHP
解决jquery .ajax 在IE下卡死问题的解决方法
2009/10/26 Javascript
Jquery+CSS3实现一款简洁大气带滑动效果的弹出层
2013/05/15 Javascript
extjs两个tbar问题探讨
2013/08/08 Javascript
分享使用AngularJS创建应用的5个框架
2015/12/05 Javascript
JavaScript动态创建div等元素实例讲解
2016/01/06 Javascript
简单的jQuery banner图片轮播实例代码
2016/03/04 Javascript
详谈Node.js之操作文件系统
2017/08/29 Javascript
typescript nodejs 依赖注入实现方法代码详解
2019/07/21 NodeJs
[01:00:30]完美世界DOTA2联赛循环赛 Inki vs Matador BO2第二场 10.31
2020/11/02 DOTA
python实现socket客户端和服务端简单示例
2014/02/24 Python
Python实现新浪博客备份的方法
2016/04/27 Python
python使用Flask操作mysql实现登录功能
2018/05/14 Python
对python while循环和双重循环的实例详解
2019/08/23 Python
python 表格打印代码实例解析
2019/10/12 Python
tensorflow dataset.shuffle、dataset.batch、dataset.repeat顺序区别详解
2020/06/03 Python
CSS3实现精美横向滚动菜单按钮
2017/04/14 HTML / CSS
BabyBjörn婴儿背带法国官网:BabyBjorn法国
2018/06/16 全球购物
捷克厨房用品购物网站:Tescoma
2018/07/13 全球购物
Foot Locker英国官网:美国知名运动产品零售商
2019/02/21 全球购物
会计学自我鉴定
2014/02/06 职场文书
环保公益广告语
2014/03/13 职场文书
教师产假请假条
2014/04/10 职场文书
公司领导班子民主生活会对照检查材料
2014/10/02 职场文书
遗嘱格式范本
2015/08/07 职场文书
信息技术课教学反思
2016/02/23 职场文书
Python中可变和不可变对象的深入讲解
2021/08/02 Python