GAN训练mnist数据集

发布于:2025-06-23 ⋅ 阅读:(20) ⋅ 点赞:(0)
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from openpyxl.styles.builtins import output
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data.dataloader import DataLoader
import matplotlib.pyplot as plt

def gen_img_plot(model,text_input):
    prediction = np.squeeze(model(text_input).detach().cpu().numpy()[:16])
    plt.figure(figsize=(4,4))
    for i in range(16):
        plt.subplot(4,4,i+1)
        plt.imshow((prediction[i]+1)/2)
        plt.axis('off')
    plt.show()


dataset_train = datasets.MNIST(root='./DATA',train=True,download=False,transform=transforms.Compose([transforms.Resize((28,28)),
                                                                                    transforms.ToTensor(),
                                                                                    transforms.Normalize([0.5],[0.5],[0.5])]))

dataset_test = datasets.MNIST(root='./DATA',train=False,download=False,transform=transforms.Compose([transforms.Resize((28,28)),
                                                                                    transforms.ToTensor(),
                                                                                    transforms.Normalize([0.5],[0.5],[0.5])]))

train_loader = DataLoader(dataset_train,batch_size=64,shuffle=True)
test_loader = DataLoader(dataset_test,batch_size=64,shuffle=False)


class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            # 输入: [batch, 64, 1, 1]
            nn.ConvTranspose2d(latent_dim, 32, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            # [batch, 32, 4, 4]
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            # [batch, 16, 8, 8]
            nn.ConvTranspose2d(16, 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            # [batch, 8, 16, 16]
            nn.ConvTranspose2d(8, 1, kernel_size=4, stride=2, padding=3, bias=False),
            nn.Tanh()
            # 输出: [batch, 1, 28, 28]
        )

    def forward(self, z):
        return self.model(z.view(z.size(0), z.size(1), 1, 1))

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            # 输入: [batch, 1, 28, 28]
            nn.Conv2d(1, 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # [batch, 4, 14, 14]
            nn.Conv2d(4, 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(8),
            nn.LeakyReLU(0.2, inplace=True),
            # [batch, 8, 7, 7]
            nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2, inplace=True),
            # [batch, 16, 4, 4]
            nn.Conv2d(16, 1, kernel_size=4, stride=1, padding=0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img).view(-1, 1).squeeze(1)

generator = Generator()
discriminator = Discriminator()


G_optimizer = torch.optim.Adam(generator.parameters(),lr=0.0001)
D_optimizer = torch.optim.Adam(discriminator.parameters(),lr=0.0002)


criterion = torch.nn.BCELoss()

num_epoch = 100

G_loss_save = []
D_loss_save = []
for epoch in range(num_epoch):
    G_epoch_loss = 0
    D_epoch_loss = 0
    count = len(train_loader)
    for i, (img,_) in enumerate(train_loader):
        size = img.size(0)
        #生成随机噪声
        fake_img = torch.randn(size,100)
        #根据随机噪声生成图像
        output_fake = generator(fake_img)
        #判断器判断假样本的分数
        fake_score = discriminator(output_fake.detach())
        #假样本趋于0的损失
        D_fake_loss = criterion(fake_score,torch.zeros_like(fake_score))
        #判断真样本的分数
        real_score = discriminator(img)
        #判断真样本趋近于1的损失
        D_real_loss = criterion(real_score,torch.ones_like(real_score))
        D_loss = D_fake_loss + D_real_loss
        D_optimizer.zero_grad()
        D_loss.backward()
        D_optimizer.step()

        #训练生成器
        fake_G_score = discriminator(output_fake)
        #生成器要尽可能的使判别器判1
        G_fake_loss = criterion(fake_G_score,torch.ones_like(fake_G_score))
        G_optimizer.zero_grad()
        G_fake_loss.backward()
        G_optimizer.step()
        with torch.no_grad():
            G_epoch_loss += G_fake_loss
            D_epoch_loss += D_loss
    with torch.no_grad():
        G_epoch_loss /= count
        D_epoch_loss /= count

        G_loss_save.append(G_epoch_loss)
        D_loss_save.append(D_epoch_loss)
        print('Epoch:[%d/%d] | G_loss:%.3f | D_loss:%.3f'%(epoch,num_epoch,G_epoch_loss,D_epoch_loss))
    text_input = torch.randn(64,100)
    gen_img_plot(generator,text_input)

训练50轮后效果如下:


网站公告

今日签到

点亮在社区的每一天
去签到