pytorch:实现简单的GAN示例(MNIST数据集)


Posted in Python onJanuary 10, 2020

我就废话不多说了,直接上代码吧!

# -*- coding: utf-8 -*-
"""
Created on Sat Oct 13 10:22:45 2018
@author: www
"""
 
import torch
from torch import nn
from torch.autograd import Variable
 
import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
 
import numpy as np
 
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
 
plt.rcParams['figure.figsize'] = (10.0, 8.0) # 设置画图的尺寸
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
 
def show_images(images): # 定义画图工具
  images = np.reshape(images, [images.shape[0], -1])
  sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
  sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
 
  fig = plt.figure(figsize=(sqrtn, sqrtn))
  gs = gridspec.GridSpec(sqrtn, sqrtn)
  gs.update(wspace=0.05, hspace=0.05)
 
  for i, img in enumerate(images):
    ax = plt.subplot(gs[i])
    plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_aspect('equal')
    plt.imshow(img.reshape([sqrtimg,sqrtimg]))
  return 
  
def preprocess_img(x):
  x = tfs.ToTensor()(x)
  return (x - 0.5) / 0.5
 
def deprocess_img(x):
  return (x + 1.0) / 2.0
 
class ChunkSampler(sampler.Sampler): # 定义一个取样的函数
  """Samples elements sequentially from some offset. 
  Arguments:
    num_samples: # of desired datapoints
    start: offset where we should start selecting from
  """
  def __init__(self, num_samples, start=0):
    self.num_samples = num_samples
    self.start = start
 
  def __iter__(self):
    return iter(range(self.start, self.start + self.num_samples))
 
  def __len__(self):
    return self.num_samples
    
NUM_TRAIN = 50000
NUM_VAL = 5000
 
NOISE_DIM = 96
batch_size = 128
 
train_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))
 
val_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))
 
imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze() # 可视化图片效果
show_images(imgs)
 
#判别网络
def discriminator():
  net = nn.Sequential(    
      nn.Linear(784, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 1)
    )
  return net
  
#生成网络
def generator(noise_dim=NOISE_DIM):  
  net = nn.Sequential(
    nn.Linear(noise_dim, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 784),
    nn.Tanh()
  )
  return net
  
#判别器的 loss 就是将真实数据的得分判断为 1,假的数据的得分判断为 0,而生成器的 loss 就是将假的数据判断为 1
 
bce_loss = nn.BCEWithLogitsLoss()#交叉熵损失函数
 
def discriminator_loss(logits_real, logits_fake): # 判别器的 loss
  size = logits_real.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  false_labels = Variable(torch.zeros(size, 1)).float()
  loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
  return loss
  
def generator_loss(logits_fake): # 生成器的 loss 
  size = logits_fake.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  loss = bce_loss(logits_fake, true_labels)
  return loss
  
# 使用 adam 来进行训练,学习率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):
  optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
  return optimizer
  
def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250, 
        noise_size=96, num_epochs=10):
  iter_count = 0
  for epoch in range(num_epochs):
    for x, _ in train_data:
      bs = x.shape[0]
      # 判别网络
      real_data = Variable(x).view(bs, -1) # 真实数据
      logits_real = D_net(real_data) # 判别网络得分
      
      sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均匀分布
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的数据
      logits_fake = D_net(fake_images) # 判别网络得分
 
      d_total_error = discriminator_loss(logits_real, logits_fake) # 判别器的 loss
      D_optimizer.zero_grad()
      d_total_error.backward()
      D_optimizer.step() # 优化判别网络
      
      # 生成网络
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的数据
 
      gen_logits_fake = D_net(fake_images)
      g_error = generator_loss(gen_logits_fake) # 生成网络的 loss
      G_optimizer.zero_grad()
      g_error.backward()
      G_optimizer.step() # 优化生成网络
 
      if (iter_count % show_every == 0):
        print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))
        imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
        show_images(imgs_numpy[0:16])
        plt.show()
        print()
      iter_count += 1
 
D = discriminator()
G = generator()
 
D_optim = get_optimizer(D)
G_optim = get_optimizer(G)
 
train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)

以上这篇pytorch:实现简单的GAN示例(MNIST数据集)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 远程统计文件代码分享
May 14 Python
使用Python脚本将文字转换为图片的实例分享
Aug 29 Python
12步教你理解Python装饰器
Feb 25 Python
Python 爬取携程所有机票的实例代码
Jun 11 Python
python os.listdir按文件存取时间顺序列出目录的实例
Oct 21 Python
Python面向对象之继承和多态用法分析
Jun 08 Python
python原类、类的创建过程与方法详解
Jul 19 Python
Python企业编码生成系统之主程序模块设计详解
Jul 26 Python
一行Python代码过滤标点符号等特殊字符
Aug 12 Python
Golang GBK转UTF-8的例子
Aug 26 Python
python爬虫开发之urllib模块详细使用方法与实例全解
Mar 09 Python
pytorch下的unsqueeze和squeeze的用法说明
Feb 06 Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
python列表推导和生成器表达式知识点总结
Jan 10 #Python
pytorch的梯度计算以及backward方法详解
Jan 10 #Python
Python如何获取Win7,Win10系统缩放大小
Jan 10 #Python
You might like
用PHP制作静态网站的模板框架(二)
2006/10/09 PHP
PHP+iframe模拟Ajax上传文件功能示例
2019/07/02 PHP
JavaScript入门教程(12) js对象化编程
2009/01/31 Javascript
JS判断页面是否出现滚动条的方法
2015/07/17 Javascript
基于jQuery和CSS3制作响应式水平时间轴附源码下载
2015/12/20 Javascript
轻松搞定jQuery.noConflict()
2016/02/15 Javascript
学习Javascript面向对象编程之封装
2016/02/23 Javascript
jquery插件格式实例分析
2016/06/16 Javascript
使用JSON作为函数的参数的优缺点
2016/10/27 Javascript
Ajax异步获取html数据中包含js方法无效的解决方法
2017/02/20 Javascript
Node.js使用gm拼装sprite图片
2017/07/04 Javascript
vue + element-ui实现简洁的导入导出功能
2017/12/22 Javascript
小程序日历控件使用方法详解
2018/12/29 Javascript
jQuery实现表格的增、删、改操作示例
2019/01/27 jQuery
JS数组扁平化(flat)方法总结详解
2019/06/24 Javascript
vue+moment实现倒计时效果
2019/08/26 Javascript
vue.js循环radio的实例
2019/11/07 Javascript
JavaScript多种图形实现代码实例
2020/06/28 Javascript
详解为什么Vue中的v-if和v-for不建议一起用
2021/01/13 Vue.js
[54:27]TNC vs Serenity 2018国际邀请赛小组赛BO2 第一场 8.18
2018/08/19 DOTA
Python中的map、reduce和filter浅析
2014/04/26 Python
简单介绍Python中的几种数据类型
2016/01/02 Python
Python中static相关知识小结
2018/01/02 Python
Flask框架中request、请求钩子、上下文用法分析
2019/07/23 Python
python 提取文件指定列的方法示例
2019/08/07 Python
Python 使用 docopt 解析json参数文件过程讲解
2019/08/13 Python
Python利用PyExecJS库执行JS函数的案例分析
2019/12/18 Python
Python3+selenium配置常见报错解决方案
2020/08/28 Python
Python HTMLTestRunner如何下载生成报告
2020/09/04 Python
Myprotein加拿大官网:欧洲第一的运动营养品牌
2018/01/06 全球购物
电脑教师的自我评价
2013/12/18 职场文书
计算机维护专业推荐信
2014/02/27 职场文书
淘宝客服工作职责
2014/07/11 职场文书
二年级上册数学教学计划
2015/01/20 职场文书
Python中生成随机数据安全性、多功能性、用途和速度方面进行比较
2022/04/14 Python
python 镜像环境搭建总结
2022/09/23 Python