pytorch:实现简单的GAN示例(MNIST数据集)


Posted in Python onJanuary 10, 2020

我就废话不多说了,直接上代码吧!

# -*- coding: utf-8 -*-
"""
Created on Sat Oct 13 10:22:45 2018
@author: www
"""
 
import torch
from torch import nn
from torch.autograd import Variable
 
import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
 
import numpy as np
 
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
 
plt.rcParams['figure.figsize'] = (10.0, 8.0) # 设置画图的尺寸
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
 
def show_images(images): # 定义画图工具
  images = np.reshape(images, [images.shape[0], -1])
  sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
  sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
 
  fig = plt.figure(figsize=(sqrtn, sqrtn))
  gs = gridspec.GridSpec(sqrtn, sqrtn)
  gs.update(wspace=0.05, hspace=0.05)
 
  for i, img in enumerate(images):
    ax = plt.subplot(gs[i])
    plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_aspect('equal')
    plt.imshow(img.reshape([sqrtimg,sqrtimg]))
  return 
  
def preprocess_img(x):
  x = tfs.ToTensor()(x)
  return (x - 0.5) / 0.5
 
def deprocess_img(x):
  return (x + 1.0) / 2.0
 
class ChunkSampler(sampler.Sampler): # 定义一个取样的函数
  """Samples elements sequentially from some offset. 
  Arguments:
    num_samples: # of desired datapoints
    start: offset where we should start selecting from
  """
  def __init__(self, num_samples, start=0):
    self.num_samples = num_samples
    self.start = start
 
  def __iter__(self):
    return iter(range(self.start, self.start + self.num_samples))
 
  def __len__(self):
    return self.num_samples
    
NUM_TRAIN = 50000
NUM_VAL = 5000
 
NOISE_DIM = 96
batch_size = 128
 
train_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))
 
val_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))
 
imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze() # 可视化图片效果
show_images(imgs)
 
#判别网络
def discriminator():
  net = nn.Sequential(    
      nn.Linear(784, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 1)
    )
  return net
  
#生成网络
def generator(noise_dim=NOISE_DIM):  
  net = nn.Sequential(
    nn.Linear(noise_dim, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 784),
    nn.Tanh()
  )
  return net
  
#判别器的 loss 就是将真实数据的得分判断为 1,假的数据的得分判断为 0,而生成器的 loss 就是将假的数据判断为 1
 
bce_loss = nn.BCEWithLogitsLoss()#交叉熵损失函数
 
def discriminator_loss(logits_real, logits_fake): # 判别器的 loss
  size = logits_real.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  false_labels = Variable(torch.zeros(size, 1)).float()
  loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
  return loss
  
def generator_loss(logits_fake): # 生成器的 loss 
  size = logits_fake.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  loss = bce_loss(logits_fake, true_labels)
  return loss
  
# 使用 adam 来进行训练,学习率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):
  optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
  return optimizer
  
def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250, 
        noise_size=96, num_epochs=10):
  iter_count = 0
  for epoch in range(num_epochs):
    for x, _ in train_data:
      bs = x.shape[0]
      # 判别网络
      real_data = Variable(x).view(bs, -1) # 真实数据
      logits_real = D_net(real_data) # 判别网络得分
      
      sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均匀分布
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的数据
      logits_fake = D_net(fake_images) # 判别网络得分
 
      d_total_error = discriminator_loss(logits_real, logits_fake) # 判别器的 loss
      D_optimizer.zero_grad()
      d_total_error.backward()
      D_optimizer.step() # 优化判别网络
      
      # 生成网络
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的数据
 
      gen_logits_fake = D_net(fake_images)
      g_error = generator_loss(gen_logits_fake) # 生成网络的 loss
      G_optimizer.zero_grad()
      g_error.backward()
      G_optimizer.step() # 优化生成网络
 
      if (iter_count % show_every == 0):
        print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))
        imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
        show_images(imgs_numpy[0:16])
        plt.show()
        print()
      iter_count += 1
 
D = discriminator()
G = generator()
 
D_optim = get_optimizer(D)
G_optim = get_optimizer(G)
 
train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)

以上这篇pytorch:实现简单的GAN示例(MNIST数据集)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python求解水仙花数的方法
May 11 Python
pandas 两列时间相减换算为秒的方法
Apr 20 Python
Python将文本去空格并保存到txt文件中的实例
Jul 24 Python
使用memory_profiler监测python代码运行时内存消耗方法
Dec 03 Python
python使用selenium登录QQ邮箱(附带滑动解锁)
Jan 23 Python
python实现广度优先搜索过程解析
Oct 19 Python
记一次pyinstaller打包pygame项目为exe的过程(带图片)
Mar 02 Python
python GUI库图形界面开发之PyQt5布局控件QGridLayout详细使用方法与实例
Mar 06 Python
Python气泡提示与标签的实现
Apr 01 Python
opencv 图像腐蚀和图像膨胀的实现
Jul 07 Python
Python如何操作docker redis过程解析
Aug 10 Python
python实现计算图形面积
Feb 22 Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
python列表推导和生成器表达式知识点总结
Jan 10 #Python
pytorch的梯度计算以及backward方法详解
Jan 10 #Python
Python如何获取Win7,Win10系统缩放大小
Jan 10 #Python
You might like
一个可以找出源代码中所有中文的工具
2006/10/25 PHP
WINXP下apache+php4+mysql
2006/11/25 PHP
php下载文件的代码示例
2012/06/29 PHP
phpQuery占用内存过多的处理方法
2013/11/13 PHP
分享一个Laravel好用的Cache宏
2015/03/02 PHP
PHP、Java des加密解密实例
2015/04/27 PHP
不同浏览器的怪癖小结
2010/07/11 Javascript
微信中一些常用的js方法汇总
2015/03/12 Javascript
jquery中attr和prop的区别分析
2015/03/16 Javascript
JavaScript 事件对象介绍
2015/04/13 Javascript
angularJs中orderBy筛选以及filter过滤数据的方法
2018/09/30 Javascript
解决vue.js提交数组时出现数组下标的问题
2019/11/05 Javascript
vue+render+jsx实现可编辑动态多级表头table的实例代码
2020/04/01 Javascript
JavaScript常用工具函数汇总(浏览器环境)
2020/09/17 Javascript
[05:20]卡尔工作室_DOTA2新手教学_DOTA2超强新手功能
2013/04/22 DOTA
Python中无限元素列表的实现方法
2014/08/18 Python
Python调用C语言的方法【基于ctypes模块】
2018/01/22 Python
Python实现检测文件MD5值的方法示例
2018/04/11 Python
详解python分布式进程
2018/10/08 Python
python3实现逐字输出的方法
2019/01/23 Python
python实现简单日期工具类
2019/04/24 Python
Python3.5 Json与pickle实现数据序列化与反序列化操作示例
2019/04/29 Python
关于阿里云oss获取sts凭证 app直传 python的实例
2019/08/20 Python
python3 selenium自动化 frame表单嵌套的切换方法
2019/08/23 Python
windows下Pycharm安装opencv的多种方法
2020/03/05 Python
python wsgiref源码解析
2021/02/06 Python
用CSS禁用输入法(CSS3 UI规范)实例解析
2012/12/04 HTML / CSS
美国钻石商店:Zales
2016/11/20 全球购物
杭州龙健科技笔试题.net部分笔试题
2016/01/24 面试题
事业单位公务员的职业生涯规划
2014/01/15 职场文书
婚庆司仪主持词
2014/03/15 职场文书
优秀团支部申报材料
2014/12/26 职场文书
预备党员自我评价范文
2015/03/04 职场文书
教师节倡议书2015
2015/04/27 职场文书
三方合作意向书范本
2015/05/09 职场文书
python_tkinter弹出对话框创建
2022/03/20 Python