pytorch:实现简单的GAN示例(MNIST数据集)


Posted in Python onJanuary 10, 2020

我就废话不多说了,直接上代码吧!

# -*- coding: utf-8 -*-
"""
Created on Sat Oct 13 10:22:45 2018
@author: www
"""
 
import torch
from torch import nn
from torch.autograd import Variable
 
import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
 
import numpy as np
 
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
 
plt.rcParams['figure.figsize'] = (10.0, 8.0) # 设置画图的尺寸
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
 
def show_images(images): # 定义画图工具
  images = np.reshape(images, [images.shape[0], -1])
  sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
  sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
 
  fig = plt.figure(figsize=(sqrtn, sqrtn))
  gs = gridspec.GridSpec(sqrtn, sqrtn)
  gs.update(wspace=0.05, hspace=0.05)
 
  for i, img in enumerate(images):
    ax = plt.subplot(gs[i])
    plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_aspect('equal')
    plt.imshow(img.reshape([sqrtimg,sqrtimg]))
  return 
  
def preprocess_img(x):
  x = tfs.ToTensor()(x)
  return (x - 0.5) / 0.5
 
def deprocess_img(x):
  return (x + 1.0) / 2.0
 
class ChunkSampler(sampler.Sampler): # 定义一个取样的函数
  """Samples elements sequentially from some offset. 
  Arguments:
    num_samples: # of desired datapoints
    start: offset where we should start selecting from
  """
  def __init__(self, num_samples, start=0):
    self.num_samples = num_samples
    self.start = start
 
  def __iter__(self):
    return iter(range(self.start, self.start + self.num_samples))
 
  def __len__(self):
    return self.num_samples
    
NUM_TRAIN = 50000
NUM_VAL = 5000
 
NOISE_DIM = 96
batch_size = 128
 
train_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))
 
val_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))
 
imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze() # 可视化图片效果
show_images(imgs)
 
#判别网络
def discriminator():
  net = nn.Sequential(    
      nn.Linear(784, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 1)
    )
  return net
  
#生成网络
def generator(noise_dim=NOISE_DIM):  
  net = nn.Sequential(
    nn.Linear(noise_dim, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 784),
    nn.Tanh()
  )
  return net
  
#判别器的 loss 就是将真实数据的得分判断为 1,假的数据的得分判断为 0,而生成器的 loss 就是将假的数据判断为 1
 
bce_loss = nn.BCEWithLogitsLoss()#交叉熵损失函数
 
def discriminator_loss(logits_real, logits_fake): # 判别器的 loss
  size = logits_real.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  false_labels = Variable(torch.zeros(size, 1)).float()
  loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
  return loss
  
def generator_loss(logits_fake): # 生成器的 loss 
  size = logits_fake.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  loss = bce_loss(logits_fake, true_labels)
  return loss
  
# 使用 adam 来进行训练,学习率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):
  optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
  return optimizer
  
def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250, 
        noise_size=96, num_epochs=10):
  iter_count = 0
  for epoch in range(num_epochs):
    for x, _ in train_data:
      bs = x.shape[0]
      # 判别网络
      real_data = Variable(x).view(bs, -1) # 真实数据
      logits_real = D_net(real_data) # 判别网络得分
      
      sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均匀分布
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的数据
      logits_fake = D_net(fake_images) # 判别网络得分
 
      d_total_error = discriminator_loss(logits_real, logits_fake) # 判别器的 loss
      D_optimizer.zero_grad()
      d_total_error.backward()
      D_optimizer.step() # 优化判别网络
      
      # 生成网络
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的数据
 
      gen_logits_fake = D_net(fake_images)
      g_error = generator_loss(gen_logits_fake) # 生成网络的 loss
      G_optimizer.zero_grad()
      g_error.backward()
      G_optimizer.step() # 优化生成网络
 
      if (iter_count % show_every == 0):
        print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))
        imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
        show_images(imgs_numpy[0:16])
        plt.show()
        print()
      iter_count += 1
 
D = discriminator()
G = generator()
 
D_optim = get_optimizer(D)
G_optim = get_optimizer(G)
 
train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)

以上这篇pytorch:实现简单的GAN示例(MNIST数据集)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现对PPT文件进行截图操作的方法
Apr 28 Python
Python中Continue语句的用法的举例详解
May 14 Python
Python中MySQL数据迁移到MongoDB脚本的方法
Apr 28 Python
基于Python代码编辑器的选用(详解)
Sep 13 Python
Python爬虫通过替换http request header来欺骗浏览器实现登录功能
Jan 07 Python
python opencv人脸检测提取及保存方法
Aug 03 Python
对Python定时任务的启动和停止方法详解
Feb 19 Python
使用Python代码实现Linux中的ls遍历目录命令的实例代码
Sep 07 Python
Python如何实现强制数据类型转换
Nov 22 Python
如何使用Python自动生成报表并以邮件发送
Oct 15 Python
Python plt 利用subplot 实现在一张画布同时画多张图
Feb 26 Python
教你如何使用Python实现二叉树结构及三种遍历
Jun 18 Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
python列表推导和生成器表达式知识点总结
Jan 10 #Python
pytorch的梯度计算以及backward方法详解
Jan 10 #Python
Python如何获取Win7,Win10系统缩放大小
Jan 10 #Python
You might like
Codeigniter实现处理用户登录验证后的URL跳转
2014/06/12 PHP
Laravel 5.5基于内置的Auth模块实现前后台登陆详解
2017/12/21 PHP
PHP使用SMTP邮件服务器发送邮件示例
2018/08/28 PHP
ExtJS4 组件化编程,动态加载,面向对象,Direct
2011/05/12 Javascript
JavaScript的常见兼容问题及相关解决方法(chrome/IE/firefox)
2013/12/31 Javascript
JavaScript indexOf方法入门实例(计算指定字符在字符串中首次出现的位置)
2014/10/17 Javascript
JS表的模拟方法
2015/02/05 Javascript
JavaScript中使用Math.floor()方法对数字取整
2015/06/15 Javascript
详解AngularJS中自定义过滤器
2015/12/28 Javascript
js实现带三角符的手风琴效果
2017/03/01 Javascript
clipboard在vue中的使用的方法示例
2018/10/19 Javascript
如何基于JavaScript判断图片是否加载完成
2019/12/28 Javascript
解决ant Design中this.props.form.validateFields未执行的问题
2020/10/27 Javascript
[03:42]2014DOTA2西雅图国际邀请赛 Navi战队巡礼
2014/07/07 DOTA
[40:06]DOTA2亚洲邀请赛 4.3 突围赛 Liquid vs VGJ.T 第一场
2018/04/04 DOTA
python统计字符串中指定字符出现次数的方法
2015/04/04 Python
Python编程实战之Oracle数据库操作示例
2017/06/21 Python
python登录并爬取淘宝信息代码示例
2017/12/09 Python
Django 生成登陆验证码代码分享
2017/12/12 Python
python3.5绘制随机漫步图
2018/08/27 Python
python文字和unicode/ascll相互转换函数及简单加密解密实现代码
2019/08/12 Python
python 函数嵌套及多函数共同运行知识点讲解
2020/03/03 Python
python 数据类型强制转换的总结
2021/01/25 Python
New Balance英国官方网站:始于1906年,百年慢跑品牌
2016/12/07 全球购物
英国家庭家具、照明和花园家具购物网站:Furniture123
2018/12/31 全球购物
秋季运动会稿件
2014/01/30 职场文书
总经理司机岗位职责
2014/02/06 职场文书
小学班主任寄语大全
2014/04/04 职场文书
2014年入党积极分子党校培训心得体会
2014/07/08 职场文书
中国梦读书活动总结
2014/07/10 职场文书
2014年个人教学工作总结
2014/12/09 职场文书
学习雷锋精神活动总结
2015/02/06 职场文书
酒店厨房管理制度
2015/08/06 职场文书
党组织关系的介绍信模板
2019/06/21 职场文书
日本读研:怎样写好一篇日本研究计划书?
2019/07/15 职场文书
三好学生竞选稿范文
2019/08/21 职场文书