pytorch:实现简单的GAN示例(MNIST数据集)


Posted in Python onJanuary 10, 2020

我就废话不多说了,直接上代码吧!

# -*- coding: utf-8 -*-
"""
Created on Sat Oct 13 10:22:45 2018
@author: www
"""
 
import torch
from torch import nn
from torch.autograd import Variable
 
import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
 
import numpy as np
 
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
 
plt.rcParams['figure.figsize'] = (10.0, 8.0) # 设置画图的尺寸
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
 
def show_images(images): # 定义画图工具
  images = np.reshape(images, [images.shape[0], -1])
  sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
  sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
 
  fig = plt.figure(figsize=(sqrtn, sqrtn))
  gs = gridspec.GridSpec(sqrtn, sqrtn)
  gs.update(wspace=0.05, hspace=0.05)
 
  for i, img in enumerate(images):
    ax = plt.subplot(gs[i])
    plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_aspect('equal')
    plt.imshow(img.reshape([sqrtimg,sqrtimg]))
  return 
  
def preprocess_img(x):
  x = tfs.ToTensor()(x)
  return (x - 0.5) / 0.5
 
def deprocess_img(x):
  return (x + 1.0) / 2.0
 
class ChunkSampler(sampler.Sampler): # 定义一个取样的函数
  """Samples elements sequentially from some offset. 
  Arguments:
    num_samples: # of desired datapoints
    start: offset where we should start selecting from
  """
  def __init__(self, num_samples, start=0):
    self.num_samples = num_samples
    self.start = start
 
  def __iter__(self):
    return iter(range(self.start, self.start + self.num_samples))
 
  def __len__(self):
    return self.num_samples
    
NUM_TRAIN = 50000
NUM_VAL = 5000
 
NOISE_DIM = 96
batch_size = 128
 
train_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))
 
val_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))
 
imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze() # 可视化图片效果
show_images(imgs)
 
#判别网络
def discriminator():
  net = nn.Sequential(    
      nn.Linear(784, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 1)
    )
  return net
  
#生成网络
def generator(noise_dim=NOISE_DIM):  
  net = nn.Sequential(
    nn.Linear(noise_dim, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 784),
    nn.Tanh()
  )
  return net
  
#判别器的 loss 就是将真实数据的得分判断为 1,假的数据的得分判断为 0,而生成器的 loss 就是将假的数据判断为 1
 
bce_loss = nn.BCEWithLogitsLoss()#交叉熵损失函数
 
def discriminator_loss(logits_real, logits_fake): # 判别器的 loss
  size = logits_real.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  false_labels = Variable(torch.zeros(size, 1)).float()
  loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
  return loss
  
def generator_loss(logits_fake): # 生成器的 loss 
  size = logits_fake.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  loss = bce_loss(logits_fake, true_labels)
  return loss
  
# 使用 adam 来进行训练,学习率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):
  optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
  return optimizer
  
def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250, 
        noise_size=96, num_epochs=10):
  iter_count = 0
  for epoch in range(num_epochs):
    for x, _ in train_data:
      bs = x.shape[0]
      # 判别网络
      real_data = Variable(x).view(bs, -1) # 真实数据
      logits_real = D_net(real_data) # 判别网络得分
      
      sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均匀分布
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的数据
      logits_fake = D_net(fake_images) # 判别网络得分
 
      d_total_error = discriminator_loss(logits_real, logits_fake) # 判别器的 loss
      D_optimizer.zero_grad()
      d_total_error.backward()
      D_optimizer.step() # 优化判别网络
      
      # 生成网络
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的数据
 
      gen_logits_fake = D_net(fake_images)
      g_error = generator_loss(gen_logits_fake) # 生成网络的 loss
      G_optimizer.zero_grad()
      g_error.backward()
      G_optimizer.step() # 优化生成网络
 
      if (iter_count % show_every == 0):
        print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))
        imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
        show_images(imgs_numpy[0:16])
        plt.show()
        print()
      iter_count += 1
 
D = discriminator()
G = generator()
 
D_optim = get_optimizer(D)
G_optim = get_optimizer(G)
 
train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)

以上这篇pytorch:实现简单的GAN示例(MNIST数据集)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python列表操作实例
Jan 14 Python
详解python如何调用C/C++底层库与互相传值
Aug 10 Python
Python实现更改图片尺寸大小的方法(基于Pillow包)
Sep 19 Python
利用Python脚本生成sitemap.xml的实现方法
Jan 31 Python
Python爬虫获取图片并下载保存至本地的实例
Jun 01 Python
解决pandas使用read_csv()读取文件遇到的问题
Jun 15 Python
Python元组常见操作示例
Feb 19 Python
pandas实现将dataframe满足某一条件的值选出
Jun 12 Python
python实现复制文件到指定目录
Oct 16 Python
win10下安装Anaconda的教程(python环境+jupyter_notebook)
Oct 23 Python
Python 如何操作 SQLite 数据库
Aug 17 Python
Python基于opencv的简单图像轮廓形状识别(全网最简单最少代码)
Jan 28 Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
python列表推导和生成器表达式知识点总结
Jan 10 #Python
pytorch的梯度计算以及backward方法详解
Jan 10 #Python
Python如何获取Win7,Win10系统缩放大小
Jan 10 #Python
You might like
PHP中常用的字符串格式化函数总结
2014/11/19 PHP
php获取四位字母和数字的随机数的实现方法
2015/01/09 PHP
PHP配置把错误日志以邮件方式发送方法(Windows系统)
2015/06/23 PHP
PHP生成随机数的方法总结
2018/03/01 PHP
详解PHP队列的实现
2019/03/14 PHP
javascript中类的定义及其方式(《javascript高级程序设计》学习笔记)
2011/07/04 Javascript
前台js改变Session的值(用ajax实现)
2012/12/28 Javascript
js 输出内容到新窗口具体实现代码
2013/05/31 Javascript
通过隐藏iframe实现文件下载的js方法介绍
2014/02/26 Javascript
JQuery 在线引用及测试引用是否成功
2014/06/24 Javascript
Jquery中扩展方法extend使用技巧
2014/08/24 Javascript
JQuery跳出each循环的方法
2015/04/16 Javascript
微信小程序 Toast自定义实例详解
2017/01/20 Javascript
深入浅析JSONAPI在PHP中的应用
2017/12/24 Javascript
浅谈angular表单提交中ng-submit的默认使用方法
2018/09/30 Javascript
webpack结合express实现自动刷新的方法
2019/05/07 Javascript
图解NodeJS实现登录注册功能
2019/09/16 NodeJs
vue.js实现图书管理功能
2019/09/24 Javascript
javascript实现鼠标点击生成文字特效
2019/12/24 Javascript
基于canvasJS在PHP中制作动态图表
2020/05/30 Javascript
Element Rate 评分的使用方法
2020/07/27 Javascript
element-ui tree结构实现增删改自定义功能代码
2020/08/31 Javascript
[52:02]DOTA2-DPC中国联赛 正赛 Phoenix vs Dragon BO3 第二场 2月26日
2021/03/11 DOTA
跟老齐学Python之大话题小函数(1)
2014/10/10 Python
Python3.2中的字符串函数学习总结
2015/04/23 Python
对python3新增的byte类型详解
2018/12/04 Python
几行Python代码爬取3000+上市公司的信息
2019/01/24 Python
python  文件的基本操作 菜中菜功能的实例代码
2019/07/17 Python
简述 Python 的类和对象
2020/08/21 Python
优衣库英国官网:UNIQLO英国
2016/12/25 全球购物
高性能钓鱼服装:Huk Gear
2019/02/20 全球购物
财务会计自荐信范文
2014/02/21 职场文书
高中美术教师事迹材料
2014/08/22 职场文书
对公司的意见和建议
2015/06/04 职场文书
golang实现一个简单的websocket聊天室功能
2021/10/05 Golang
疑《守望先锋2》A测截图泄露 或将推出新模式、新界面
2022/04/03 其他游戏