pytorch GAN伪造手写体mnist数据集方式


Posted in Python onJanuary 10, 2020

一,mnist数据集

pytorch GAN伪造手写体mnist数据集方式

形如上图的数字手写体就是mnist数据集。

二,GAN原理(生成对抗网络)

GAN网络一共由两部分组成:一个是伪造器(Generator,简称G),一个是判别器(Discrimniator,简称D)

一开始,G由服从某几个分布(如高斯分布)的噪音组成,生成的图片不断送给D判断是否正确,直到G生成的图片连D都判断以为是真的。D每一轮除了看过G生成的假图片以外,还要见数据集中的真图片,以前者和后者得到的损失函数值为依据更新D网络中的权值。因此G和D都在不停地更新权值。以下图为例:

pytorch GAN伪造手写体mnist数据集方式

在v1时的G只不过是 一堆噪声,见过数据集(real images)的D肯定能判断出G所生成的是假的。当然G也能知道D判断它是假的这个结果,因此G就会更新权值,到v2的时候,G就能生成更逼真的图片来让D判断,当然在v2时D也是会先看一次真图片,再去判断G所生成的图片。以此类推,不断循环就是GAN的思想。

三,训练代码

import argparse
import os
import numpy as np
import math
 
import torchvision.transforms as transforms
from torchvision.utils import save_image
 
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
 
import torch.nn as nn
import torch.nn.functional as F
import torch
 
os.makedirs("images", exist_ok=True)
 
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)
 
img_shape = (opt.channels, opt.img_size, opt.img_size) # 确定图片输入的格式为(1,28,28),由于mnist数据集是灰度图所以通道为1
cuda = True if torch.cuda.is_available() else False
 
 
class Generator(nn.Module):
 def __init__(self):
  super(Generator, self).__init__()
 
  def block(in_feat, out_feat, normalize=True):
   layers = [nn.Linear(in_feat, out_feat)]
   if normalize:
    layers.append(nn.BatchNorm1d(out_feat, 0.8))
   layers.append(nn.LeakyReLU(0.2, inplace=True))
   return layers
 
  self.model = nn.Sequential(
   *block(opt.latent_dim, 128, normalize=False),
   *block(128, 256),
   *block(256, 512),
   *block(512, 1024),
   nn.Linear(1024, int(np.prod(img_shape))),
   nn.Tanh()
  )
 
 def forward(self, z):
  img = self.model(z)
  img = img.view(img.size(0), *img_shape)
  return img
 
 
class Discriminator(nn.Module):
 def __init__(self):
  super(Discriminator, self).__init__()
 
  self.model = nn.Sequential(
   nn.Linear(int(np.prod(img_shape)), 512),
   nn.LeakyReLU(0.2, inplace=True),
   nn.Linear(512, 256),
   nn.LeakyReLU(0.2, inplace=True),
   nn.Linear(256, 1),
   nn.Sigmoid(),
  )
 
 def forward(self, img):
  img_flat = img.view(img.size(0), -1)
  validity = self.model(img_flat)
  return validity
 
 
# Loss function
adversarial_loss = torch.nn.BCELoss()
 
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()
 
if cuda:
 generator.cuda()
 discriminator.cuda()
 adversarial_loss.cuda()
 
# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
 datasets.MNIST(
  "../../data/mnist",
  train=True,
  download=True,
  transform=transforms.Compose(
   [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
  ),
 ),
 batch_size=opt.batch_size,
 shuffle=True,
)
 
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
 
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
 
# ----------
# Training
# ----------
if __name__ == '__main__':
 for epoch in range(opt.n_epochs):
  for i, (imgs, _) in enumerate(dataloader):
   # print(imgs.shape)
   # Adversarial ground truths
   valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False) # 全1
   fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False) # 全0
   # Configure input
   real_imgs = Variable(imgs.type(Tensor))
 
   # -----------------
   # Train Generator
   # -----------------
 
   optimizer_G.zero_grad() # 清空G网络 上一个batch的梯度
 
   # Sample noise as generator input
   z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # 生成的噪音,均值为0方差为1维度为(64,100)的噪音
   # Generate a batch of images
   gen_imgs = generator(z)
   # Loss measures generator's ability to fool the discriminator
   g_loss = adversarial_loss(discriminator(gen_imgs), valid)
 
   g_loss.backward() # g_loss用于更新G网络的权值,g_loss于D网络的判断结果 有关
   optimizer_G.step()
 
   # ---------------------
   # Train Discriminator
   # ---------------------
 
   optimizer_D.zero_grad() # 清空D网络 上一个batch的梯度
   # Measure discriminator's ability to classify real from generated samples
   real_loss = adversarial_loss(discriminator(real_imgs), valid)
   fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
   d_loss = (real_loss + fake_loss) / 2
 
   d_loss.backward() # d_loss用于更新D网络的权值
   optimizer_D.step()
 
   print(
    "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
    % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
   )
 
   batches_done = epoch * len(dataloader) + i
   if batches_done % opt.sample_interval == 0:
    save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True) # 保存一个batchsize中的25张
   if (epoch+1) %2 ==0:
    print('save..')
    torch.save(generator,'g%d.pth' % epoch)
    torch.save(discriminator,'d%d.pth' % epoch)

运行结果:

一开始时,G生成的全是杂音:

pytorch GAN伪造手写体mnist数据集方式

然后逐渐呈现数字的雏形:

pytorch GAN伪造手写体mnist数据集方式

最后一次生成的结果:

pytorch GAN伪造手写体mnist数据集方式

四,测试代码:

导入最后保存生成器的模型:

from gan import Generator,Discriminator
import torch
import matplotlib.pyplot as plt
from torch.autograd import Variable
import numpy as np
from torchvision.utils import save_image
 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Tensor = torch.cuda.FloatTensor
g = torch.load('g199.pth') #导入生成器Generator模型
#d = torch.load('d.pth')
g = g.to(device)
#d = d.to(device)
 
z = Variable(Tensor(np.random.normal(0, 1, (64, 100)))) #输入的噪音
gen_imgs =g(z) #生产图片
save_image(gen_imgs.data[:25], "images.png" , nrow=5, normalize=True)

生成结果:

pytorch GAN伪造手写体mnist数据集方式

以上这篇pytorch GAN伪造手写体mnist数据集方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中global与nonlocal比较
Nov 21 Python
利用Python演示数型数据结构的教程
Apr 03 Python
举例讲解Python设计模式编程的代理模式与抽象工厂模式
Jan 16 Python
对numpy中数组元素的统一赋值实例
Apr 04 Python
python实现守护进程、守护线程、守护非守护并行
May 05 Python
python pandas模块基础学习详解
Jul 03 Python
python numpy 常用随机数的产生方法的实现
Aug 21 Python
python多线程实现代码(模拟银行服务操作流程)
Jan 13 Python
pycharm 更改创建文件默认路径的操作
Feb 15 Python
TensorFlow2.X使用图片制作简单的数据集训练模型
Apr 08 Python
Python-openpyxl表格读取写入的案例详解
Nov 02 Python
Python 视频画质增强
Apr 28 Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
python列表推导和生成器表达式知识点总结
Jan 10 #Python
You might like
PHP文件类型检查及fileinfo模块安装使用详解
2019/05/09 PHP
详解使用php-cs-fixer格式化代码
2020/09/16 PHP
JQuery扩展插件Validate 5添加自定义验证方法
2011/09/05 Javascript
Jquery 例外被抛出且未被接住原因介绍
2013/09/04 Javascript
运用JQuery的toggle实现网页加载完成自动弹窗
2014/03/18 Javascript
浅谈JavaScript中的字符编码转换问题
2015/07/07 Javascript
jquery图片滚动放大代码分享(1)
2015/08/25 Javascript
jquery点击缩略图切换视频播放特效代码分享
2015/09/15 Javascript
js实现图片上传并正常显示
2015/12/19 Javascript
AngularJS 过滤与排序详解及实例代码
2016/09/14 Javascript
浅析Javascript的自动分号插入(ASI)机制
2016/09/29 Javascript
vue2组件实现懒加载浅析
2017/03/29 Javascript
Vue.js使用$.ajax和vue-resource实现OAuth的注册、登录、注销和API调用
2017/05/10 Javascript
react redux入门示例
2018/04/19 Javascript
Vue中util的工具函数实例详解
2019/07/08 Javascript
vue实现评论列表功能
2019/10/25 Javascript
vue 输入电话号码自动按3-4-4分割功能的实现代码
2020/04/30 Javascript
javascript实现时间日期的格式化的方法汇总
2020/08/06 Javascript
python实现根据指定字符截取对应的行的内容方法
2018/10/23 Python
Python使用sklearn库实现的各种分类算法简单应用小结
2019/07/04 Python
Python实现画图软件功能方法详解
2020/07/28 Python
纯CSS实现预加载动画效果
2017/09/06 HTML / CSS
canvas绘制视频封面的方法
2018/02/05 HTML / CSS
同程旅游英文网站:LY.com
2018/11/13 全球购物
新西兰便宜隐形眼镜购买网站:QUICKLENS New Zealand
2019/03/02 全球购物
Liu Jo西班牙官网:意大利服装品牌
2019/09/11 全球购物
介绍一下XMLHttpRequest对象的常用方法和属性
2013/05/24 面试题
机械专业个人求职自荐信格式
2013/09/21 职场文书
消防安全管理制度
2014/02/01 职场文书
优秀的个人求职信范文
2014/05/09 职场文书
服装发布会策划方案
2014/05/22 职场文书
企业活动策划方案
2014/06/02 职场文书
学习保证书怎么写
2015/02/26 职场文书
2015年社区文体活动总结
2015/03/25 职场文书
2015秋季运动会通讯稿
2015/07/18 职场文书
Tensorflow与RNN、双向LSTM等的踩坑记录及解决
2021/05/31 Python