pytorch 演示 “50层深度生成网络“ 基于 CIFAR10数据集【低配显卡正常运行】

发布于:2025-07-24 ⋅ 阅读:(28) ⋅ 点赞:(0)
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import os
import time
from sklearn.metrics.pairwise import cosine_similarity

# 设置随机种子确保可重复性
torch.manual_seed(42)
np.random.seed(42)

# 检查GPU可用性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 数据集下载根目录
data_root = "./dataset/"

# 定义数据预处理
transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 加载CIFAR10数据集
train_set = torchvision.datasets.CIFAR10(root=data_root, train=True, download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10(root=data_root, train=False, download=True, transform=transform)

batch_size = 64  # 适合3.5GB显存的批大小

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

# 定义残差块
class ResidualBlock(nn.Module):
  def __init__(self, in_channels, out_channels, stride=1):
    super(ResidualBlock, self).__init__()
    self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(out_channels)
    self.relu = nn.ReLU(inplace=True)
    self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(out_channels)
    
    self.shortcut = nn.Sequential()
    if stride != 1 or in_channels != out_channels:
      self.shortcut = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
        nn.BatchNorm2d(out_channels)
      )

  def forward(self, x):
    # [B, C, H, W] -> [B, C, H, W]
    identity = self.shortcut(x)
    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)
    out = self.conv2(out)
    out = self.bn2(out)
    out += identity
    out = self.relu(out)
    return out

# 定义50层深度自编码器
class DeepAutoencoder(nn.Module):
  def __init__(self):
    super(DeepAutoencoder, self).__init__()
    
    # 编码器 (25层)
    self.encoder = nn.Sequential(
      # 输入: [3, 32, 32]
      nn.Conv2d(3, 32, kernel_size=3, padding=1),
      nn.BatchNorm2d(32),
      nn.ReLU(),
      # 残差块组1 (4层)
      self._make_layer(32, 32, 2, stride=1),
      # 残差块组2 (4层)
      self._make_layer(32, 64, 2, stride=2),  # 下采样到16x16
      # 残差块组3 (4层)
      self._make_layer(64, 128, 2, stride=2),  # 下采样到8x8
      # 残差块组4 (4层)
      self._make_layer(128, 256, 2, stride=2),  # 下采样到4x4
      # 残差块组5 (4层)
      self._make_layer(256, 512, 2, stride=1),
      # 全局平均池化
      nn.AdaptiveAvgPool2d(1),  # [512, 1, 1]
      nn.Flatten(),  # [512]
      nn.Linear(512, 128)  # 最终编码: [128]
    )
    
    # 解码器 (25层)
    self.decoder = nn.Sequential(
      # 输入: [128]
      nn.Linear(128, 512),
      nn.Unflatten(1, (512, 1, 1)),  # [512, 1, 1]
      # 上采样到4x4
      nn.ConvTranspose2d(512, 256, kernel_size=4, stride=1),
      nn.BatchNorm2d(256),
      nn.ReLU(),
      # 残差块组1 (4层)
      self._make_layer(256, 256, 2, stride=1),
      # 上采样到8x8
      nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
      nn.BatchNorm2d(128),
      nn.ReLU(),
      # 残差块组2 (4层)
      self._make_layer(128, 128, 2, stride=1),
      # 上采样到16x16
      nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
      nn.BatchNorm2d(64),
      nn.ReLU(),
      # 残差块组3 (4层)
      self._make_layer(64, 64, 2, stride=1),
      # 上采样到32x32
      nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
      nn.BatchNorm2d(32),
      nn.ReLU(),
      # 残差块组4 (4层)
      self._make_layer(32, 32, 2, stride=1),
      # 最终输出层
      nn.Conv2d(32, 3, kernel_size=3, padding=1),
      nn.Tanh()  # 输出范围[-1,1]
    )
  
  def _make_layer(self, in_channels, out_channels, num_blocks, stride):
    layers = []
    layers.append(ResidualBlock(in_channels, out_channels, stride))
    for _ in range(1, num_blocks):
      layers.append(ResidualBlock(out_channels, out_channels, stride=1))
    return nn.Sequential(*layers)
  
  def forward(self, x):
    # [B, 3, 32, 32] -> [B, 128]
    latent = self.encoder(x)
    # [B, 128] -> [B, 3, 32, 32]
    reconstructed = self.decoder(latent)
    return reconstructed

# 实例化模型
model = DeepAutoencoder().to(device)

# 定义损失函数
criterion = nn.MSELoss()

# 优化器
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

# 学习率调度器
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)

# 记录训练和测试损失
train_losses = []
test_losses = []
similarities = []  # 存储相似度结果

# 创建结果目录
os.makedirs('results', exist_ok=True)

def pltLoss():
  plt.figure(figsize=(10, 5))
  plt.plot(train_losses, label='Training Loss')
  plt.plot(test_losses, label='Test Loss')
  plt.xlabel('Epochs')
  plt.ylabel('Loss')
  plt.title('Training and Test Loss')
  plt.legend()
  plt.grid(True)
  plt.savefig('results/loss_curve.png')
  plt.close()
  
  # 单独绘制相似度曲线
  plt.figure(figsize=(10, 5))
  plt.plot(similarities, label='Cosine Similarity', color='green')
  plt.xlabel('Epochs')
  plt.ylabel('Similarity')
  plt.title('Test Set Reconstruction Similarity')
  plt.legend()
  plt.grid(True)
  plt.savefig('results/similarity_curve.png')
  plt.close()

def calcTestDataSet_Rebuildsimilarity():
  model.eval()
  total_similarity = 0.0
  total_samples = 0
  
  with torch.no_grad():
    for images, _ in test_loader:
      images = images.to(device)
      outputs = model(images)
      
      # 展平图像和重建结果
      images_flat = images.view(images.size(0), -1).cpu().numpy()
      outputs_flat = outputs.view(outputs.size(0), -1).cpu().numpy()
      
      # 计算余弦相似度
      batch_similarity = cosine_similarity(images_flat, outputs_flat)
      # 取对角线元素(每个样本与自身重建的相似度)
      total_similarity += np.diag(batch_similarity).sum()
      total_samples += images.size(0)
  
  avg_similarity = total_similarity / total_samples
  similarities.append(avg_similarity)
  print(f"Average Cosine Similarity: {avg_similarity:.4f}")
  return avg_similarity

def findTopSample():
  model.eval()
  max_loss = 0.0
  worst_sample = None
  worst_recon = None
  worst_label = None
  
  with torch.no_grad():
    for images, labels in test_loader:
      images = images.to(device)
      outputs = model(images)
      loss = criterion(outputs, images)
      
      # 计算每个样本的损失
      batch_loss = torch.mean((outputs - images)**2, dim=[1,2,3])
      # 找到当前批次中损失最大的样本
      batch_max_loss, idx = torch.max(batch_loss, 0)
      
      if batch_max_loss > max_loss:
        max_loss = batch_max_loss
        worst_sample = images[idx].cpu()
        worst_recon = outputs[idx].cpu()
        worst_label = labels[idx].item()
  
  # 保存最差样本和重建结果
  if worst_sample is not None:
    # 反归一化
    worst_sample = (worst_sample * 0.5) + 0.5
    worst_recon = (worst_recon * 0.5) + 0.5
    
    plt.figure(figsize=(8, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(worst_sample.permute(1, 2, 0).numpy())
    plt.title(f"Original (Label: {worst_label})")
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    plt.imshow(worst_recon.permute(1, 2, 0).numpy())
    plt.title("Reconstruction")
    plt.axis('off')
    
    plt.suptitle(f"Worst Sample (Epoch {len(train_losses)})")
    plt.savefig(f'results/worst_sample_epoch_{len(train_losses)}.png')
    plt.close()

# 训练函数
def train(epoch):
  model.train()
  running_loss = 0.0
  
  for images, _ in train_loader:
    images = images.to(device)
    optimizer.zero_grad()
    
    # 前向传播
    outputs = model(images)
    loss = criterion(outputs, images)
    
    # 反向传播
    loss.backward()
    optimizer.step()
    
    running_loss += loss.item()
  
  epoch_loss = running_loss / len(train_loader)
  train_losses.append(epoch_loss)
  print(f"Epoch {epoch} Training Loss: {epoch_loss:.4f}")

# 测试函数
def test(epoch):
  model.eval()
  running_loss = 0.0
  
  with torch.no_grad():
    for images, _ in test_loader:
      images = images.to(device)
      outputs = model(images)
      loss = criterion(outputs, images)
      running_loss += loss.item()
  
  epoch_loss = running_loss / len(test_loader)
  test_losses.append(epoch_loss)
  print(f"Epoch {epoch} Test Loss: {epoch_loss:.4f}")
  return epoch_loss

# 训练循环
num_epochs = 10  # 10分钟内完成训练

for epoch in range(1, num_epochs+1):
  start_time = time.time()
  
  train(epoch)
  test_loss = test(epoch)
  scheduler.step(test_loss)
  
  # 计算相似度
  calcTestDataSet_Rebuildsimilarity()
  
  # 寻找并保存最差样本
  findTopSample()
  
  # 绘制损失曲线
  pltLoss()
  
  epoch_time = time.time() - start_time
  print(f"Epoch {epoch} completed in {epoch_time:.2f} seconds\n")

# 保存最终模型
torch.save(model.state_dict(), 'autoencoder.pth')
print("Training finished!")