import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import os
import time
from sklearn.metrics.pairwise import cosine_similarity
torch.manual_seed(42)
np.random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
data_root = "./dataset/"
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_set = torchvision.datasets.CIFAR10(root=data_root, train=True, download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10(root=data_root, train=False, download=True, transform=transform)
batch_size = 64
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
identity = self.shortcut(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity
out = self.relu(out)
return out
class DeepAutoencoder(nn.Module):
def __init__(self):
super(DeepAutoencoder, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
self._make_layer(32, 32, 2, stride=1),
self._make_layer(32, 64, 2, stride=2),
self._make_layer(64, 128, 2, stride=2),
self._make_layer(128, 256, 2, stride=2),
self._make_layer(256, 512, 2, stride=1),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(512, 128)
)
self.decoder = nn.Sequential(
nn.Linear(128, 512),
nn.Unflatten(1, (512, 1, 1)),
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=1),
nn.BatchNorm2d(256),
nn.ReLU(),
self._make_layer(256, 256, 2, stride=1),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
self._make_layer(128, 128, 2, stride=1),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
self._make_layer(64, 64, 2, stride=1),
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
self._make_layer(32, 32, 2, stride=1),
nn.Conv2d(32, 3, kernel_size=3, padding=1),
nn.Tanh()
)
def _make_layer(self, in_channels, out_channels, num_blocks, stride):
layers = []
layers.append(ResidualBlock(in_channels, out_channels, stride))
for _ in range(1, num_blocks):
layers.append(ResidualBlock(out_channels, out_channels, stride=1))
return nn.Sequential(*layers)
def forward(self, x):
latent = self.encoder(x)
reconstructed = self.decoder(latent)
return reconstructed
model = DeepAutoencoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
train_losses = []
test_losses = []
similarities = []
os.makedirs('results', exist_ok=True)
def pltLoss():
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(test_losses, label='Test Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Test Loss')
plt.legend()
plt.grid(True)
plt.savefig('results/loss_curve.png')
plt.close()
plt.figure(figsize=(10, 5))
plt.plot(similarities, label='Cosine Similarity', color='green')
plt.xlabel('Epochs')
plt.ylabel('Similarity')
plt.title('Test Set Reconstruction Similarity')
plt.legend()
plt.grid(True)
plt.savefig('results/similarity_curve.png')
plt.close()
def calcTestDataSet_Rebuildsimilarity():
model.eval()
total_similarity = 0.0
total_samples = 0
with torch.no_grad():
for images, _ in test_loader:
images = images.to(device)
outputs = model(images)
images_flat = images.view(images.size(0), -1).cpu().numpy()
outputs_flat = outputs.view(outputs.size(0), -1).cpu().numpy()
batch_similarity = cosine_similarity(images_flat, outputs_flat)
total_similarity += np.diag(batch_similarity).sum()
total_samples += images.size(0)
avg_similarity = total_similarity / total_samples
similarities.append(avg_similarity)
print(f"Average Cosine Similarity: {avg_similarity:.4f}")
return avg_similarity
def findTopSample():
model.eval()
max_loss = 0.0
worst_sample = None
worst_recon = None
worst_label = None
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
outputs = model(images)
loss = criterion(outputs, images)
batch_loss = torch.mean((outputs - images)**2, dim=[1,2,3])
batch_max_loss, idx = torch.max(batch_loss, 0)
if batch_max_loss > max_loss:
max_loss = batch_max_loss
worst_sample = images[idx].cpu()
worst_recon = outputs[idx].cpu()
worst_label = labels[idx].item()
if worst_sample is not None:
worst_sample = (worst_sample * 0.5) + 0.5
worst_recon = (worst_recon * 0.5) + 0.5
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.imshow(worst_sample.permute(1, 2, 0).numpy())
plt.title(f"Original (Label: {worst_label})")
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(worst_recon.permute(1, 2, 0).numpy())
plt.title("Reconstruction")
plt.axis('off')
plt.suptitle(f"Worst Sample (Epoch {len(train_losses)})")
plt.savefig(f'results/worst_sample_epoch_{len(train_losses)}.png')
plt.close()
def train(epoch):
model.train()
running_loss = 0.0
for images, _ in train_loader:
images = images.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, images)
loss.backward()
optimizer.step()
running_loss += loss.item()
epoch_loss = running_loss / len(train_loader)
train_losses.append(epoch_loss)
print(f"Epoch {epoch} Training Loss: {epoch_loss:.4f}")
def test(epoch):
model.eval()
running_loss = 0.0
with torch.no_grad():
for images, _ in test_loader:
images = images.to(device)
outputs = model(images)
loss = criterion(outputs, images)
running_loss += loss.item()
epoch_loss = running_loss / len(test_loader)
test_losses.append(epoch_loss)
print(f"Epoch {epoch} Test Loss: {epoch_loss:.4f}")
return epoch_loss
num_epochs = 10
for epoch in range(1, num_epochs+1):
start_time = time.time()
train(epoch)
test_loss = test(epoch)
scheduler.step(test_loss)
calcTestDataSet_Rebuildsimilarity()
findTopSample()
pltLoss()
epoch_time = time.time() - start_time
print(f"Epoch {epoch} completed in {epoch_time:.2f} seconds\n")
torch.save(model.state_dict(), 'autoencoder.pth')
print("Training finished!")