pytorch 实现变分自动编码器的操作


Posted in Python onMay 24, 2021

本来以为自动编码器是很简单的东西,但是也是看了好多资料仍然不太懂它的原理。先把代码记录下来,有时间好好研究。

这个例子是用MNIST数据集生成为例子

# -*- coding: utf-8 -*-
"""
Created on Fri Oct 12 11:42:19 2018
@author: www
""" 
import os 
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision.datasets import MNIST
from torchvision import transforms as tfs
from torchvision.utils import save_image 
im_tfs = tfs.Compose([
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 标准化
])
 
train_set = MNIST('E:\data', transform=im_tfs)
train_data = DataLoader(train_set, batch_size=128, shuffle=True)
 
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
 
        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20) # mean
        self.fc22 = nn.Linear(400, 20) # var
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)
 
    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)
 
    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = torch.FloatTensor(std.size()).normal_()
        if torch.cuda.is_available():
            eps = Variable(eps.cuda())
        else:
            eps = Variable(eps)
        return eps.mul(std).add_(mu)
 
    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return F.tanh(self.fc4(h3))
 
    def forward(self, x):
        mu, logvar = self.encode(x) # 编码
        z = self.reparametrize(mu, logvar) # 重新参数化成正态分布
        return self.decode(z), mu, logvar # 解码,同时输出均值方差 
 
net = VAE() # 实例化网络
if torch.cuda.is_available():
    net = net.cuda()
    
x, _ = train_set[0]
x = x.view(x.shape[0], -1)
if torch.cuda.is_available():
    x = x.cuda()
x = Variable(x)
_, mu, var = net(x) 
print(mu)
 
#可以看到,对于输入,网络可以输出隐含变量的均值和方差,这里的均值方差还没有训练
 
#下面开始训练 
reconstruction_function = nn.MSELoss(size_average=False) 
def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    MSE = reconstruction_function(recon_x, x)
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return MSE + KLD 
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
 
def to_img(x):
    '''
    定义一个函数将最后的结果转换回图片
    '''
    x = 0.5 * (x + 1.)
    x = x.clamp(0, 1)
    x = x.view(x.shape[0], 1, 28, 28)
    return x
 
for e in range(100):
    for im, _ in train_data:
        im = im.view(im.shape[0], -1)
        im = Variable(im)
        if torch.cuda.is_available():
            im = im.cuda()
        recon_im, mu, logvar = net(im)
        loss = loss_function(recon_im, im, mu, logvar) / im.shape[0] # 将 loss 平均
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
    if (e + 1) % 20 == 0:
        print('epoch: {}, Loss: {:.4f}'.format(e + 1, loss.item()))
        save = to_img(recon_im.cpu().data)
        if not os.path.exists('./vae_img'):
            os.mkdir('./vae_img')
        save_image(save, './vae_img/image_{}.png'.format(e + 1))

补充:PyTorch 深度学习快速入门——变分自动编码器

变分编码器是自动编码器的升级版本,其结构跟自动编码器是类似的,也由编码器和解码器构成。

回忆一下,自动编码器有个问题,就是并不能任意生成图片,因为我们没有办法自己去构造隐藏向量,需要通过一张图片输入编码我们才知道得到的隐含向量是什么,这时我们就可以通过变分自动编码器来解决这个问题。

其实原理特别简单,只需要在编码过程给它增加一些限制,迫使其生成的隐含向量能够粗略的遵循一个标准正态分布,这就是其与一般的自动编码器最大的不同。

这样我们生成一张新图片就很简单了,我们只需要给它一个标准正态分布的随机隐含向量,这样通过解码器就能够生成我们想要的图片,而不需要给它一张原始图片先编码。

一般来讲,我们通过 encoder 得到的隐含向量并不是一个标准的正态分布,为了衡量两种分布的相似程度,我们使用 KL divergence,利用其来表示隐含向量与标准正态分布之间差异的 loss,另外一个 loss 仍然使用生成图片与原图片的均方误差来表示。

KL divergence 的公式如下

pytorch 实现变分自动编码器的操作

重参数 为了避免计算 KL divergence 中的积分,我们使用重参数的技巧,不是每次产生一个隐含向量,而是生成两个向量,一个表示均值,一个表示标准差,这里我们默认编码之后的隐含向量服从一个正态分布的之后,就可以用一个标准正态分布先乘上标准差再加上均值来合成这个正态分布,最后 loss 就是希望这个生成的正态分布能够符合一个标准正态分布,也就是希望均值为 0,方差为 1

所以最后我们可以将我们的 loss 定义为下面的函数,由均方误差和 KL divergence 求和得到一个总的 loss

def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    MSE = reconstruction_function(recon_x, x)
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return MSE + KLD

用 mnist 数据集来简单说明一下变分自动编码器

import os 
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision.datasets import MNIST
from torchvision import transforms as tfs
from torchvision.utils import save_image
 
im_tfs = tfs.Compose([
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 标准化
])
 
train_set = MNIST('./mnist', transform=im_tfs)
train_data = DataLoader(train_set, batch_size=128, shuffle=True)
 
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
 
        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20) # mean
        self.fc22 = nn.Linear(400, 20) # var
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)
 
    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)
 
    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = torch.FloatTensor(std.size()).normal_()
        if torch.cuda.is_available():
            eps = Variable(eps.cuda())
        else:
            eps = Variable(eps)
        return eps.mul(std).add_(mu)
 
    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return F.tanh(self.fc4(h3))
 
    def forward(self, x):
        mu, logvar = self.encode(x) # 编码
        z = self.reparametrize(mu, logvar) # 重新参数化成正态分布
        return self.decode(z), mu, logvar # 解码,同时输出均值方差
 
net = VAE() # 实例化网络
if torch.cuda.is_available():
    net = net.cuda()
x, _ = train_set[0]
x = x.view(x.shape[0], -1)
if torch.cuda.is_available():
    x = x.cuda()
x = Variable(x)
_, mu, var = net(x) 
print(mu) 
 
Variable containing:  Columns 0 to 9  -0.0307 -0.1439 -0.0435  0.3472  0.0368 -0.0339  0.0274 -0.5608  0.0280  0.2742  Columns 10 to 19  -0.6221 -0.0894 -0.0933  0.4241  0.1611  0.3267  0.5755 -0.0237  0.2714 -0.2806 [torch.cuda.FloatTensor of size 1x20 (GPU 0)]

可以看到,对于输入,网络可以输出隐含变量的均值和方差,这里的均值方差还没有训练 下面开始训练

reconstruction_function = nn.MSELoss(size_average=False) 
def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    MSE = reconstruction_function(recon_x, x)
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return MSE + KLD 
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
 
def to_img(x):
    '''
    定义一个函数将最后的结果转换回图片
    '''
    x = 0.5 * (x + 1.)
    x = x.clamp(0, 1)
    x = x.view(x.shape[0], 1, 28, 28)
    return x
 
for e in range(100):
    for im, _ in train_data:
        im = im.view(im.shape[0], -1)
        im = Variable(im)
        if torch.cuda.is_available():
            im = im.cuda()
        recon_im, mu, logvar = net(im)
        loss = loss_function(recon_im, im, mu, logvar) / im.shape[0] # 将 loss 平均
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
    if (e + 1) % 20 == 0:
        print('epoch: {}, Loss: {:.4f}'.format(e + 1, loss.data[0]))
        save = to_img(recon_im.cpu().data)
        if not os.path.exists('./vae_img'):
            os.mkdir('./vae_img')
        save_image(save, './vae_img/image_{}.png'.format(e + 1))
  
epoch: 20, Loss: 61.5803 epoch: 40, Loss: 62.9573 epoch: 60, Loss: 63.4285 epoch: 80, Loss: 64.7138 epoch: 100, Loss: 63.3343

变分自动编码器虽然比一般的自动编码器效果要好,而且也限制了其输出的编码 (code) 的概率分布,但是它仍然是通过直接计算生成图片和原始图片的均方误差来生成 loss,这个方式并不好,生成对抗网络中,我们会讲一讲这种方式计算 loss 的局限性,然后会介绍一种新的训练办法,就是通过生成对抗的训练方式来训练网络而不是直接比较两张图片的每个像素点的均方误差

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python文件和目录操作函数小结
Jul 11 Python
Python实现删除文件但保留指定文件
Jun 21 Python
Python实现学生成绩管理系统
Apr 05 Python
NumPy 如何生成多维数组的方法
Feb 05 Python
Python函数装饰器实现方法详解
Dec 22 Python
Python对ElasticSearch获取数据及操作
Apr 24 Python
3行Python代码实现图像照片抠图和换底色的方法
Oct 10 Python
python代码实现TSNE降维数据可视化教程
Feb 28 Python
python画环形图的方法
Mar 25 Python
Python识别花卉种类鉴定网络热门植物并自动整理分类
Apr 08 Python
人工智能深度学习OpenAI baselines的使用方法
May 20 Python
Python四款GUI图形界面库介绍
Jun 05 Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
Python利用folium实现地图可视化
python爬虫之selenium库的安装及使用教程
教你利用python实现企业微信发送消息
You might like
小偷PHP+Html+缓存
2006/12/20 PHP
一个图形显示IP的PHP程序代码
2007/10/19 PHP
php中选择什么接口(mysql、mysqli)访问mysql
2013/02/06 PHP
适用于抽奖程序、随机广告的PHP概率算法实例
2014/04/09 PHP
php+jQuery+Ajax实现点赞效果的方法(附源码下载)
2020/07/21 PHP
PHP Filter过滤器全面解析
2016/08/09 PHP
可以把编码转换成 gb2312编码lib.UTF8toGB2312.js
2007/08/21 Javascript
js渐变显示渐变消失示例代码
2013/08/01 Javascript
让table变成exls的示例代码
2014/03/24 Javascript
javascript的理解及经典案例分析
2016/05/20 Javascript
基于jQuery ligerUI实现分页样式
2016/09/18 Javascript
Bootstrap 设置datetimepicker在屏幕上面弹出设置方法
2017/03/21 Javascript
自带气泡提示的vue校验插件(vue-verify-pop)
2017/04/07 Javascript
jQuery+vue.js实现的九宫格拼图游戏完整实例【附源码下载】
2017/09/12 jQuery
微信小程序实现图片预览功能
2018/01/31 Javascript
vue init webpack 建vue项目报错的解决方法
2018/09/29 Javascript
浅析JavaScript异步代码优化
2019/03/18 Javascript
Python中正则表达式的用法实例汇总
2014/08/18 Python
从请求到响应过程中django都做了哪些处理
2018/08/01 Python
Tensorflow实现神经网络拟合线性回归
2019/07/19 Python
python目标检测给图画框,bbox画到图上并保存案例
2020/03/10 Python
如何基于Python爬取隐秘的角落评论
2020/07/02 Python
matplotlib基础绘图命令之imshow的使用
2020/08/13 Python
Jupyter Notebook 安装配置与使用详解
2021/01/06 Python
水上运动奥特莱斯:Wasterports Outlet
2018/08/08 全球购物
巴西独家产品和现场演示购物网站:Shoptime
2019/07/11 全球购物
员工晚婚的请假条
2014/02/08 职场文书
2014社区三八妇女节活动总结
2014/03/01 职场文书
士力架广告词
2014/03/20 职场文书
《画杨桃》教学反思
2014/04/13 职场文书
干部作风整顿自我剖析材料和整改措施
2014/09/18 职场文书
财会专业大学生求职信
2014/09/26 职场文书
幼儿园感谢信
2015/01/21 职场文书
给男朋友的道歉短信
2015/05/12 职场文书
高中团支书竞选稿
2015/11/21 职场文书
教你使用TensorFlow2识别验证码
2021/06/11 Python