import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import os
import torchvision
# 配置中文显示支持
plt.rcParams["font.sans-serif"] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False # 修正负号渲染问题
# 检测并选择运算设备
computation_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"当前运算设备: {computation_device}")
# 定义训练数据增强流程
training_transforms = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
])
# 定义测试数据处理流程
testing_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
])
# 准备训练数据集
train_set = datasets.CIFAR10(root='./dataset', train=True, download=True, transform=training_transforms)
# 准备测试数据集
test_set = datasets.CIFAR10(root='./dataset', train=False, transform=testing_transforms)
# 创建训练数据加载器
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
# 创建测试数据加载器
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)
@浙大疏锦行