VAE学习笔记

发布于:2025-08-05 ⋅ 阅读:(10) ⋅ 点赞:(0)

模型结构:

(m1,m2,m3)是数据经过encoder 得到的编码 

(σ1,σ2,σ3)是控制噪音干扰程度的编码,就是为随机噪音码(e1,e2,e3)分配权重

损失函数2:如果没有对σi 的限制 生成的图片会希望噪音对自身生成图片的干扰越小,于是分配给噪音的权重越小,这样只需要将(σ1,σ2,σ3)赋为接近负无穷大的值就好了,直观上也能看出来在σi=0处取最小

VAE原理:

首先VAE认为 所有数据都是由某个隐藏变量生成的 学会了这个隐藏变量的分布 就可以生成数据。

关键步骤:

Encoder:把输入数据压缩成隐藏变量的分布参数(均值和方差),直接输出固定值会导致生成能力变差 输出分布可以随机采样增加多样性。

重参数化技巧:解决直接采样不可导问题 改用以下方式 。

                                z = μ + σ * ε, 其中 ε ~ N(0, 1)

Decoder:把隐藏变量 z 还原成数据(如生成新图片)。

损失函数:

        重构损失以及KL散度,KL散度主要是限制σ不要跑偏,保证生成多样性。

基础代码实现:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torch.nn.functional as F
from torchvision.utils import save_image



class VAE(nn.Module):
    def __init__(self, input_size, latent_size):
        super(VAE, self).__init__()
        #编码器层
        self.fc1 = nn.Linear(input_size, 512)
        self.fc2 = nn.Linear(512, latent_size)
        self.fc3 = nn.Linear(512, latent_size)

        #解码器层
        self.fc4 = nn.Linear(latent_size, 512)
        self.fc5 = nn.Linear(512, input_size)

    def encode(self, x):
        x = F.relu(self.fc1(x)) #编码器的隐藏表示
        mu = self.fc2(x)
        logvar = self.fc3(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        z = F.relu(self.fc4(z)) #将潜在变量Z解码为重构图像
        return torch.sigmoid(self.fc5(z)) #将隐藏表示映射回输入图像大小 用sigmoid激活 产生重构图像

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        out = self.decode(z)
        return out , mu, logvar

def loss_function(recon_x, x, mu, logvar):
    MSE = F.mse_loss(recon_x, x.view(-1,input_size), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return MSE + KLD

if __name__ == '__main__':
    batch_size = 64
    epochs = 50
    sample_interval = 10
    learning_rate = 1e-3
    input_size = 784
    latent_size = 256
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
    train_dateset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dateset, batch_size=batch_size, shuffle=True)
    model = VAE(input_size, latent_size).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data = data.to(device)
            data = data.view(-1,input_size)
            predict ,mu, logvar = model(data)
            loss = loss_function(predict, data, mu, logvar)
            train_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        train_loss =train_loss / len(train_loader)
        print('Epoch [{}/{}], Loss: {:.2f}]'.format(epoch + 1, epochs, train_loss))

        if (epoch+1) % sample_interval == 0:
            torch.save(model.state_dict(), f'./VAE{epoch+1}.pth')
            model.eval()
            with torch.no_grad():
                pic_num=10
                sample = torch.randn(pic_num, latent_size).to(device)
                sample_img = model.decode(sample)
                save_image(sample_img.view(pic_num,1,28,28), './sample'+str(pic_num)+'.png' , nrow = int(pic_num/2))


网站公告

今日签到

点亮在社区的每一天
去签到