变分自编码器VAE的Pytorch实现

发布于:2025-08-14 ⋅ 阅读:(19) ⋅ 点赞:(0)

一、导入第三方库

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置中文字体
plt.rcParams['axes.unicode_minus'] = False   # 正常显示负号
from torchvision import transforms
import os
from PIL import Image
from torch.utils.data import Dataset,DataLoader

二、手写数字数据集准备

#手写数字数据集
class MINISTDataset(Dataset):
    def __init__(self,files,root_dir,transform=None):
        self.files=files
        self.root_dir=root_dir
        self.transform=transform
        self.labels=[]
        for f in files:
            parts=f.split("_")
            p=parts[2].split(".")[0]
            self.labels.append(int(p))

    def __len__(self):
        return len(self.files)

    def __getitem__(self,idx):
        img_path=os.path.join(self.root_dir,self.files[idx])
        img=Image.open(img_path).convert("L")

        if self.transform:
            img=self.transform(img)

        label=self.labels[idx]
        return img,label

三、VAE模型的pytorch代码

#编码器
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=nn.Sequential(
            nn.Conv2d(1,10,kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        self.conv2=nn.Sequential(
            nn.Conv2d(10,20,kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        self.fc1=nn.Linear(320,160)
        self.fc21=nn.Linear(160,80)  #均值
        self.fc22=nn.Linear(160,80)  #方差
        self.relu=nn.ReLU()

    def forward(self,x):
        batch_size=x.size(0)
        x=self.conv1(x)
        x=self.conv2(x)
        x=x.view(batch_size,-1)
        h=self.relu(self.fc1(x))
        mu=self.fc21(h)
        log_var=self.fc22(h)
        return mu,log_var

#解码器
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.main=nn.Sequential(
            nn.Linear(80,160),
            nn.ReLU(),
            nn.Linear(160,320),
            nn.ReLU(),
            nn.Linear(320,28*28),
            nn.Sigmoid()
        )

    def forward(self,z):
        return self.main(z)

#变分自编码器
class VAE(nn.Module):
    def __init__(self,encoder,decoder):
        super().__init__()
        self.encoder=encoder
        self.decoder=decoder

    #重参数化
    def reparameterize(self,mu,log_var):
        std=torch.exp(0.5*log_var)  #计算标准差
        eps=torch.randn_like(std)   #从标准正态分布中采样噪声
        z=mu+eps*std  #重参数化
        return z

    def forward(self,x):
        mu,log_var=self.encoder(x)
        z=self.reparameterize(mu,log_var)
        return self.decoder(z),mu,log_var

四、主程序

if __name__=="__main__":

    #对数据做归一化处理
    transforms=transforms.Compose([
        transforms.Resize((28,28)),
        transforms.ToTensor()
    ])

    #路径
    base_dir='C:\\Users\\Administrator\\PycharmProjects\\CNN'
    train_dir=os.path.join(base_dir,"minist_train")

    #获取文件夹里图像的名称
    train_files=[f for f in os.listdir(train_dir) if f.endswith('.jpg')]

    #创建数据集和数据加载器
    train_dataset=MINISTDataset(train_files,train_dir,transform=transforms)
    train_loader=DataLoader(train_dataset,batch_size=64,shuffle=True)

    #参数
    num_epochs=50
    lr=0.001

    #模型初始化
    encoder=Encoder()
    decoder=Decoder()
    vae=VAE(encoder,decoder)
    criterion=nn.BCELoss()
    optimizer=optim.Adam(vae.parameters(),lr=lr,betas=(0.5,0.999))

    #记录损失函数值
    epoch_loss=[]

    for epoch in range(num_epochs):
        total_loss=0.0

        for data in train_loader:
            images,_=data
            #images=images.view(images.size(0),-1)

            optimizer.zero_grad()

            outputs,mu,logvar=vae(images)

            #计算重构损失和KL散度
            reconstruction_loss=criterion(outputs,images.view(images.size(0),-1))
            kl_divergence=-0.5*torch.mean(1+logvar-mu.pow(2)-logvar.exp())

            loss=reconstruction_loss+0.1*kl_divergence

            loss.backward()
            optimizer.step()

            total_loss+=loss.item()

        avg_loss=total_loss/len(train_loader)
        epoch_loss.append(avg_loss)

        print("Epoch",epoch,"  Loss:",avg_loss)

        #生成新图像
        with torch.no_grad():
            if (epoch+1)%5==0:
                z=torch.randn(9,80)
                plt.figure(figsize=(9,9))
                for i in range(9):
                    plt.subplot(3,3,i+1)
                    plt.imshow(decoder(z[i]).view(28,28),cmap="gray")
                    plt.axis("off")
                name=f"vae_gen_img_{epoch}.jpg"
                gen_name=os.path.join("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_img",name)
                plt.savefig(gen_name,dpi=300)
                plt.close()

    #绘制损失函数曲线图
    plt.figure(figsize=(12,6))
    plt.plot(epoch_loss,color="tomato")
    plt.xlabel("epoch")
    plt.ylabel("loss")
    plt.title("损失函数曲线图")
    plt.legend()
    plt.grid()
    plt.savefig("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_loss.jpg")
    plt.close()

五、运行结果

5.1 损失函数曲线图

5.2 生成的图像

这里只展示一部分

vae_gen_img_4.jpg

vae_gen_img_29.jpg

vae_gen_img_49.jpg

六、VAE的完整代码

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置中文字体
plt.rcParams['axes.unicode_minus'] = False   # 正常显示负号
from torchvision import transforms
import os
from PIL import Image
from torch.utils.data import Dataset,DataLoader

#手写数字数据集
class MINISTDataset(Dataset):
    def __init__(self,files,root_dir,transform=None):
        self.files=files
        self.root_dir=root_dir
        self.transform=transform
        self.labels=[]
        for f in files:
            parts=f.split("_")
            p=parts[2].split(".")[0]
            self.labels.append(int(p))

    def __len__(self):
        return len(self.files)

    def __getitem__(self,idx):
        img_path=os.path.join(self.root_dir,self.files[idx])
        img=Image.open(img_path).convert("L")

        if self.transform:
            img=self.transform(img)

        label=self.labels[idx]
        return img,label

#编码器
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=nn.Sequential(
            nn.Conv2d(1,10,kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        self.conv2=nn.Sequential(
            nn.Conv2d(10,20,kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        self.fc1=nn.Linear(320,160)
        self.fc21=nn.Linear(160,80)  #均值
        self.fc22=nn.Linear(160,80)  #方差
        self.relu=nn.ReLU()

    def forward(self,x):
        batch_size=x.size(0)
        x=self.conv1(x)
        x=self.conv2(x)
        x=x.view(batch_size,-1)
        h=self.relu(self.fc1(x))
        mu=self.fc21(h)
        log_var=self.fc22(h)
        return mu,log_var

#解码器
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.main=nn.Sequential(
            nn.Linear(80,160),
            nn.ReLU(),
            nn.Linear(160,320),
            nn.ReLU(),
            nn.Linear(320,28*28),
            nn.Sigmoid()
        )

    def forward(self,z):
        return self.main(z)

#变分自编码器
class VAE(nn.Module):
    def __init__(self,encoder,decoder):
        super().__init__()
        self.encoder=encoder
        self.decoder=decoder

    #重参数化
    def reparameterize(self,mu,log_var):
        std=torch.exp(0.5*log_var)  #计算标准差
        eps=torch.randn_like(std)   #从标准正态分布中采样噪声
        z=mu+eps*std  #重参数化
        return z

    def forward(self,x):
        mu,log_var=self.encoder(x)
        z=self.reparameterize(mu,log_var)
        return self.decoder(z),mu,log_var

if __name__=="__main__":

    #对数据做归一化处理
    transforms=transforms.Compose([
        transforms.Resize((28,28)),
        transforms.ToTensor()
    ])

    #路径
    base_dir='C:\\Users\\Administrator\\PycharmProjects\\CNN'
    train_dir=os.path.join(base_dir,"minist_train")

    #获取文件夹里图像的名称
    train_files=[f for f in os.listdir(train_dir) if f.endswith('.jpg')]

    #创建数据集和数据加载器
    train_dataset=MINISTDataset(train_files,train_dir,transform=transforms)
    train_loader=DataLoader(train_dataset,batch_size=64,shuffle=True)

    #参数
    num_epochs=50
    lr=0.001

    #模型初始化
    encoder=Encoder()
    decoder=Decoder()
    vae=VAE(encoder,decoder)
    criterion=nn.BCELoss()
    optimizer=optim.Adam(vae.parameters(),lr=lr,betas=(0.5,0.999))

    #记录损失函数值
    epoch_loss=[]

    for epoch in range(num_epochs):
        total_loss=0.0

        for data in train_loader:
            images,_=data
            #images=images.view(images.size(0),-1)

            optimizer.zero_grad()

            outputs,mu,logvar=vae(images)

            #计算重构损失和KL散度
            reconstruction_loss=criterion(outputs,images.view(images.size(0),-1))
            kl_divergence=-0.5*torch.mean(1+logvar-mu.pow(2)-logvar.exp())

            loss=reconstruction_loss+0.1*kl_divergence

            loss.backward()
            optimizer.step()

            total_loss+=loss.item()

        avg_loss=total_loss/len(train_loader)
        epoch_loss.append(avg_loss)

        print("Epoch",epoch,"  Loss:",avg_loss)

        #生成新图像
        with torch.no_grad():
            if (epoch+1)%5==0:
                z=torch.randn(9,80)
                plt.figure(figsize=(9,9))
                for i in range(9):
                    plt.subplot(3,3,i+1)
                    plt.imshow(decoder(z[i]).view(28,28),cmap="gray")
                    plt.axis("off")
                name=f"vae_gen_img_{epoch}.jpg"
                gen_name=os.path.join("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_img",name)
                plt.savefig(gen_name,dpi=300)
                plt.close()

    #绘制损失函数曲线图
    plt.figure(figsize=(12,6))
    plt.plot(epoch_loss,color="tomato")
    plt.xlabel("epoch")
    plt.ylabel("loss")
    plt.title("损失函数曲线图")
    plt.legend()
    plt.grid()
    plt.savefig("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_loss.jpg")
    plt.close()


网站公告

今日签到

点亮在社区的每一天
去签到