day53

发布于:2025-07-07 ⋅ 阅读:(10) ⋅ 点赞:(0)

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from sklearn.datasets import load_iris
import warnings

# 忽略不必要的警告信息
warnings.filterwarnings("ignore")

# --------------------------
# 1. 配置训练参数与设备
# --------------------------

# 潜在空间维度(生成器的输入维度)
latent_dim = 10  
# 训练总轮数(GAN通常需要较多迭代才能收敛)
train_epochs = 10000  
# 批次大小(根据数据集规模调整)
batch_size = 32  
# 学习率(控制参数更新幅度)
learning_rate = 0.0002  
# Adam优化器的动量参数(影响收敛稳定性)
beta1 = 0.5  

# 自动选择运算设备(优先GPU,没有则用CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"当前使用设备: {device}")

# --------------------------
# 2. 数据加载与预处理
# --------------------------

# 加载鸢尾花数据集
iris_dataset = load_iris()
# 提取特征数据和标签
features = iris_dataset.data
labels = iris_dataset.target

# 只选取Setosa类别(标签为0)的数据进行训练
setosa_features = features[labels == 0]

# 将数据缩放到[-1, 1]区间(配合生成器的Tanh输出激活)
scaler = MinMaxScaler(feature_range=(-1, 1))
scaled_features = scaler.fit_transform(setosa_features)

# 转换为PyTorch张量并创建数据加载器
# 注意:必须转为float类型才能与模型参数兼容
data_tensor = torch.from_numpy(scaled_features).float()
dataset = TensorDataset(data_tensor)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 打印数据基本信息
print(f"训练样本数量: {len(scaled_features)}")
print(f"特征维度: {scaled_features.shape[1]}")  # 鸢尾花数据集固定为4维特征

# --------------------------
# 3. 定义生成器和判别器
# --------------------------

class Generator(nn.Module):
    """生成器:将随机噪声转换为模拟的鸢尾花特征数据"""
    def __init__(self):
        super(Generator, self).__init__()
        # 简单的全连接网络结构
        self.net = nn.Sequential(
            nn.Linear(latent_dim, 16),  # 从潜在空间映射到16维
            nn.ReLU(),  # 激活函数增加非线性
            nn.Linear(16, 32),  # 进一步映射到32维
            nn.ReLU(),
            nn.Linear(32, 4),  # 输出4维特征(与真实数据一致)
            nn.Tanh()  # 确保输出在[-1, 1]范围内
        )
    
    def forward(self, x):
        # 前向传播:输入噪声,输出生成的数据
        return self.net(x)

class Discriminator(nn.Module):
    """判别器:区分输入数据是真实样本还是生成器伪造的"""
    def __init__(self):
        super(Discriminator, self).__init__()
        # 简单的全连接网络结构
        self.net = nn.Sequential(
            nn.Linear(4, 32),  # 输入4维特征
            nn.LeakyReLU(0.2),  # LeakyReLU避免梯度消失问题
            nn.Linear(32, 16),  # 压缩到16维
            nn.LeakyReLU(0.2),
            nn.Linear(16, 1),  # 输出单个概率值
            nn.Sigmoid()  # 将输出压缩到[0,1](表示真实数据的概率)
        )
    
    def forward(self, x):
        # 前向传播:输入数据,输出判断概率
        return self.net(x)

# 初始化模型并移动到运算设备
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# 打印模型结构
print("\n生成器结构:")
print(generator)
print("\n判别器结构:")
print(discriminator)

# --------------------------
# 4. 配置训练组件
# --------------------------

# 定义损失函数(二元交叉熵,适合二分类问题)
criterion = nn.BCELoss()

# 定义优化器(分别优化生成器和判别器)
gen_optimizer = optim.Adam(generator.parameters(), lr=learning_rate, betas=(beta1, 0.999))
dis_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(beta1, 0.999))

# --------------------------
# 5. 开始训练
# --------------------------

print("\n--- 训练开始 ---")
for epoch in range(train_epochs):
    # 遍历数据加载器中的每一批次
    for batch_idx, (real_data,) in enumerate(data_loader):
        # 将真实数据移动到运算设备
        real_data = real_data.to(device)
        current_batch_size = real_data.size(0)  # 获取当前批次的实际样本数(最后一批可能不满)
        
        # 创建标签:真实数据标为1,生成数据标为0
        real_labels = torch.ones(current_batch_size, 1).to(device)
        fake_labels = torch.zeros(current_batch_size, 1).to(device)
        
        # --------------------
        # 训练判别器
        # --------------------
        dis_optimizer.zero_grad()  # 清空判别器的梯度缓存
        
        # 1. 用真实数据训练
        real_output = discriminator(real_data)
        # 计算真实数据的损失(希望判别器能认出真实数据)
        loss_real = criterion(real_output, real_labels)
        
        # 2. 用生成的数据训练
        # 生成随机噪声(作为生成器的输入)
        noise = torch.randn(current_batch_size, latent_dim).to(device)
        # 生成假数据,并阻断梯度流向生成器(避免影响生成器参数)
        fake_data = generator(noise).detach()
        fake_output = discriminator(fake_data)
        # 计算假数据的损失(希望判别器能认出假数据)
        loss_fake = criterion(fake_output, fake_labels)
        
        # 总损失反向传播并更新判别器参数
        dis_loss = loss_real + loss_fake
        dis_loss.backward()
        dis_optimizer.step()
        
        # --------------------
        # 训练生成器
        # --------------------
        gen_optimizer.zero_grad()  # 清空生成器的梯度缓存
        
        # 重新生成假数据(这次需要计算生成器的梯度)
        noise = torch.randn(current_batch_size, latent_dim).to(device)
        fake_data = generator(noise)
        fake_output = discriminator(fake_data)
        
        # 生成器的损失:希望判别器把假数据当成真的(所以标签用real_labels)
        gen_loss = criterion(fake_output, real_labels)
        gen_loss.backward()
        gen_optimizer.step()
    
    # 每1000轮打印一次训练状态
    if (epoch + 1) % 1000 == 0:
        print(
            f"轮次 [{epoch+1}/{train_epochs}], "
            f"判别器损失: {dis_loss.item():.4f}, "
            f"生成器损失: {gen_loss.item():.4f}"
        )

print("\n--- 训练完成 ---")


网站公告

今日签到

点亮在社区的每一天
去签到