pytorch 实现变分自动编码器的操作


Posted in Python onMay 24, 2021

本来以为自动编码器是很简单的东西,但是也是看了好多资料仍然不太懂它的原理。先把代码记录下来,有时间好好研究。

这个例子是用MNIST数据集生成为例子

# -*- coding: utf-8 -*-
"""
Created on Fri Oct 12 11:42:19 2018
@author: www
""" 
import os 
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision.datasets import MNIST
from torchvision import transforms as tfs
from torchvision.utils import save_image 
im_tfs = tfs.Compose([
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 标准化
])
 
train_set = MNIST('E:\data', transform=im_tfs)
train_data = DataLoader(train_set, batch_size=128, shuffle=True)
 
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
 
        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20) # mean
        self.fc22 = nn.Linear(400, 20) # var
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)
 
    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)
 
    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = torch.FloatTensor(std.size()).normal_()
        if torch.cuda.is_available():
            eps = Variable(eps.cuda())
        else:
            eps = Variable(eps)
        return eps.mul(std).add_(mu)
 
    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return F.tanh(self.fc4(h3))
 
    def forward(self, x):
        mu, logvar = self.encode(x) # 编码
        z = self.reparametrize(mu, logvar) # 重新参数化成正态分布
        return self.decode(z), mu, logvar # 解码,同时输出均值方差 
 
net = VAE() # 实例化网络
if torch.cuda.is_available():
    net = net.cuda()
    
x, _ = train_set[0]
x = x.view(x.shape[0], -1)
if torch.cuda.is_available():
    x = x.cuda()
x = Variable(x)
_, mu, var = net(x) 
print(mu)
 
#可以看到,对于输入,网络可以输出隐含变量的均值和方差,这里的均值方差还没有训练
 
#下面开始训练 
reconstruction_function = nn.MSELoss(size_average=False) 
def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    MSE = reconstruction_function(recon_x, x)
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return MSE + KLD 
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
 
def to_img(x):
    '''
    定义一个函数将最后的结果转换回图片
    '''
    x = 0.5 * (x + 1.)
    x = x.clamp(0, 1)
    x = x.view(x.shape[0], 1, 28, 28)
    return x
 
for e in range(100):
    for im, _ in train_data:
        im = im.view(im.shape[0], -1)
        im = Variable(im)
        if torch.cuda.is_available():
            im = im.cuda()
        recon_im, mu, logvar = net(im)
        loss = loss_function(recon_im, im, mu, logvar) / im.shape[0] # 将 loss 平均
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
    if (e + 1) % 20 == 0:
        print('epoch: {}, Loss: {:.4f}'.format(e + 1, loss.item()))
        save = to_img(recon_im.cpu().data)
        if not os.path.exists('./vae_img'):
            os.mkdir('./vae_img')
        save_image(save, './vae_img/image_{}.png'.format(e + 1))

补充:PyTorch 深度学习快速入门——变分自动编码器

变分编码器是自动编码器的升级版本,其结构跟自动编码器是类似的,也由编码器和解码器构成。

回忆一下,自动编码器有个问题,就是并不能任意生成图片,因为我们没有办法自己去构造隐藏向量,需要通过一张图片输入编码我们才知道得到的隐含向量是什么,这时我们就可以通过变分自动编码器来解决这个问题。

其实原理特别简单,只需要在编码过程给它增加一些限制,迫使其生成的隐含向量能够粗略的遵循一个标准正态分布,这就是其与一般的自动编码器最大的不同。

这样我们生成一张新图片就很简单了,我们只需要给它一个标准正态分布的随机隐含向量,这样通过解码器就能够生成我们想要的图片,而不需要给它一张原始图片先编码。

一般来讲,我们通过 encoder 得到的隐含向量并不是一个标准的正态分布,为了衡量两种分布的相似程度,我们使用 KL divergence,利用其来表示隐含向量与标准正态分布之间差异的 loss,另外一个 loss 仍然使用生成图片与原图片的均方误差来表示。

KL divergence 的公式如下

pytorch 实现变分自动编码器的操作

重参数 为了避免计算 KL divergence 中的积分,我们使用重参数的技巧,不是每次产生一个隐含向量,而是生成两个向量,一个表示均值,一个表示标准差,这里我们默认编码之后的隐含向量服从一个正态分布的之后,就可以用一个标准正态分布先乘上标准差再加上均值来合成这个正态分布,最后 loss 就是希望这个生成的正态分布能够符合一个标准正态分布,也就是希望均值为 0,方差为 1

所以最后我们可以将我们的 loss 定义为下面的函数,由均方误差和 KL divergence 求和得到一个总的 loss

def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    MSE = reconstruction_function(recon_x, x)
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return MSE + KLD

用 mnist 数据集来简单说明一下变分自动编码器

import os 
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision.datasets import MNIST
from torchvision import transforms as tfs
from torchvision.utils import save_image
 
im_tfs = tfs.Compose([
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 标准化
])
 
train_set = MNIST('./mnist', transform=im_tfs)
train_data = DataLoader(train_set, batch_size=128, shuffle=True)
 
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
 
        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20) # mean
        self.fc22 = nn.Linear(400, 20) # var
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)
 
    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)
 
    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = torch.FloatTensor(std.size()).normal_()
        if torch.cuda.is_available():
            eps = Variable(eps.cuda())
        else:
            eps = Variable(eps)
        return eps.mul(std).add_(mu)
 
    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return F.tanh(self.fc4(h3))
 
    def forward(self, x):
        mu, logvar = self.encode(x) # 编码
        z = self.reparametrize(mu, logvar) # 重新参数化成正态分布
        return self.decode(z), mu, logvar # 解码,同时输出均值方差
 
net = VAE() # 实例化网络
if torch.cuda.is_available():
    net = net.cuda()
x, _ = train_set[0]
x = x.view(x.shape[0], -1)
if torch.cuda.is_available():
    x = x.cuda()
x = Variable(x)
_, mu, var = net(x) 
print(mu) 
 
Variable containing:  Columns 0 to 9  -0.0307 -0.1439 -0.0435  0.3472  0.0368 -0.0339  0.0274 -0.5608  0.0280  0.2742  Columns 10 to 19  -0.6221 -0.0894 -0.0933  0.4241  0.1611  0.3267  0.5755 -0.0237  0.2714 -0.2806 [torch.cuda.FloatTensor of size 1x20 (GPU 0)]

可以看到,对于输入,网络可以输出隐含变量的均值和方差,这里的均值方差还没有训练 下面开始训练

reconstruction_function = nn.MSELoss(size_average=False) 
def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    MSE = reconstruction_function(recon_x, x)
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return MSE + KLD 
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
 
def to_img(x):
    '''
    定义一个函数将最后的结果转换回图片
    '''
    x = 0.5 * (x + 1.)
    x = x.clamp(0, 1)
    x = x.view(x.shape[0], 1, 28, 28)
    return x
 
for e in range(100):
    for im, _ in train_data:
        im = im.view(im.shape[0], -1)
        im = Variable(im)
        if torch.cuda.is_available():
            im = im.cuda()
        recon_im, mu, logvar = net(im)
        loss = loss_function(recon_im, im, mu, logvar) / im.shape[0] # 将 loss 平均
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
    if (e + 1) % 20 == 0:
        print('epoch: {}, Loss: {:.4f}'.format(e + 1, loss.data[0]))
        save = to_img(recon_im.cpu().data)
        if not os.path.exists('./vae_img'):
            os.mkdir('./vae_img')
        save_image(save, './vae_img/image_{}.png'.format(e + 1))
  
epoch: 20, Loss: 61.5803 epoch: 40, Loss: 62.9573 epoch: 60, Loss: 63.4285 epoch: 80, Loss: 64.7138 epoch: 100, Loss: 63.3343

变分自动编码器虽然比一般的自动编码器效果要好,而且也限制了其输出的编码 (code) 的概率分布,但是它仍然是通过直接计算生成图片和原始图片的均方误差来生成 loss,这个方式并不好,生成对抗网络中,我们会讲一讲这种方式计算 loss 的局限性,然后会介绍一种新的训练办法,就是通过生成对抗的训练方式来训练网络而不是直接比较两张图片的每个像素点的均方误差

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python迭代用法实例教程
Sep 08 Python
Python的ORM框架中SQLAlchemy库的查询操作的教程
Apr 25 Python
利用python模拟实现POST请求提交图片的方法
Jul 25 Python
Python初学者常见错误详解
Jul 02 Python
Pycharm使用之设置代码字体大小和颜色主题的教程
Jul 12 Python
Python人工智能之路 之PyAudio 实现录音 自动化交互实现问答
Aug 13 Python
python实现的登录与提交表单数据功能示例
Sep 25 Python
快速解决jupyter启动卡死的问题
Apr 10 Python
Python基于wordcloud及jieba实现中国地图词云图
Jun 09 Python
python实现人工蜂群算法
Sep 18 Python
基于Python模拟浏览器发送http请求
Nov 06 Python
最新pycharm安装教程
Nov 18 Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
Python利用folium实现地图可视化
python爬虫之selenium库的安装及使用教程
教你利用python实现企业微信发送消息
You might like
PHP 中dirname(_file_)讲解
2007/03/18 PHP
php 远程关机操作的代码
2008/12/05 PHP
php 归并排序 数组交集
2011/05/10 PHP
两种php给图片加水印的实现代码
2020/04/18 PHP
thinkPHP删除前弹出确认框的简单实现方法
2016/05/16 PHP
Javascript String.replace的妙用
2009/09/08 Javascript
JavaScrip实现PHP print_r的数功能(三种方法)
2013/11/12 Javascript
用Javascript获取页面元素的具体位置
2013/12/09 Javascript
原生 JS Ajax,GET和POST 请求实例代码
2016/06/08 Javascript
jQuery如何解决IE输入框不能输入的问题
2016/10/08 Javascript
不使用script导入js文件的几种方法
2016/10/27 Javascript
Nodejs实现多房间简易聊天室功能
2017/06/20 NodeJs
js弹性势能动画之抛物线运动实例详解
2017/07/27 Javascript
通过webpack引入第三方库的方法
2018/07/20 Javascript
如何使用electron-builder及electron-updater给项目配置自动更新
2018/12/24 Javascript
JS实现可控制的进度条
2020/03/25 Javascript
node.js +mongdb实现登录功能
2020/06/18 Javascript
使用Python操作MySQL的一些基本方法
2015/08/16 Python
python删除文本中行数标签的方法
2018/05/31 Python
Python地图绘制实操详解
2019/03/04 Python
Python中断多重循环的思路总结
2019/10/04 Python
Python如何操作office实现自动化及win32com.client的运用
2020/04/01 Python
python中tab键是什么意思
2020/06/18 Python
详解pyqt5的UI中嵌入matplotlib图形并实时刷新(挖坑和填坑)
2020/08/07 Python
联想墨西哥官方网站:Lenovo墨西哥
2016/08/17 全球购物
南京迈特望C/C++面试题
2012/07/09 面试题
用Java语言将一个键盘输入的数字转化成中文输出
2013/01/25 面试题
淘宝客服专员岗位职责
2014/04/11 职场文书
财务会计专业自荐书
2014/06/30 职场文书
2014年科室工作总结范文
2014/12/19 职场文书
党员转正申请报告
2015/05/15 职场文书
小型婚礼主持词
2015/06/30 职场文书
初中数学课堂教学反思
2016/02/17 职场文书
创业计划书之蛋糕店
2019/08/29 职场文书
MySQL Innodb索引机制详细介绍
2021/11/23 MySQL
笔记本自带的win11如何跳过联网激活?
2022/04/20 数码科技