pytorch:实现简单的GAN示例(MNIST数据集)


Posted in Python onJanuary 10, 2020

我就废话不多说了,直接上代码吧!

# -*- coding: utf-8 -*-
"""
Created on Sat Oct 13 10:22:45 2018
@author: www
"""
 
import torch
from torch import nn
from torch.autograd import Variable
 
import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
 
import numpy as np
 
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
 
plt.rcParams['figure.figsize'] = (10.0, 8.0) # 设置画图的尺寸
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
 
def show_images(images): # 定义画图工具
  images = np.reshape(images, [images.shape[0], -1])
  sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
  sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
 
  fig = plt.figure(figsize=(sqrtn, sqrtn))
  gs = gridspec.GridSpec(sqrtn, sqrtn)
  gs.update(wspace=0.05, hspace=0.05)
 
  for i, img in enumerate(images):
    ax = plt.subplot(gs[i])
    plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_aspect('equal')
    plt.imshow(img.reshape([sqrtimg,sqrtimg]))
  return 
  
def preprocess_img(x):
  x = tfs.ToTensor()(x)
  return (x - 0.5) / 0.5
 
def deprocess_img(x):
  return (x + 1.0) / 2.0
 
class ChunkSampler(sampler.Sampler): # 定义一个取样的函数
  """Samples elements sequentially from some offset. 
  Arguments:
    num_samples: # of desired datapoints
    start: offset where we should start selecting from
  """
  def __init__(self, num_samples, start=0):
    self.num_samples = num_samples
    self.start = start
 
  def __iter__(self):
    return iter(range(self.start, self.start + self.num_samples))
 
  def __len__(self):
    return self.num_samples
    
NUM_TRAIN = 50000
NUM_VAL = 5000
 
NOISE_DIM = 96
batch_size = 128
 
train_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))
 
val_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))
 
imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze() # 可视化图片效果
show_images(imgs)
 
#判别网络
def discriminator():
  net = nn.Sequential(    
      nn.Linear(784, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 1)
    )
  return net
  
#生成网络
def generator(noise_dim=NOISE_DIM):  
  net = nn.Sequential(
    nn.Linear(noise_dim, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 784),
    nn.Tanh()
  )
  return net
  
#判别器的 loss 就是将真实数据的得分判断为 1,假的数据的得分判断为 0,而生成器的 loss 就是将假的数据判断为 1
 
bce_loss = nn.BCEWithLogitsLoss()#交叉熵损失函数
 
def discriminator_loss(logits_real, logits_fake): # 判别器的 loss
  size = logits_real.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  false_labels = Variable(torch.zeros(size, 1)).float()
  loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
  return loss
  
def generator_loss(logits_fake): # 生成器的 loss 
  size = logits_fake.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  loss = bce_loss(logits_fake, true_labels)
  return loss
  
# 使用 adam 来进行训练,学习率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):
  optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
  return optimizer
  
def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250, 
        noise_size=96, num_epochs=10):
  iter_count = 0
  for epoch in range(num_epochs):
    for x, _ in train_data:
      bs = x.shape[0]
      # 判别网络
      real_data = Variable(x).view(bs, -1) # 真实数据
      logits_real = D_net(real_data) # 判别网络得分
      
      sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均匀分布
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的数据
      logits_fake = D_net(fake_images) # 判别网络得分
 
      d_total_error = discriminator_loss(logits_real, logits_fake) # 判别器的 loss
      D_optimizer.zero_grad()
      d_total_error.backward()
      D_optimizer.step() # 优化判别网络
      
      # 生成网络
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的数据
 
      gen_logits_fake = D_net(fake_images)
      g_error = generator_loss(gen_logits_fake) # 生成网络的 loss
      G_optimizer.zero_grad()
      g_error.backward()
      G_optimizer.step() # 优化生成网络
 
      if (iter_count % show_every == 0):
        print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))
        imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
        show_images(imgs_numpy[0:16])
        plt.show()
        print()
      iter_count += 1
 
D = discriminator()
G = generator()
 
D_optim = get_optimizer(D)
G_optim = get_optimizer(G)
 
train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)

以上这篇pytorch:实现简单的GAN示例(MNIST数据集)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 从远程服务器下载东西的代码
Feb 10 Python
Python中的random()方法的使用介绍
May 15 Python
Python随机生成手机号、数字的方法详解
Jul 21 Python
PyQt5利用QPainter绘制各种图形的实例
Oct 19 Python
详解如何使用Python编写vim插件
Nov 28 Python
聊聊Python中的pypy
Jan 12 Python
利用Pandas 创建空的DataFrame方法
Apr 08 Python
Python实现的批量修改文件后缀名操作示例
Dec 07 Python
scrapy-redis的安装部署步骤讲解
Feb 27 Python
Python如何通过Flask-Mail发送电子邮件
Jan 29 Python
Python内置异常类型全面汇总
May 28 Python
python模块如何查看
Jun 16 Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
python列表推导和生成器表达式知识点总结
Jan 10 #Python
pytorch的梯度计算以及backward方法详解
Jan 10 #Python
Python如何获取Win7,Win10系统缩放大小
Jan 10 #Python
You might like
模仿OSO的论坛(三)
2006/10/09 PHP
php在页面中调用fckeditor编辑器的方法
2011/06/10 PHP
详解PHP+AJAX无刷新分页实现方法
2015/11/03 PHP
PHP使用http_build_query()构造URL字符串的方法
2016/04/02 PHP
详解php命令注入攻击
2019/04/06 PHP
javascript 事件查询综合 推荐收藏
2010/03/10 Javascript
jQuery-onload让第一次页面加载时图片是淡入方式显示
2012/05/23 Javascript
JavaScript对内存分配及管理机制详细解析
2013/11/11 Javascript
js 剪切板的用法(clipboardData.setData)与js match函数介绍
2013/11/19 Javascript
ExtJS4如何自动生成控制grid的列显示、隐藏的checkbox
2014/05/02 Javascript
javascript动态判断html元素并执行不同的操作
2014/06/16 Javascript
JavaScript实现彩虹文字效果的方法
2015/04/16 Javascript
js实现向右横向滑出的二级菜单效果
2015/08/27 Javascript
通过js获取上传的图片信息(临时保存路径,名称,大小)然后通过ajax传递给后端的方法
2015/10/01 Javascript
vue2.0开发实践总结之入门篇
2016/12/06 Javascript
Bootstrap select下拉联动(jQuery cxselect)
2017/01/04 Javascript
12个非常有用的JavaScript技巧
2017/05/17 Javascript
JS全角与半角转化实例(分享)
2017/07/04 Javascript
深入浅析JSONAPI在PHP中的应用
2017/12/24 Javascript
jQuery轻量级表单模型验证插件
2018/10/15 jQuery
layui 对table中的数据进行转义的实例
2019/09/12 Javascript
JavaScript cookie原理及使用实例
2020/05/08 Javascript
Python使用matplotlib简单绘图示例
2018/02/01 Python
Python批量提取PDF文件中文本的脚本
2018/03/14 Python
朴素贝叶斯分类算法原理与Python实现与使用方法案例
2018/06/26 Python
Python import与from import使用及区别介绍
2018/09/06 Python
使用Python横向合并excel文件的实例
2018/12/11 Python
Python 序列化和反序列化库 MarshMallow 的用法实例代码
2020/02/25 Python
如何表示python中的相对路径
2020/07/08 Python
Python tkinter界面实现历史天气查询的示例代码
2020/08/23 Python
香港化妆品经销商:我的公主
2016/08/05 全球购物
升职自荐书范文
2013/11/28 职场文书
工作作风整顿个人剖析材料
2014/10/11 职场文书
用php如何解决大文件分片上传问题
2021/07/07 PHP
《杜鹃的婚约》OP主题曲「凸凹」无字幕影像公开
2022/04/08 日漫
详解flex:1什么意思
2022/07/23 HTML / CSS