pytorch生成对抗网络

发布于:2024-12-06 ⋅ 阅读:(145) ⋅ 点赞:(0)

# 生成对抗网络
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 超参数
latent_size = 64 # 潜在空间(latent space)的维度数量
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples'
# Create a directory if not exists
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
# Image processing
# transform = transforms.Compose([
#                 transforms.ToTensor(),
#                 transforms.Normalize(mean=(0.5, 0.5, 0.5),   # 3 for RGB channels
#                                      std=(0.5, 0.5, 0.5))])
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5],   # 1 for greyscale channels
                                     std=[0.5])])
# MNIST dataset
mnist = torchvision.datasets.MNIST(root='./datasets',
                                   train=True,
                                   transform=transform,
                                   download=True)
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size,
                                          shuffle=True)
for i,_ in data_loader:
    print(i.shape,i.max(),i.min(),torch.unique(_))
    break
# 鉴别器
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid())
# 生成器
G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh())
# Device setting
D = D.to(device)
G = G.to(device)
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1) # 裁剪到0-1
def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()
torch.zeros(4, 1).shape
torch.randn(4,50).shape
# 生成对抗中有生成器和鉴别器,生成器是把潜在向量作为输入,网络输出图片的一维向量形式,在模型训练时,
# 判别器对真实图片和生成图片进行判别(打分),这是个二分类,这里用1表示真,0表示假,所以鉴别器损失就由
# 真实图片和全1的损失,生成图片和全0的损失,这两个损失组成,之后反向传播,这时候更新的就是鉴别器的参数
# 鉴别器训练的目的是为了辨别真实和生成,而生成器呢,生成器的目的是让生成的图片尽量接近真实,但是你怎么判断
# 它接近真实,就是传入鉴别器,鉴别器打分越高,说明它越真实,所以生成器损失就是全1标签和鉴别logits间的误差
# 特别注意的是:在鉴别器损失反向传播时,更新的只是鉴别器模型中的参数,而生成器损失反向传播时,更新的也只是
#生成器模型中的参数
# 训练
total_step = len(data_loader) # 总批次数
for epoch in range(num_epochs):
    # 遍历每个批次数据
    for i, (images, _) in enumerate(data_loader):
        images = images.reshape(batch_size, -1).to(device)
        # 创建稍后用作BCE损失输入的标签(全1表示真,全0表示假)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        # 训练鉴别器
        outputs = D(images) # 获取鉴别器对真实图片的鉴别分数
        d_loss_real = criterion(outputs, real_labels) # 真实鉴别损失
        real_score = outputs # 真实鉴别分数
        # z是随机初始化的生成图片(latent_size是用这个大小的向量表示图片)
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z) # 获取生成的图片
        outputs = D(fake_images) # 获取鉴别器对生成图片的鉴别分数
        d_loss_fake = criterion(outputs, fake_labels) # 计算生成(假)的鉴别损失
        fake_score = outputs
        # Backprop and optimize
        d_loss = d_loss_real + d_loss_fake # 这两个加起来是鉴别器损失
        reset_grad()
        # 鉴别器损失反向传播
        d_loss.backward()
        # 根据梯度更新参数
        d_optimizer.step()
        # 训练生成器
        # 随机初始化一个噪音图片(用一定大小的向量表示)
        z = torch.randn(batch_size, latent_size).to(device)
        # 通过生成器生成图片
        fake_images = G(z)
        outputs = D(fake_images) # 鉴别器对生成图片的鉴别得分
        # 生成器的目的是使生成的图片足够真实,也就是最小化全1标签和鉴别logits间的误差
        g_loss = criterion(outputs, real_labels)
        # 清理之前梯度,用生成器损失反向传播,用g_optimizer更新参数
        reset_grad()
        g_loss.backward()
        g_optimizer.step()
        # 每隔200批次打印日志
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),
                          real_score.mean().item(), fake_score.mean().item()))
    # 真实图片只需要保存一次
    if (epoch+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
    # 每个轮次都会保存一次生成器生成的图片
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))
# Save the model checkpoints (把模型中各个层的参数字典保存到磁盘)
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')

 

上面是真实图片,下面是经过200个轮次后的生成图片


网站公告

今日签到

点亮在社区的每一天
去签到