pytorch 实现变分自动编码器的操作


Posted in Python onMay 24, 2021

本来以为自动编码器是很简单的东西,但是也是看了好多资料仍然不太懂它的原理。先把代码记录下来,有时间好好研究。

这个例子是用MNIST数据集生成为例子

# -*- coding: utf-8 -*-
"""
Created on Fri Oct 12 11:42:19 2018
@author: www
""" 
import os 
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision.datasets import MNIST
from torchvision import transforms as tfs
from torchvision.utils import save_image 
im_tfs = tfs.Compose([
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 标准化
])
 
train_set = MNIST('E:\data', transform=im_tfs)
train_data = DataLoader(train_set, batch_size=128, shuffle=True)
 
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
 
        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20) # mean
        self.fc22 = nn.Linear(400, 20) # var
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)
 
    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)
 
    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = torch.FloatTensor(std.size()).normal_()
        if torch.cuda.is_available():
            eps = Variable(eps.cuda())
        else:
            eps = Variable(eps)
        return eps.mul(std).add_(mu)
 
    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return F.tanh(self.fc4(h3))
 
    def forward(self, x):
        mu, logvar = self.encode(x) # 编码
        z = self.reparametrize(mu, logvar) # 重新参数化成正态分布
        return self.decode(z), mu, logvar # 解码,同时输出均值方差 
 
net = VAE() # 实例化网络
if torch.cuda.is_available():
    net = net.cuda()
    
x, _ = train_set[0]
x = x.view(x.shape[0], -1)
if torch.cuda.is_available():
    x = x.cuda()
x = Variable(x)
_, mu, var = net(x) 
print(mu)
 
#可以看到,对于输入,网络可以输出隐含变量的均值和方差,这里的均值方差还没有训练
 
#下面开始训练 
reconstruction_function = nn.MSELoss(size_average=False) 
def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    MSE = reconstruction_function(recon_x, x)
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return MSE + KLD 
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
 
def to_img(x):
    '''
    定义一个函数将最后的结果转换回图片
    '''
    x = 0.5 * (x + 1.)
    x = x.clamp(0, 1)
    x = x.view(x.shape[0], 1, 28, 28)
    return x
 
for e in range(100):
    for im, _ in train_data:
        im = im.view(im.shape[0], -1)
        im = Variable(im)
        if torch.cuda.is_available():
            im = im.cuda()
        recon_im, mu, logvar = net(im)
        loss = loss_function(recon_im, im, mu, logvar) / im.shape[0] # 将 loss 平均
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
    if (e + 1) % 20 == 0:
        print('epoch: {}, Loss: {:.4f}'.format(e + 1, loss.item()))
        save = to_img(recon_im.cpu().data)
        if not os.path.exists('./vae_img'):
            os.mkdir('./vae_img')
        save_image(save, './vae_img/image_{}.png'.format(e + 1))

补充:PyTorch 深度学习快速入门——变分自动编码器

变分编码器是自动编码器的升级版本,其结构跟自动编码器是类似的,也由编码器和解码器构成。

回忆一下,自动编码器有个问题,就是并不能任意生成图片,因为我们没有办法自己去构造隐藏向量,需要通过一张图片输入编码我们才知道得到的隐含向量是什么,这时我们就可以通过变分自动编码器来解决这个问题。

其实原理特别简单,只需要在编码过程给它增加一些限制,迫使其生成的隐含向量能够粗略的遵循一个标准正态分布,这就是其与一般的自动编码器最大的不同。

这样我们生成一张新图片就很简单了,我们只需要给它一个标准正态分布的随机隐含向量,这样通过解码器就能够生成我们想要的图片,而不需要给它一张原始图片先编码。

一般来讲,我们通过 encoder 得到的隐含向量并不是一个标准的正态分布,为了衡量两种分布的相似程度,我们使用 KL divergence,利用其来表示隐含向量与标准正态分布之间差异的 loss,另外一个 loss 仍然使用生成图片与原图片的均方误差来表示。

KL divergence 的公式如下

pytorch 实现变分自动编码器的操作

重参数 为了避免计算 KL divergence 中的积分,我们使用重参数的技巧,不是每次产生一个隐含向量,而是生成两个向量,一个表示均值,一个表示标准差,这里我们默认编码之后的隐含向量服从一个正态分布的之后,就可以用一个标准正态分布先乘上标准差再加上均值来合成这个正态分布,最后 loss 就是希望这个生成的正态分布能够符合一个标准正态分布,也就是希望均值为 0,方差为 1

所以最后我们可以将我们的 loss 定义为下面的函数,由均方误差和 KL divergence 求和得到一个总的 loss

def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    MSE = reconstruction_function(recon_x, x)
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return MSE + KLD

用 mnist 数据集来简单说明一下变分自动编码器

import os 
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision.datasets import MNIST
from torchvision import transforms as tfs
from torchvision.utils import save_image
 
im_tfs = tfs.Compose([
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 标准化
])
 
train_set = MNIST('./mnist', transform=im_tfs)
train_data = DataLoader(train_set, batch_size=128, shuffle=True)
 
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
 
        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20) # mean
        self.fc22 = nn.Linear(400, 20) # var
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)
 
    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)
 
    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = torch.FloatTensor(std.size()).normal_()
        if torch.cuda.is_available():
            eps = Variable(eps.cuda())
        else:
            eps = Variable(eps)
        return eps.mul(std).add_(mu)
 
    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return F.tanh(self.fc4(h3))
 
    def forward(self, x):
        mu, logvar = self.encode(x) # 编码
        z = self.reparametrize(mu, logvar) # 重新参数化成正态分布
        return self.decode(z), mu, logvar # 解码,同时输出均值方差
 
net = VAE() # 实例化网络
if torch.cuda.is_available():
    net = net.cuda()
x, _ = train_set[0]
x = x.view(x.shape[0], -1)
if torch.cuda.is_available():
    x = x.cuda()
x = Variable(x)
_, mu, var = net(x) 
print(mu) 
 
Variable containing:  Columns 0 to 9  -0.0307 -0.1439 -0.0435  0.3472  0.0368 -0.0339  0.0274 -0.5608  0.0280  0.2742  Columns 10 to 19  -0.6221 -0.0894 -0.0933  0.4241  0.1611  0.3267  0.5755 -0.0237  0.2714 -0.2806 [torch.cuda.FloatTensor of size 1x20 (GPU 0)]

可以看到,对于输入,网络可以输出隐含变量的均值和方差,这里的均值方差还没有训练 下面开始训练

reconstruction_function = nn.MSELoss(size_average=False) 
def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    MSE = reconstruction_function(recon_x, x)
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return MSE + KLD 
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
 
def to_img(x):
    '''
    定义一个函数将最后的结果转换回图片
    '''
    x = 0.5 * (x + 1.)
    x = x.clamp(0, 1)
    x = x.view(x.shape[0], 1, 28, 28)
    return x
 
for e in range(100):
    for im, _ in train_data:
        im = im.view(im.shape[0], -1)
        im = Variable(im)
        if torch.cuda.is_available():
            im = im.cuda()
        recon_im, mu, logvar = net(im)
        loss = loss_function(recon_im, im, mu, logvar) / im.shape[0] # 将 loss 平均
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
    if (e + 1) % 20 == 0:
        print('epoch: {}, Loss: {:.4f}'.format(e + 1, loss.data[0]))
        save = to_img(recon_im.cpu().data)
        if not os.path.exists('./vae_img'):
            os.mkdir('./vae_img')
        save_image(save, './vae_img/image_{}.png'.format(e + 1))
  
epoch: 20, Loss: 61.5803 epoch: 40, Loss: 62.9573 epoch: 60, Loss: 63.4285 epoch: 80, Loss: 64.7138 epoch: 100, Loss: 63.3343

变分自动编码器虽然比一般的自动编码器效果要好,而且也限制了其输出的编码 (code) 的概率分布,但是它仍然是通过直接计算生成图片和原始图片的均方误差来生成 loss,这个方式并不好,生成对抗网络中,我们会讲一讲这种方式计算 loss 的局限性,然后会介绍一种新的训练办法,就是通过生成对抗的训练方式来训练网络而不是直接比较两张图片的每个像素点的均方误差

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python列表推导式的使用方法
Nov 21 Python
Python 获取新浪微博的最新公共微博实例分享
Jul 03 Python
Python的Twisted框架上手前所必须了解的异步编程思想
May 25 Python
python中json格式数据输出的简单实现方法
Oct 31 Python
Python实现的桶排序算法示例
Nov 29 Python
Python 读取指定文件夹下的所有图像方法
Apr 27 Python
Python实现多属性排序的方法
Dec 05 Python
Python +Selenium解决图片验证码登录或注册问题(推荐)
Feb 09 Python
python加密解密库cryptography使用openSSL生成的密匙加密解密
Feb 11 Python
哈工大自然语言处理工具箱之ltp在windows10下的安装使用教程
May 07 Python
PyQt5.6+pycharm配置以及pyinstaller生成exe(小白教程)
Jun 02 Python
Opencv中cv2.floodFill算法的使用
Jun 18 Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
Python利用folium实现地图可视化
python爬虫之selenium库的安装及使用教程
教你利用python实现企业微信发送消息
You might like
PHP分页显示制作详细讲解
2006/10/09 PHP
ThinkPHP惯例配置文件详解
2014/07/14 PHP
启用Csrf后POST数据时出现的400错误
2015/07/05 PHP
PHP文件上传操作实例详解
2016/09/27 PHP
thinkPHP自动验证机制详解
2016/12/05 PHP
js资料toString 方法
2007/03/13 Javascript
防止jQuery ajax Load使用缓存的方法小结
2014/02/22 Javascript
使用jQuery实现更改默认alert框体
2015/04/13 Javascript
JS基于Mootools实现的个性菜单效果代码
2015/10/21 Javascript
JS/jQ实现免费获取手机验证码倒计时效果
2016/06/13 Javascript
Bootstrap3 多选和单选框(checkbox)
2016/12/29 Javascript
js es6系列教程 - 新的类语法实战选项卡(详解)
2017/09/02 Javascript
解析Vue2 dist 目录下各个文件的区别
2017/11/22 Javascript
vue项目动态设置页面title及是否缓存页面的问题
2018/11/08 Javascript
VUE中V-IF条件判断改变元素的样式操作
2020/08/09 Javascript
Python管理Windows服务小脚本
2018/03/12 Python
Python实现快速计算词频功能示例
2018/06/25 Python
用Python实现读写锁的示例代码
2018/11/05 Python
Python学习笔记之抓取某只基金历史净值数据实战案例
2019/06/03 Python
PyQt5 QTable插入图片并动态更新的实例
2019/06/18 Python
Python Django 页面上展示固定的页码数实现代码
2019/08/21 Python
实现Python与STM32通信方式
2019/12/18 Python
Jupyter Notebook远程登录及密码设置操作
2020/04/10 Python
解决python调用自己文件函数/执行函数找不到包问题
2020/06/01 Python
python 如何实现遗传算法
2020/09/22 Python
浅谈盘点5种基于Python生成的个性化语音方法
2021/02/05 Python
美国内衣第一品牌:Hanes(恒适)
2016/07/29 全球购物
BISSELL官网:北美吸尘器第一品牌
2019/03/14 全球购物
俄罗斯外国汽车和国产汽车配件网上商店:Движком
2020/04/19 全球购物
优秀员工推荐信
2014/05/10 职场文书
质量管理标语
2014/06/12 职场文书
“四风”问题整改措施和努力方向
2014/09/20 职场文书
民警群众路线教育实践活动对照检查材料
2014/10/04 职场文书
优秀班集体事迹材料
2014/12/25 职场文书
卡特教练观后感
2015/06/08 职场文书
法制教育主题班会
2015/08/13 职场文书