pytorch:实现简单的GAN示例(MNIST数据集)


Posted in Python onJanuary 10, 2020

我就废话不多说了,直接上代码吧!

# -*- coding: utf-8 -*-
"""
Created on Sat Oct 13 10:22:45 2018
@author: www
"""
 
import torch
from torch import nn
from torch.autograd import Variable
 
import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
 
import numpy as np
 
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
 
plt.rcParams['figure.figsize'] = (10.0, 8.0) # 设置画图的尺寸
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
 
def show_images(images): # 定义画图工具
  images = np.reshape(images, [images.shape[0], -1])
  sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
  sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
 
  fig = plt.figure(figsize=(sqrtn, sqrtn))
  gs = gridspec.GridSpec(sqrtn, sqrtn)
  gs.update(wspace=0.05, hspace=0.05)
 
  for i, img in enumerate(images):
    ax = plt.subplot(gs[i])
    plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_aspect('equal')
    plt.imshow(img.reshape([sqrtimg,sqrtimg]))
  return 
  
def preprocess_img(x):
  x = tfs.ToTensor()(x)
  return (x - 0.5) / 0.5
 
def deprocess_img(x):
  return (x + 1.0) / 2.0
 
class ChunkSampler(sampler.Sampler): # 定义一个取样的函数
  """Samples elements sequentially from some offset. 
  Arguments:
    num_samples: # of desired datapoints
    start: offset where we should start selecting from
  """
  def __init__(self, num_samples, start=0):
    self.num_samples = num_samples
    self.start = start
 
  def __iter__(self):
    return iter(range(self.start, self.start + self.num_samples))
 
  def __len__(self):
    return self.num_samples
    
NUM_TRAIN = 50000
NUM_VAL = 5000
 
NOISE_DIM = 96
batch_size = 128
 
train_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))
 
val_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))
 
imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze() # 可视化图片效果
show_images(imgs)
 
#判别网络
def discriminator():
  net = nn.Sequential(    
      nn.Linear(784, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 1)
    )
  return net
  
#生成网络
def generator(noise_dim=NOISE_DIM):  
  net = nn.Sequential(
    nn.Linear(noise_dim, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 784),
    nn.Tanh()
  )
  return net
  
#判别器的 loss 就是将真实数据的得分判断为 1,假的数据的得分判断为 0,而生成器的 loss 就是将假的数据判断为 1
 
bce_loss = nn.BCEWithLogitsLoss()#交叉熵损失函数
 
def discriminator_loss(logits_real, logits_fake): # 判别器的 loss
  size = logits_real.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  false_labels = Variable(torch.zeros(size, 1)).float()
  loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
  return loss
  
def generator_loss(logits_fake): # 生成器的 loss 
  size = logits_fake.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  loss = bce_loss(logits_fake, true_labels)
  return loss
  
# 使用 adam 来进行训练,学习率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):
  optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
  return optimizer
  
def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250, 
        noise_size=96, num_epochs=10):
  iter_count = 0
  for epoch in range(num_epochs):
    for x, _ in train_data:
      bs = x.shape[0]
      # 判别网络
      real_data = Variable(x).view(bs, -1) # 真实数据
      logits_real = D_net(real_data) # 判别网络得分
      
      sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均匀分布
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的数据
      logits_fake = D_net(fake_images) # 判别网络得分
 
      d_total_error = discriminator_loss(logits_real, logits_fake) # 判别器的 loss
      D_optimizer.zero_grad()
      d_total_error.backward()
      D_optimizer.step() # 优化判别网络
      
      # 生成网络
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的数据
 
      gen_logits_fake = D_net(fake_images)
      g_error = generator_loss(gen_logits_fake) # 生成网络的 loss
      G_optimizer.zero_grad()
      g_error.backward()
      G_optimizer.step() # 优化生成网络
 
      if (iter_count % show_every == 0):
        print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))
        imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
        show_images(imgs_numpy[0:16])
        plt.show()
        print()
      iter_count += 1
 
D = discriminator()
G = generator()
 
D_optim = get_optimizer(D)
G_optim = get_optimizer(G)
 
train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)

以上这篇pytorch:实现简单的GAN示例(MNIST数据集)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python遍历 truple list dictionary的几种方法总结
Sep 11 Python
pandas object格式转float64格式的方法
Apr 10 Python
pandas的唯一值、值计数以及成员资格的示例
Jul 25 Python
python  Django中的apps.py的目的是什么
Oct 15 Python
解决Pycharm 导入其他文件夹源码的2种方法
Feb 12 Python
Python单链表原理与实现方法详解
Feb 22 Python
pycharm实现print输出保存到txt文件
Jun 01 Python
Pytorch生成随机数Tensor的方法汇总
Sep 09 Python
使用python操作lmdb对数据读取的实例
Dec 11 Python
python wsgiref源码解析
Feb 06 Python
opencv用VS2013调试时用Image Watch插件查看图片
Jul 26 Python
Python使用mitmproxy工具监控手机 下载手机小视频
Apr 18 Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
python列表推导和生成器表达式知识点总结
Jan 10 #Python
pytorch的梯度计算以及backward方法详解
Jan 10 #Python
Python如何获取Win7,Win10系统缩放大小
Jan 10 #Python
You might like
怎样在UNIX系统下安装MySQL
2006/10/09 PHP
SWFUpload与CI不能正确上传识别文件MIME类型解决方法分享
2011/04/18 PHP
PHP缓存技术的多种方法小结
2012/08/14 PHP
php获取数组长度的方法(有实例)
2013/10/27 PHP
PHP翻页跳转功能实现方法
2020/11/30 PHP
javascript静态的url如何传递
2007/05/03 Javascript
jQuery Tips 为AJAX回调函数传递额外参数的方法
2010/12/28 Javascript
jquery 定位input元素的几种方法小结
2013/07/28 Javascript
JavaScript基础语法、dom操作树及document对象
2014/12/02 Javascript
Jquery网页内滑动缓冲导航的实现代码
2015/04/05 Javascript
JS中获取函数调用链所有参数的方法
2015/05/07 Javascript
JS+CSS实现表格高亮的方法
2015/08/05 Javascript
jQuery实现鼠标滑过链接控制图片的滑动展开与隐藏效果
2015/10/28 Javascript
微信小程序自定义toast弹窗效果的实现代码
2018/11/15 Javascript
Node.js 多进程处理CPU密集任务的实现
2019/05/26 Javascript
[01:02:18]VGJ.S vs infamous Supermajor 败者组 BO3 第一场 6.4
2018/06/05 DOTA
Python的多态性实例分析
2015/07/07 Python
对python-3-print重定向输出的几种方法总结
2018/05/11 Python
教你利用Python玩转histogram直方图的五种方法
2018/07/30 Python
python 获取utc时间转化为本地时间的方法
2018/12/31 Python
Django实现发送邮件功能
2019/07/18 Python
python爬虫selenium和phantomJs使用方法解析
2019/08/08 Python
Python实现桌面翻译工具【新手必学】
2020/02/12 Python
python统计字符串中字母出现次数代码实例
2020/03/02 Python
通过python调用adb命令对App进行性能测试方式
2020/04/23 Python
详解CSS3中nth-child与nth-of-type的区别
2017/01/05 HTML / CSS
英国高街品牌:Miss Selfridge(塞尔弗里奇小姐)
2016/09/21 全球购物
能源工程专业应届生求职信
2014/03/01 职场文书
面试自我评价范文
2014/09/17 职场文书
店铺转让协议书
2015/01/29 职场文书
高中生自我评价范文2015
2015/03/03 职场文书
因公司原因离职的辞职信范文
2015/05/12 职场文书
小学生反邪教心得体会
2016/01/15 职场文书
《自己去吧》教学反思
2016/02/16 职场文书
创业计划书之o2o水果店
2019/08/30 职场文书
使用canvas仿Echarts实现金字塔图的实例代码
2021/11/11 HTML / CSS