生成对抗网络(GAN)原理

发布于:2025-05-20 ⋅ 阅读:(18) ⋅ 点赞:(0)

介绍

示例代码

生成对抗网络(Generative Adversarial Network,GAN)是由 Ian Goodfellow 等人在 2014 年提出的一种深度生成模型。它通过两个神经网络之间的博弈(对抗)过程,学习数据的生成分布,从而生成以假乱真的数据(如图像、语音等)。GAN 是近年来生成模型领域的重要突破,广泛应用于图像生成、风格迁移、图像修复等任务中。


一、GAN 的基本结构

GAN 主要由两个部分组成:

1. 生成器(Generator,记作 G)

  • 目标:生成尽可能真实的数据,欺骗判别器。
  • 输入:随机噪声向量(一般从正态分布或均匀分布中采样)
  • 输出:“伪造”的样本,尽可能与真实样本相似。

2. 判别器(Discriminator,记作 D)

  • 目标:判断输入数据是真实的样本还是生成器生成的伪造样本。
  • 输入:真实样本或生成样本
  • 输出:一个概率值,表示输入是“真实”的概率。

二、对抗过程(博弈思想)

GAN 的训练过程是一个零和博弈(min-max game):

  • 生成器试图最小化判别器对生成样本的识别能力;
  • 判别器试图最大化识别真实样本与生成样本的能力。

这个过程可以表示为一个最优化问题:

min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log (1 - D(G(z)))] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

其中:

  • p d a t a ( x ) p_{data}(x) pdata(x) 是真实数据的分布;
  • p z ( z ) p_z(z) pz(z) 是生成器输入噪声的分布(如高斯分布);
  • D ( x ) D(x) D(x) 是判别器输出 x 是真实数据的概率;
  • G ( z ) G(z) G(z) 是生成器输出的伪造样本。

三、训练过程

  1. 固定生成器 G,训练判别器 D

    • 给 D 一部分真实样本(标签为 1);
    • 给 D 一部分 G 生成的样本(标签为 0);
    • 通过交叉熵损失训练 D,使其能区分真假样本。
  2. 固定判别器 D,训练生成器 G

    • 通过 G 生成假样本;
    • D 会判断其为假;
    • G 的目标是欺骗 D,即最大化 D ( G ( z ) ) D(G(z)) D(G(z)),让 D 判错;
    • 通常优化的是 log ⁡ D ( G ( z ) ) \log D(G(z)) logD(G(z)) 的反函数,例如 log ⁡ ( 1 − D ( G ( z ) ) ) \log(1 - D(G(z))) log(1D(G(z))) 或更稳定的变体(如使用 feature matching 或 Wasserstein loss)。
  3. 交替训练 D 和 G,直到生成器生成的样本无法被判别器区分为假(判别器输出接近 0.5)。


四、存在的问题与改进方向

1. 模式崩溃(Mode Collapse)

生成器只学会生成一小部分模式样本,导致多样性丢失。

2. 训练不稳定

D 和 G 的能力不均衡、学习率不合适等因素可能导致 GAN 训练震荡或失败。

3. 衡量指标困难

GAN 的损失函数不能很好地反映生成质量,因此通常使用 FID、IS 等指标辅助评估。


五、GAN 的改进与变种

为了克服原始 GAN 的不足,研究人员提出了许多变种:

名称 简介
DCGAN 使用卷积神经网络的 GAN,适合图像数据生成
WGAN Wasserstein GAN,引入 Wasserstein 距离,解决训练不稳定问题
WGAN-GP 在 WGAN 基础上加上梯度惩罚项,提高训练稳定性
CGAN 条件 GAN,可以控制生成样本的类别(如生成特定数字)
CycleGAN 用于图像风格转换(如马<->斑马、夏天<->冬天)
StyleGAN 高质量人脸图像生成的里程碑,支持精细控制生成风格
BigGAN 更大规模、更高分辨率的图像生成模型

六、GAN 的应用

  1. 图像生成(人脸、卡通、合成图像)
  2. 图像风格转换(如照片转油画)
  3. 图像修复与超分辨率重建
  4. 数据增强(医学图像、遥感等领域)
  5. 音频合成、文本生成(结合 RNN、Transformer)
  6. 深度伪造(Deepfake)

下面是一个使用 PyTorch 实现的 DCGAN(Deep Convolutional GAN)完整示例,用于在 MNISTCIFAR-10 数据集上生成图像。与普通 GAN 相比,DCGAN 使用卷积神经网络结构,更适合图像处理,生成效果更清晰、真实。


✅ DCGAN 与普通 GAN 的区别

模块 普通 GAN DCGAN
生成器结构 全连接层(MLP) 转置卷积(ConvTranspose2d)
判别器结构 全连接层 卷积层(Conv2d)
激活函数 ReLU + Sigmoid LeakyReLU + Sigmoid / No sigmoid
正则化 无 / BatchNorm1d BatchNorm2d

✅ DCGAN 示例(基于 MNIST)

🔧 安装依赖

pip install torch torchvision matplotlib

🧠 DCGAN 架构代码(Generator + Discriminator)

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import os

# 超参数
img_size = 64
channels = 1  # 改为 3 可用于 CIFAR-10
latent_dim = 100
batch_size = 128
lr = 0.0002
epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 创建目录
os.makedirs("dcgan_images", exist_ok=True)

# 数据预处理(MNIST 被 resize 成 64x64 )
transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True
)

# 生成器(使用转置卷积)
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            # 输入是 latent_dim 向量,输出 1024
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        z = z.view(z.size(0), latent_dim, 1, 1)
        return self.model(z)

# 判别器(使用卷积)
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(channels, 128, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x).view(-1, 1).squeeze(1)

# 初始化模型
G = Generator().to(device)
D = Discriminator().to(device)

# 损失和优化器
criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

# 训练 DCGAN
for epoch in range(epochs):
    for i, (real_imgs, _) in enumerate(dataloader):
        real_imgs = real_imgs.to(device)
        b_size = real_imgs.size(0)

        # 标签
        valid = torch.ones(b_size, device=device)
        fake = torch.zeros(b_size, device=device)

        # ========== 训练判别器 ==========
        optimizer_D.zero_grad()
        real_loss = criterion(D(real_imgs), valid)

        z = torch.randn(b_size, latent_dim, device=device)
        gen_imgs = G(z)
        fake_loss = criterion(D(gen_imgs.detach()), fake)

        d_loss = real_loss + fake_loss
        d_loss.backward()
        optimizer_D.step()

        # ========== 训练生成器 ==========
        optimizer_G.zero_grad()
        g_loss = criterion(D(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

        if i % 100 == 0:
            print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] "
                  f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")

    # 保存生成图像
    with torch.no_grad():
        z = torch.randn(64, latent_dim, device=device)
        gen_imgs = G(z)
        grid = make_grid(gen_imgs, nrow=8, normalize=True)
        save_image(grid, f"dcgan_images/{epoch:03d}.png")

print("DCGAN 训练完成,图像保存在 dcgan_images 文件夹中。")

🧪 使用说明

  • 若想改用 彩色图像(如 CIFAR-10),需:

    • channels = 3
    • 使用 datasets.CIFAR10 替代 MNIST
    • 修改 transforms.Normalize([0.5]*3, [0.5]*3)

网站公告

今日签到

点亮在社区的每一天
去签到