- 早停策略


import torch.nn as nn
import torch.optim as optim
import time
import matplotlib.pyplot as plt
from tqdm import tqdm
# Define the MLP model
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.fc1 = nn.Linear(X_train.shape[1], 10)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(10, 2) # Binary classification
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
# Instantiate the model
model = MLP().to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# Training settings
num_epochs = 20000
early_stop_patience = 50 # Epochs to wait for improvement
best_loss = float('inf')
patience_counter = 0
best_epoch = 0
early_stopped = False
# Track losses
train_losses = []
test_losses = []
epochs = []
# Start training
start_time = time.time()
with tqdm(total=num_epochs, desc="Training Progress", unit="epoch") as pbar:
for epoch in range(num_epochs):
model.train()
optimizer.zero_grad()
outputs = model(X_train)
train_loss = criterion(outputs, y_train)
train_loss.backward()
optimizer.step()
# Evaluate on the test set
model.eval()
with torch.no_grad():
outputs_test = model(X_test)
test_loss = criterion(outputs_test, y_test)
if (epoch + 1) % 200 == 0:
train_losses.append(train_loss.item())
test_losses.append(test_loss.item())
epochs.append(epoch + 1)
# Early stopping check
if test_loss.item() < best_loss: # If current test loss is better than the best
best_loss = test_loss.item() # Update best loss
best_epoch = epoch + 1 # Update best epoch
patience_counter = 0 # Reset counter
# Save the best model
torch.save(model.state_dict(), 'best_model.pth')
else:
patience_counter += 1
if patience_counter >= early_stop_patience:
print(f"Early stopping triggered! No improvement for {early_stop_patience} epochs.")
print(f"Best test loss was at epoch {best_epoch} with a loss of {best_loss:.4f}")
early_stopped = True
break # Stop the training loop
# Update the progress bar
pbar.set_postfix({'Train Loss': f'{train_loss.item():.4f}', 'Test Loss': f'{test_loss.item():.4f}'})
# Update progress bar every 1000 epochs
if (epoch + 1) % 1000 == 0:
pbar.update(1000)
# Ensure progress bar reaches 100%
if pbar.n < num_epochs:
pbar.update(num_epochs - pbar.n)
time_all = time.time() - start_time # Calculate total training time
print(f'Training time: {time_all:.2f} seconds')
# If early stopping occurred, load the best model
if early_stopped:
print(f"Loading best model from epoch {best_epoch} for final evaluation...")
model.load_state_dict(torch.load('best_model.pth'))
# Continue training for 50 more epochs after loading the best model
num_extra_epochs = 50
for epoch in range(num_extra_epochs):
model.train()
optimizer.zero_grad()
outputs = model(X_train)
train_loss = criterion(outputs, y_train)
train_loss.backward()
optimizer.step()
# Evaluate on the test set
model.eval()
with torch.no_grad():
outputs_test = model(X_test)
test_loss = criterion(outputs_test, y_test)
train_losses.append(train_loss.item())
test_losses.append(test_loss.item())
epochs.append(num_epochs + epoch + 1)
# Print progress for the extra epochs
print(f"Epoch {num_epochs + epoch + 1}: Train Loss = {train_loss.item():.4f}, Test Loss = {test_loss.item():.4f}")
# Plot the loss curves
plt.figure(figsize=(10, 6))
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Test Loss over Epochs')
plt.legend()
plt.grid(True)
plt.show()
# Evaluate final accuracy on the test set
model.eval()
with torch.no_grad():
outputs = model(X_test)
_, predicted = torch.max(outputs, 1)
correct = (predicted == y_test).sum().item()
accuracy = correct / y_test.size(0)
print(f'Test Accuracy: {accuracy * 100:.2f}%')
@浙大疏锦行