pytorch:实现简单的GAN示例(MNIST数据集)


Posted in Python onJanuary 10, 2020

我就废话不多说了,直接上代码吧!

# -*- coding: utf-8 -*-
"""
Created on Sat Oct 13 10:22:45 2018
@author: www
"""
 
import torch
from torch import nn
from torch.autograd import Variable
 
import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
 
import numpy as np
 
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
 
plt.rcParams['figure.figsize'] = (10.0, 8.0) # 设置画图的尺寸
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
 
def show_images(images): # 定义画图工具
  images = np.reshape(images, [images.shape[0], -1])
  sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
  sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
 
  fig = plt.figure(figsize=(sqrtn, sqrtn))
  gs = gridspec.GridSpec(sqrtn, sqrtn)
  gs.update(wspace=0.05, hspace=0.05)
 
  for i, img in enumerate(images):
    ax = plt.subplot(gs[i])
    plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_aspect('equal')
    plt.imshow(img.reshape([sqrtimg,sqrtimg]))
  return 
  
def preprocess_img(x):
  x = tfs.ToTensor()(x)
  return (x - 0.5) / 0.5
 
def deprocess_img(x):
  return (x + 1.0) / 2.0
 
class ChunkSampler(sampler.Sampler): # 定义一个取样的函数
  """Samples elements sequentially from some offset. 
  Arguments:
    num_samples: # of desired datapoints
    start: offset where we should start selecting from
  """
  def __init__(self, num_samples, start=0):
    self.num_samples = num_samples
    self.start = start
 
  def __iter__(self):
    return iter(range(self.start, self.start + self.num_samples))
 
  def __len__(self):
    return self.num_samples
    
NUM_TRAIN = 50000
NUM_VAL = 5000
 
NOISE_DIM = 96
batch_size = 128
 
train_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))
 
val_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))
 
imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze() # 可视化图片效果
show_images(imgs)
 
#判别网络
def discriminator():
  net = nn.Sequential(    
      nn.Linear(784, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 1)
    )
  return net
  
#生成网络
def generator(noise_dim=NOISE_DIM):  
  net = nn.Sequential(
    nn.Linear(noise_dim, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 784),
    nn.Tanh()
  )
  return net
  
#判别器的 loss 就是将真实数据的得分判断为 1,假的数据的得分判断为 0,而生成器的 loss 就是将假的数据判断为 1
 
bce_loss = nn.BCEWithLogitsLoss()#交叉熵损失函数
 
def discriminator_loss(logits_real, logits_fake): # 判别器的 loss
  size = logits_real.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  false_labels = Variable(torch.zeros(size, 1)).float()
  loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
  return loss
  
def generator_loss(logits_fake): # 生成器的 loss 
  size = logits_fake.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  loss = bce_loss(logits_fake, true_labels)
  return loss
  
# 使用 adam 来进行训练,学习率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):
  optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
  return optimizer
  
def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250, 
        noise_size=96, num_epochs=10):
  iter_count = 0
  for epoch in range(num_epochs):
    for x, _ in train_data:
      bs = x.shape[0]
      # 判别网络
      real_data = Variable(x).view(bs, -1) # 真实数据
      logits_real = D_net(real_data) # 判别网络得分
      
      sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均匀分布
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的数据
      logits_fake = D_net(fake_images) # 判别网络得分
 
      d_total_error = discriminator_loss(logits_real, logits_fake) # 判别器的 loss
      D_optimizer.zero_grad()
      d_total_error.backward()
      D_optimizer.step() # 优化判别网络
      
      # 生成网络
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的数据
 
      gen_logits_fake = D_net(fake_images)
      g_error = generator_loss(gen_logits_fake) # 生成网络的 loss
      G_optimizer.zero_grad()
      g_error.backward()
      G_optimizer.step() # 优化生成网络
 
      if (iter_count % show_every == 0):
        print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))
        imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
        show_images(imgs_numpy[0:16])
        plt.show()
        print()
      iter_count += 1
 
D = discriminator()
G = generator()
 
D_optim = get_optimizer(D)
G_optim = get_optimizer(G)
 
train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)

以上这篇pytorch:实现简单的GAN示例(MNIST数据集)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python cookielib 登录人人网的实现代码
Dec 19 Python
利用Python获取赶集网招聘信息前篇
Apr 18 Python
Django项目中用JS实现加载子页面并传值的方法
May 28 Python
Python3中urlencode和urldecode的用法详解
Jul 23 Python
Python中zip()函数的简单用法举例
Sep 02 Python
在Python中使用MongoEngine操作数据库教程实例
Dec 03 Python
Python3-异步进程回调函数(callback())介绍
May 02 Python
Anaconda3中的Jupyter notebook添加目录插件的实现
May 18 Python
python 3.8.3 安装配置图文教程
May 21 Python
PyCharm2020最新激活码+激活码补丁(亲测最新版PyCharm2020.2激活成功)
Nov 25 Python
python用分数表示矩阵的方法实例
Jan 11 Python
Python数据可视化之绘制柱状图和条形图
May 25 Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
python列表推导和生成器表达式知识点总结
Jan 10 #Python
pytorch的梯度计算以及backward方法详解
Jan 10 #Python
Python如何获取Win7,Win10系统缩放大小
Jan 10 #Python
You might like
php中关于codeigniter的xmlrpc的类在进行数据交换时的类型问题
2011/07/03 PHP
PHP调用.NET的WebService 简单实例
2015/03/27 PHP
2007/12/23更新创意无限,简单实用(javascript log)
2007/12/24 Javascript
jqPlot jquery的页面图表绘制工具
2009/07/25 Javascript
简洁短小的 JavaScript IE 浏览器判定代码
2010/03/21 Javascript
JS扩展方法实例分析
2015/04/15 Javascript
用js编写的简单的计算器代码程序
2015/08/04 Javascript
javascript实现显示和隐藏div方法汇总
2015/08/14 Javascript
js传值后台中文出现乱码的解决方法
2016/06/30 Javascript
JS锚点的设置与使用方法
2016/09/05 Javascript
JQuery手速测试小游戏实现思路详解
2016/09/20 Javascript
微信小程序 开发工具快捷键整理
2016/10/31 Javascript
如何快速上手Vuex
2017/02/14 Javascript
JS按钮闪烁功能的实现代码
2017/07/21 Javascript
Node.js使用cookie保持登录的方法
2018/05/11 Javascript
详解vue 数组和对象渲染问题
2018/09/21 Javascript
vue移动端html5页面根据屏幕适配的四种解决方法
2018/10/19 Javascript
vue props对象validator自定义函数实例
2019/11/13 Javascript
Vue+Vuex实现自动登录的知识点详解
2020/03/04 Javascript
[54:26]完美世界DOTA2联赛PWL S3 Forest vs Rebirth 第一场 12.10
2020/12/12 DOTA
在windows下快速搭建web.py开发框架方法
2016/04/22 Python
python实现读取大文件并逐行写入另外一个文件
2018/04/19 Python
python实现飞机大战微信小游戏
2020/03/21 Python
python装饰器简介---这一篇也许就够了(推荐)
2019/04/01 Python
python flask框架实现传数据到js的方法分析
2019/06/11 Python
Python通过yagmail实现发送邮件代码解析
2020/10/27 Python
Sentry错误日志监控使用方法解析
2020/11/12 Python
使用html5 canvas创建太空游戏的示例
2014/05/08 HTML / CSS
日本最大的购物网站:日本乐天市场(Rakuten Ichiba)
2020/11/04 全球购物
神路信息Java面试题目
2013/03/31 面试题
优良学风班申请材料
2014/02/13 职场文书
2014春晚主持词
2014/03/25 职场文书
银行反四风对照检查材料
2014/09/29 职场文书
2014年就业工作总结
2014/11/26 职场文书
于丹讲座视频观后感
2015/06/15 职场文书
描写九月优美句子(39条)
2019/09/11 职场文书