pytorch GAN伪造手写体mnist数据集方式


Posted in Python onJanuary 10, 2020

一,mnist数据集

pytorch GAN伪造手写体mnist数据集方式

形如上图的数字手写体就是mnist数据集。

二,GAN原理(生成对抗网络)

GAN网络一共由两部分组成:一个是伪造器(Generator,简称G),一个是判别器(Discrimniator,简称D)

一开始,G由服从某几个分布(如高斯分布)的噪音组成,生成的图片不断送给D判断是否正确,直到G生成的图片连D都判断以为是真的。D每一轮除了看过G生成的假图片以外,还要见数据集中的真图片,以前者和后者得到的损失函数值为依据更新D网络中的权值。因此G和D都在不停地更新权值。以下图为例:

pytorch GAN伪造手写体mnist数据集方式

在v1时的G只不过是 一堆噪声,见过数据集(real images)的D肯定能判断出G所生成的是假的。当然G也能知道D判断它是假的这个结果,因此G就会更新权值,到v2的时候,G就能生成更逼真的图片来让D判断,当然在v2时D也是会先看一次真图片,再去判断G所生成的图片。以此类推,不断循环就是GAN的思想。

三,训练代码

import argparse
import os
import numpy as np
import math
 
import torchvision.transforms as transforms
from torchvision.utils import save_image
 
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
 
import torch.nn as nn
import torch.nn.functional as F
import torch
 
os.makedirs("images", exist_ok=True)
 
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)
 
img_shape = (opt.channels, opt.img_size, opt.img_size) # 确定图片输入的格式为(1,28,28),由于mnist数据集是灰度图所以通道为1
cuda = True if torch.cuda.is_available() else False
 
 
class Generator(nn.Module):
 def __init__(self):
  super(Generator, self).__init__()
 
  def block(in_feat, out_feat, normalize=True):
   layers = [nn.Linear(in_feat, out_feat)]
   if normalize:
    layers.append(nn.BatchNorm1d(out_feat, 0.8))
   layers.append(nn.LeakyReLU(0.2, inplace=True))
   return layers
 
  self.model = nn.Sequential(
   *block(opt.latent_dim, 128, normalize=False),
   *block(128, 256),
   *block(256, 512),
   *block(512, 1024),
   nn.Linear(1024, int(np.prod(img_shape))),
   nn.Tanh()
  )
 
 def forward(self, z):
  img = self.model(z)
  img = img.view(img.size(0), *img_shape)
  return img
 
 
class Discriminator(nn.Module):
 def __init__(self):
  super(Discriminator, self).__init__()
 
  self.model = nn.Sequential(
   nn.Linear(int(np.prod(img_shape)), 512),
   nn.LeakyReLU(0.2, inplace=True),
   nn.Linear(512, 256),
   nn.LeakyReLU(0.2, inplace=True),
   nn.Linear(256, 1),
   nn.Sigmoid(),
  )
 
 def forward(self, img):
  img_flat = img.view(img.size(0), -1)
  validity = self.model(img_flat)
  return validity
 
 
# Loss function
adversarial_loss = torch.nn.BCELoss()
 
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()
 
if cuda:
 generator.cuda()
 discriminator.cuda()
 adversarial_loss.cuda()
 
# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
 datasets.MNIST(
  "../../data/mnist",
  train=True,
  download=True,
  transform=transforms.Compose(
   [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
  ),
 ),
 batch_size=opt.batch_size,
 shuffle=True,
)
 
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
 
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
 
# ----------
# Training
# ----------
if __name__ == '__main__':
 for epoch in range(opt.n_epochs):
  for i, (imgs, _) in enumerate(dataloader):
   # print(imgs.shape)
   # Adversarial ground truths
   valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False) # 全1
   fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False) # 全0
   # Configure input
   real_imgs = Variable(imgs.type(Tensor))
 
   # -----------------
   # Train Generator
   # -----------------
 
   optimizer_G.zero_grad() # 清空G网络 上一个batch的梯度
 
   # Sample noise as generator input
   z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # 生成的噪音,均值为0方差为1维度为(64,100)的噪音
   # Generate a batch of images
   gen_imgs = generator(z)
   # Loss measures generator's ability to fool the discriminator
   g_loss = adversarial_loss(discriminator(gen_imgs), valid)
 
   g_loss.backward() # g_loss用于更新G网络的权值,g_loss于D网络的判断结果 有关
   optimizer_G.step()
 
   # ---------------------
   # Train Discriminator
   # ---------------------
 
   optimizer_D.zero_grad() # 清空D网络 上一个batch的梯度
   # Measure discriminator's ability to classify real from generated samples
   real_loss = adversarial_loss(discriminator(real_imgs), valid)
   fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
   d_loss = (real_loss + fake_loss) / 2
 
   d_loss.backward() # d_loss用于更新D网络的权值
   optimizer_D.step()
 
   print(
    "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
    % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
   )
 
   batches_done = epoch * len(dataloader) + i
   if batches_done % opt.sample_interval == 0:
    save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True) # 保存一个batchsize中的25张
   if (epoch+1) %2 ==0:
    print('save..')
    torch.save(generator,'g%d.pth' % epoch)
    torch.save(discriminator,'d%d.pth' % epoch)

运行结果:

一开始时,G生成的全是杂音:

pytorch GAN伪造手写体mnist数据集方式

然后逐渐呈现数字的雏形:

pytorch GAN伪造手写体mnist数据集方式

最后一次生成的结果:

pytorch GAN伪造手写体mnist数据集方式

四,测试代码:

导入最后保存生成器的模型:

from gan import Generator,Discriminator
import torch
import matplotlib.pyplot as plt
from torch.autograd import Variable
import numpy as np
from torchvision.utils import save_image
 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Tensor = torch.cuda.FloatTensor
g = torch.load('g199.pth') #导入生成器Generator模型
#d = torch.load('d.pth')
g = g.to(device)
#d = d.to(device)
 
z = Variable(Tensor(np.random.normal(0, 1, (64, 100)))) #输入的噪音
gen_imgs =g(z) #生产图片
save_image(gen_imgs.data[:25], "images.png" , nrow=5, normalize=True)

生成结果:

pytorch GAN伪造手写体mnist数据集方式

以上这篇pytorch GAN伪造手写体mnist数据集方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
linux环境下安装pyramid和新建项目的步骤
Nov 27 Python
使用Python编写类UNIX系统的命令行工具的教程
Apr 15 Python
Python中类型检查的详细介绍
Feb 13 Python
python中将\\uxxxx转换为Unicode字符串的方法
Sep 06 Python
Python 离线工作环境搭建的方法步骤
Jul 29 Python
python try except返回异常的信息字符串代码实例
Aug 15 Python
对django layer弹窗组件的使用详解
Aug 31 Python
Python控制台输出时刷新当前行内容而不是输出新行的实现
Feb 21 Python
pycharm中import呈现灰色原因的解决方法
Mar 04 Python
keras自定义损失函数并且模型加载的写法介绍
Jun 15 Python
Django后端分离 使用element-ui文件上传方式
Jul 12 Python
详解Python openpyxl库的基本应用
Feb 26 Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
python列表推导和生成器表达式知识点总结
Jan 10 #Python
You might like
PHP的PSR规范中文版
2013/09/28 PHP
php中实现记住密码下次自动登录的例子
2014/11/06 PHP
php采集自中央气象台范围覆盖全国的天气预报代码实例
2015/01/04 PHP
php命令行模式代码实例详解
2021/02/26 PHP
javascript 防止刷新,后退,关闭
2010/08/07 Javascript
使用javascript:将其它类型值转换成布尔类型值的解决方法详解
2013/05/07 Javascript
jquery实现漂浮在网页右侧的qq在线客服插件示例
2013/05/13 Javascript
一个简单的弹性返回顶部JS代码实现介绍
2013/06/09 Javascript
JQuery操作三大控件(下拉,单选,复选)的方法
2013/08/06 Javascript
禁止按回车键提交表单的方法
2015/06/11 Javascript
JavaScript实现横向滑出的多级菜单效果
2015/10/09 Javascript
纯javascript响应式树形菜单效果
2015/11/10 Javascript
使用jquery提交form表单并自定义action的实现代码
2016/05/25 Javascript
jquery 多个radio的click事件实例
2016/12/03 Javascript
HTML中使背景图片自适应浏览器大小实例详解
2017/04/06 Javascript
vue项目中v-model父子组件通信的实现详解
2017/12/10 Javascript
JS删除数组里的某个元素方法
2018/02/03 Javascript
node.js express捕获全局异常的三种方法实例分析
2019/12/27 Javascript
vue中使用v-for时为什么不能用index作为key
2020/04/04 Javascript
vue 子组件和父组件传值的示例
2020/09/11 Javascript
JavaScript语法约定和程序调试原理解析
2020/11/03 Javascript
动态实现element ui的el-table某列数据不同样式的示例
2021/01/22 Javascript
Python中利用原始套接字进行网络编程的示例
2015/05/04 Python
Python2和Python3中print的用法示例总结
2017/10/25 Python
Django将默认的SQLite更换为MySQL的实现
2019/11/18 Python
Python turtle画图库&&画姓名实例
2020/01/19 Python
通过实例解析Python return运行原理
2020/03/04 Python
美国和加拿大房车出售在线分类广告:RVT.com
2018/04/23 全球购物
汉森冲浪板:Hansen Surfboards
2018/05/19 全球购物
自荐信写法介绍
2014/01/25 职场文书
保密承诺书
2014/03/27 职场文书
2015年社区综治工作总结
2015/04/21 职场文书
2016学习全国教书育人楷模先进事迹心得体会
2016/01/21 职场文书
html+css合并表格边框的示例代码
2021/03/31 HTML / CSS
springboot+WebMagic+MyBatis爬虫框架的使用
2021/08/07 Java/Android
Go语言并发编程 sync.Once
2021/10/16 Golang