pytorch:实现简单的GAN示例(MNIST数据集)


Posted in Python onJanuary 10, 2020

我就废话不多说了,直接上代码吧!

# -*- coding: utf-8 -*-
"""
Created on Sat Oct 13 10:22:45 2018
@author: www
"""
 
import torch
from torch import nn
from torch.autograd import Variable
 
import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
 
import numpy as np
 
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
 
plt.rcParams['figure.figsize'] = (10.0, 8.0) # 设置画图的尺寸
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
 
def show_images(images): # 定义画图工具
  images = np.reshape(images, [images.shape[0], -1])
  sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
  sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
 
  fig = plt.figure(figsize=(sqrtn, sqrtn))
  gs = gridspec.GridSpec(sqrtn, sqrtn)
  gs.update(wspace=0.05, hspace=0.05)
 
  for i, img in enumerate(images):
    ax = plt.subplot(gs[i])
    plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_aspect('equal')
    plt.imshow(img.reshape([sqrtimg,sqrtimg]))
  return 
  
def preprocess_img(x):
  x = tfs.ToTensor()(x)
  return (x - 0.5) / 0.5
 
def deprocess_img(x):
  return (x + 1.0) / 2.0
 
class ChunkSampler(sampler.Sampler): # 定义一个取样的函数
  """Samples elements sequentially from some offset. 
  Arguments:
    num_samples: # of desired datapoints
    start: offset where we should start selecting from
  """
  def __init__(self, num_samples, start=0):
    self.num_samples = num_samples
    self.start = start
 
  def __iter__(self):
    return iter(range(self.start, self.start + self.num_samples))
 
  def __len__(self):
    return self.num_samples
    
NUM_TRAIN = 50000
NUM_VAL = 5000
 
NOISE_DIM = 96
batch_size = 128
 
train_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))
 
val_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))
 
imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze() # 可视化图片效果
show_images(imgs)
 
#判别网络
def discriminator():
  net = nn.Sequential(    
      nn.Linear(784, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 1)
    )
  return net
  
#生成网络
def generator(noise_dim=NOISE_DIM):  
  net = nn.Sequential(
    nn.Linear(noise_dim, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 784),
    nn.Tanh()
  )
  return net
  
#判别器的 loss 就是将真实数据的得分判断为 1,假的数据的得分判断为 0,而生成器的 loss 就是将假的数据判断为 1
 
bce_loss = nn.BCEWithLogitsLoss()#交叉熵损失函数
 
def discriminator_loss(logits_real, logits_fake): # 判别器的 loss
  size = logits_real.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  false_labels = Variable(torch.zeros(size, 1)).float()
  loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
  return loss
  
def generator_loss(logits_fake): # 生成器的 loss 
  size = logits_fake.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  loss = bce_loss(logits_fake, true_labels)
  return loss
  
# 使用 adam 来进行训练,学习率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):
  optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
  return optimizer
  
def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250, 
        noise_size=96, num_epochs=10):
  iter_count = 0
  for epoch in range(num_epochs):
    for x, _ in train_data:
      bs = x.shape[0]
      # 判别网络
      real_data = Variable(x).view(bs, -1) # 真实数据
      logits_real = D_net(real_data) # 判别网络得分
      
      sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均匀分布
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的数据
      logits_fake = D_net(fake_images) # 判别网络得分
 
      d_total_error = discriminator_loss(logits_real, logits_fake) # 判别器的 loss
      D_optimizer.zero_grad()
      d_total_error.backward()
      D_optimizer.step() # 优化判别网络
      
      # 生成网络
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的数据
 
      gen_logits_fake = D_net(fake_images)
      g_error = generator_loss(gen_logits_fake) # 生成网络的 loss
      G_optimizer.zero_grad()
      g_error.backward()
      G_optimizer.step() # 优化生成网络
 
      if (iter_count % show_every == 0):
        print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))
        imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
        show_images(imgs_numpy[0:16])
        plt.show()
        print()
      iter_count += 1
 
D = discriminator()
G = generator()
 
D_optim = get_optimizer(D)
G_optim = get_optimizer(G)
 
train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)

以上这篇pytorch:实现简单的GAN示例(MNIST数据集)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现简单状态框架的方法
Mar 19 Python
python中字典(Dictionary)用法实例详解
May 30 Python
python with提前退出遇到的坑与解决方案
Jan 05 Python
python画出三角形外接圆和内切圆的方法
Jan 25 Python
Python中if elif else及缩进的使用简述
May 31 Python
Python随机生成身份证号码及校验功能
Dec 04 Python
使用Python自动化破解自定义字体混淆信息的方法实例
Feb 13 Python
Python2.7版os.path.isdir中文路径返回false的解决方法
Jun 21 Python
python模块hashlib(加密服务)知识点讲解
Nov 25 Python
python实现取余操作的简单实例
Aug 16 Python
python 实现mysql自动增删分区的方法
Apr 01 Python
Pytorch 中net.train 和 net.eval的使用说明
May 22 Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
python列表推导和生成器表达式知识点总结
Jan 10 #Python
pytorch的梯度计算以及backward方法详解
Jan 10 #Python
Python如何获取Win7,Win10系统缩放大小
Jan 10 #Python
You might like
php 论坛采集程序 模拟登陆,抓取页面 实现代码
2009/07/09 PHP
兼容性比较好的PHP生成缩略图的代码
2011/01/12 PHP
thinkphp的c方法使用示例
2014/02/24 PHP
PHP里8个鲜为人知的安全函数分析
2014/12/09 PHP
php利用递归实现删除文件目录的方法
2016/09/23 PHP
PHP切割整数工具类似微信红包金额分配的思路详解
2019/09/18 PHP
TNC vs BOOM BO3 第二场2.13
2021/03/10 DOTA
Javascript合并表格中具有相同内容单元格示例
2013/08/11 Javascript
javascript用户注册提示效果的简单实例
2013/08/17 Javascript
JS判断浏览器是否支持某一个CSS3属性的方法
2014/10/17 Javascript
JavaScript变量的作用域全解析
2015/08/14 Javascript
jQuery实现的导航下拉菜单效果示例
2016/09/05 Javascript
JS二叉树的简单实现方法示例
2017/04/05 Javascript
JS实现弹出下载对话框及常见文件类型的下载
2017/07/13 Javascript
通过 JS 判断页面是否有滚动条的实现方法
2018/04/05 Javascript
解决jQuery使用append添加的元素事件无效的问题
2018/08/30 jQuery
layer弹出子iframe层父子页面传值的实现方法
2018/11/22 Javascript
layer插件实现在弹出层中弹出一警告提示并关闭弹出层的方法
2019/09/24 Javascript
python操作日期和时间的方法
2014/03/11 Python
Python AES加密实例解析
2018/01/18 Python
基于python的Paxos算法实现
2019/07/03 Python
Python学习笔记之Break和Continue用法分析
2019/08/14 Python
python创建ArcGIS shape文件的实现
2019/12/06 Python
pytorch实现mnist分类的示例讲解
2020/01/10 Python
Django Model中字段(field)的各种选项说明
2020/05/19 Python
荷兰本土平价百货:HEMA
2017/10/23 全球购物
怀俄明州飞钓:Platte River Fly Shop
2017/12/28 全球购物
Otticanet意大利:最顶尖的世界名牌眼镜, 能得到打折季的价格
2019/03/10 全球购物
WebSphere面试题:在WebSphere里面如何部署一个应用
2015/08/02 面试题
深圳茁壮笔试题
2015/05/28 面试题
应用数学自荐书范文
2013/11/24 职场文书
环卫个人总结
2015/03/03 职场文书
2015学校六五普法工作总结
2015/04/22 职场文书
2015年学校减负工作总结
2015/05/19 职场文书
Python 实现Mac 屏幕截图详解
2021/10/05 Python
处理canvas绘制图片模糊问题
2022/05/11 Javascript