import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import os
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs('visualizations', exist_ok=True)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
class VSSM(nn.Module):
def __init__(self, input_size=784, hidden_size=32, state_size=16, output_size=10):
super(VSSM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.state_size = state_size
self.output_size = output_size
self.encoder = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU()
)
self.fc_mu = nn.Linear(hidden_size, state_size)
self.fc_logvar = nn.Linear(hidden_size, state_size)
self.transition = nn.Sequential(
nn.Linear(state_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, state_size)
)
self.decoder = nn.Sequential(
nn.Linear(state_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, input_size)
)
self.classifier = nn.Sequential(
nn.Linear(state_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_size, output_size)
)
def encode(self, x):
h = self.encoder(x)
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
return self.decoder(z)
def classify(self, z):
return self.classifier(z)
def forward(self, x):
batch_size = x.size(0)
x_flat = x.view(batch_size, -1)
mu, logvar = self.encode(x_flat)
z = self.reparameterize(mu, logvar)
z_next = self.transition(z)
recon_flat = self.decode(z_next)
pred = self.classify(z)
return recon_flat, pred, mu, logvar, z, x_flat
def vssm_loss(recon_x, x, pred, target, mu, logvar, lambda_kl=0.1, lambda_cls=1.0):
recon_loss = F.mse_loss(recon_x, x.view(x.size(0), -1), reduction='sum')
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
cls_loss = F.cross_entropy(pred, target, reduction='sum')
batch_size = x.size(0)
total_loss = (recon_loss + lambda_kl * kl_loss + lambda_cls * cls_loss) / batch_size
return total_loss, recon_loss.item()/batch_size, kl_loss.item()/batch_size, cls_loss.item()/batch_size
def pltLoss(train_losses, test_losses, epochs):
plt.figure(figsize=(10, 5))
plt.plot(range(1, epochs+1), train_losses, 'b-', label='Training Loss')
plt.plot(range(1, epochs+1), test_losses, 'r-', label='Test Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Test Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('loss_curve.png')
plt.close()
def plotTest(model, test_loader, device, epoch):
model.eval()
best_sample = None
best_confidence = -1
best_info = None
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
recon_flat, pred, mu, logvar, z, x_flat = model(data)
confidence = F.softmax(pred, dim=1).max(dim=1)[0]
max_idx = confidence.argmin().item()
if confidence[max_idx] > best_confidence:
best_confidence = confidence[max_idx].item()
best_sample = {
'input': data[max_idx].cpu(),
'recon': recon_flat[max_idx].cpu().view(1, 28, 28),
'target': target[max_idx].cpu().item(),
'pred': pred[max_idx].argmax().cpu().item(),
'confidence': best_confidence,
'mu': mu[max_idx].cpu().numpy(),
'logvar': logvar[max_idx].cpu().numpy(),
'z': z[max_idx].cpu().numpy(),
'pred_dist': F.softmax(pred[max_idx], dim=0).cpu().numpy()
}
del data, target, recon_flat, pred, mu, logvar, z, x_flat, confidence, max_idx
torch.cuda.empty_cache()
if best_sample is not None:
plt.figure(figsize=(12, 8))
plt.subplot(2, 3, 1)
plt.title(f'Input Image (True: {best_sample["target"]})')
plt.imshow(best_sample['input'].squeeze().numpy(), cmap='gray')
plt.axis('off')
plt.subplot(2, 3, 2)
plt.title(f'Reconstructed Image')
plt.imshow(best_sample['recon'].squeeze().numpy(), cmap='gray')
plt.axis('off')
plt.subplot(2, 3, 3)
plt.title('Latent Mean (μ)')
plt.bar(range(len(best_sample['mu'])), best_sample['mu'])
plt.xlabel('Dimension')
plt.ylabel('Value')
plt.subplot(2, 3, 4)
plt.title('Latent Log Variance (log σ²)')
plt.bar(range(len(best_sample['logvar'])), best_sample['logvar'])
plt.xlabel('Dimension')
plt.ylabel('Value')
plt.subplot(2, 3, 5)
plt.title('Sampled Latent Variable (z)')
plt.bar(range(len(best_sample['z'])), best_sample['z'])
plt.xlabel('Dimension')
plt.ylabel('Value')
plt.subplot(2, 3, 6)
plt.title(f'Prediction Distribution (Pred: {best_sample["pred"]}, Conf: {best_sample["confidence"]:.4f})')
plt.bar(range(10), best_sample['pred_dist'])
plt.xticks(range(10))
plt.xlabel('Class')
plt.ylabel('Probability')
plt.tight_layout()
plt.savefig(f'visualizations/epoch_{epoch}_best_sample.png')
plt.close()
model = VSSM().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5, verbose=True)
def train(model, train_loader, optimizer, epoch, device):
model.train()
train_loss = 0
train_recon_loss = 0
train_kl_loss = 0
train_cls_loss = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
recon, pred, mu, logvar, z, x_flat = model(data)
loss, recon_loss, kl_loss, cls_loss = vssm_loss(recon, data, pred, target, mu, logvar)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_recon_loss += recon_loss
train_kl_loss += kl_loss
train_cls_loss += cls_loss
if batch_idx % 100 == 0:
print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
avg_loss = train_loss / len(train_loader)
avg_recon_loss = train_recon_loss / len(train_loader)
avg_kl_loss = train_kl_loss / len(train_loader)
avg_cls_loss = train_cls_loss / len(train_loader)
print(f'Epoch: {epoch} Average training loss: {avg_loss:.4f} '
f'(Recon: {avg_recon_loss:.4f}, KL: {avg_kl_loss:.4f}, Cls: {avg_cls_loss:.4f})')
return avg_loss
def test(model, test_loader, device):
model.eval()
test_loss = 0
test_recon_loss = 0
test_kl_loss = 0
test_cls_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
recon, pred, mu, logvar, z, x_flat = model(data)
loss, recon_loss, kl_loss, cls_loss = vssm_loss(recon, data, pred, target, mu, logvar)
test_loss += loss.item()
test_recon_loss += recon_loss
test_kl_loss += kl_loss
test_cls_loss += cls_loss
pred_class = pred.argmax(dim=1, keepdim=True)
correct += pred_class.eq(target.view_as(pred_class)).sum().item()
avg_loss = test_loss / len(test_loader)
avg_recon_loss = test_recon_loss / len(test_loader)
avg_kl_loss = test_kl_loss / len(test_loader)
avg_cls_loss = test_cls_loss / len(test_loader)
accuracy = 100. * correct / len(test_loader.dataset)
print(f'Average test loss: {avg_loss:.4f} '
f'(Recon: {avg_recon_loss:.4f}, KL: {avg_kl_loss:.4f}, Cls: {avg_cls_loss:.4f})')
print(f'Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)')
return avg_loss, accuracy
epochs = 10
train_losses = []
test_losses = []
best_accuracy = 0.0
for epoch in range(1, epochs + 1):
print(f'\nEpoch {epoch}/{epochs}')
train_loss = train(model, train_loader, optimizer, epoch, device)
train_losses.append(train_loss)
test_loss, accuracy = test(model, test_loader, device)
test_losses.append(test_loss)
plotTest(model, test_loader, device, epoch)
scheduler.step(test_loss)
if accuracy > best_accuracy:
best_accuracy = accuracy
torch.save(model.state_dict(), 'best_model.pth')
print(f'Best model saved with accuracy: {accuracy:.2f}%')
pltLoss(train_losses, test_losses, epoch)
torch.cuda.empty_cache()
print(f'\nTraining completed. Best accuracy: {best_accuracy:.2f}%')