多模型协同预测在风机故障预测的应用(demo)

发布于:2025-05-10 ⋅ 阅读:(9) ⋅ 点赞:(0)

  1. 数据加载和预处理的真实性

    • 下面的代码中,DummyDataset 和数据加载部分仍然是高度简化和占位的。为了让这个训练循环真正有效,您必须用您自己的数据加载逻辑替换它。
    • 这意味着您需要创建一个 torch.utils.data.Dataset 的子类,它能够正确地从您的数据源(例如CSV文件、数据库、文件夹中的原始信号文件)加载每个样本的多种传感器数据。
    • __getitem__ 方法中,您需要调用 DataProcessor 的相应 process_... 方法来提取特征,然后进行归一化(如果需要模型直接处理归一化后的特征,而不是在predict中才做),并将所有数据转换成模型期望的张量格式和形状。
    • fit_scalers 的调用时机DataProcessor 中的 fit_scalers 方法必须在创建 DataLoader 并开始训练之前,使用整个训练集提取出的特征进行调用。这一步至关重要。
  2. 特征准备的复杂性

    • 在批处理训练中,为每个模型(CNN, LSTM, Electrical)准备输入可能很复杂,特别是 LSTM 需要特征序列。您可能需要在 Dataset__getitem__ 中就准备好这些序列,或者设计一个高效的批处理函数。
    • 规则引擎在验证批次中的使用也需要仔细考虑,因为它通常处理单个样本的特征。
  3. 模型和融合权重的调优

    • 此代码提供了一个结构。要获得高准确率,您仍然需要对各个模型的超参数、model_weights(融合权重)进行仔细的实验和调优。
  4. 计算资源

    • 训练深度学习模型(尤其是多个模型)可能需要大量的计算资源(GPU)和时间。
import numpy as np
import torch
import torch.nn as nn
from scipy import signal as sig
from scipy.stats import kurtosis, skew
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import pickle
from typing import Dict, List, Tuple, Optional, Any
import os # 用于创建目录

# --- 配置常量 (根据需要调整) ---
VIBRATION_RAW_SEQ_LEN = 1024
VIBRATION_SAMPLING_RATE = 10000
PRESSURE_SAMPLING_RATE = 100
LSTM_FEATURE_SIZE = 10 # 温度(5) + 压力(4) + 电流单特征(1)
LSTM_SEQ_LEN = 5 # LSTM输入特征的示例序列长度
ELECTRICAL_FEATURE_SIZE = 2
NUM_FAULT_CLASSES = 10

# 创建模型和结果保存目录
MODEL_SAVE_DIR = "saved_models"
SCALER_SAVE_PATH = os.path.join(MODEL_SAVE_DIR, "fitted_scalers.pkl")
BEST_CNN_PATH = os.path.join(MODEL_SAVE_DIR, "best_cnn.pth")
BEST_LSTM_PATH = os.path.join(MODEL_SAVE_DIR, "best_lstm.pth")
BEST_ELECTRICAL_PATH = os.path.join(MODEL_SAVE_DIR, "best_electrical.pth")

if not os.path.exists(MODEL_SAVE_DIR):
    os.makedirs(MODEL_SAVE_DIR)

class DataProcessor:
    """全故障覆盖的数据处理模块"""
    def __init__(self):
        self.scalers = {
            'vibration_features': StandardScaler(),
            'temperature': StandardScaler(),
            'pressure': StandardScaler(),
            'blade_angle': StandardScaler(),
            'oil_particles': StandardScaler(),
            'current': StandardScaler()
        }
        self.fitted_scalers = {}
        self.fault_types = {
            0: "正常", 1: "叶轮不平衡", 2: "轴承失效", 3: "动叶卡涩",
            4: "喘振", 5: "积灰堵塞", 6: "电机绕组故障", 7: "传感器失效",
            8: "基础松动", 9: "密封失效"
        }
        assert len(self.fault_types) == NUM_FAULT_CLASSES, "NUM_FAULT_CLASSES 和 fault_types 长度不匹配"

    def fit_scalers(self, training_features_dict: Dict[str, List[np.ndarray]]):
        print("正在拟合缩放器...")
        for feature_type, features_list in training_features_dict.items():
            if feature_type in self.scalers and features_list:
                all_features_for_type = np.array(features_list)
                if all_features_for_type.ndim == 1:
                    all_features_for_type = all_features_for_type.reshape(-1, 1)
                if all_features_for_type.shape[0] == 0:
                    print(f"警告:未提供用于拟合 {feature_type} 缩放器的数据。")
                    continue
                try:
                    self.scalers[feature_type].fit(all_features_for_type)
                    self.fitted_scalers[feature_type] = True
                    print(f"{feature_type} 的缩放器已拟合,数据形状: {all_features_for_type.shape}")
                except Exception as e:
                    print(f"拟合 {feature_type} 缩放器时出错 (形状 {all_features_for_type.shape}): {e}")
            elif feature_type not in self.scalers:
                 print(f"未为特征类型定义缩放器: {feature_type}")

    def save_scalers(self, path=SCALER_SAVE_PATH):
        with open(path, 'wb') as f:
            pickle.dump(self.scalers, f)
        print(f"已拟合的缩放器保存至 {path}")

    def load_scalers(self, path=SCALER_SAVE_PATH):
        try:
            with open(path, 'rb') as f:
                self.scalers = pickle.load(f)
            for key in self.scalers.keys():
                self.fitted_scalers[key] = True
            print(f"缩放器从 {path} 加载成功")
        except FileNotFoundError:
            print(f"缩放器文件 {path} 未找到。如果不是在训练阶段,则初始化新的缩放器。")
        except Exception as e:
            print(f"加载缩放器错误: {e}。如果不是在训练阶段,则初始化新的缩放器。")
            self.__init__()

    def process_vibration(self, vibration_signal: np.ndarray, sampling_rate: int = VIBRATION_SAMPLING_RATE) -> np.ndarray:
        if not isinstance(vibration_signal, np.ndarray) or vibration_signal.ndim != 1:
            raise ValueError("振动信号必须是一维numpy数组。")
        if len(vibration_signal) < 2:
             return np.zeros(8)
        rms = np.sqrt(np.mean(vibration_signal**2))
        peak = np.max(np.abs(vibration_signal))
        kurtosis_val = kurtosis(vibration_signal)
        skewness_val = skew(vibration_signal)
        nperseg_val = min(len(vibration_signal), 1024)
        if nperseg_val == 0:
            f, psd = np.array([]), np.array([])
        else:
            f, psd = sig.welch(vibration_signal, sampling_rate, nperseg=nperseg_val)
        def get_freq_component(freq_target):
            if f.size > 0 and psd.size > 0:
                idx = np.argmin(np.abs(f - freq_target))
                if np.min(np.abs(f - freq_target)) < ( (f[1]-f[0]) if len(f)>1 else sampling_rate/2 ):
                    return psd[idx]
            return 0.0
        freq_1x = get_freq_component(50)
        freq_2x = get_freq_component(100)
        freq_3x = get_freq_component(150)
        low_freq_energy = 0.0
        if psd.size > 0 and np.sum(psd) > 0:
            relevant_psd = psd[(f >= 5) & (f <= 50)]
            if relevant_psd.size > 0:
                 low_freq_energy = np.sum(relevant_psd) / np.sum(psd)
        return np.array([rms, peak, kurtosis_val, skewness_val, freq_1x, freq_2x, freq_3x, low_freq_energy])

    def process_temperature(self, temp_series: np.ndarray) -> np.ndarray:
        if not isinstance(temp_series, np.ndarray) or temp_series.ndim != 1 or len(temp_series) == 0:
            return np.zeros(5)
        current_temp = temp_series[-1]
        avg_temp = np.mean(temp_series)
        max_temp = np.max(temp_series)
        temp_rate = np.diff(temp_series).mean() if len(temp_series) > 1 else 0.0
        stator_temp = temp_series[1] if len(temp_series) > 1 else temp_series[0]
        return np.array([current_temp, avg_temp, max_temp, temp_rate, stator_temp])

    def process_pressure(self, pressure_series: np.ndarray, fs: int = PRESSURE_SAMPLING_RATE) -> np.ndarray:
        if not isinstance(pressure_series, np.ndarray) or pressure_series.ndim != 1 or len(pressure_series) == 0:
            return np.zeros(4)
        mean_pressure = np.mean(pressure_series)
        std_pressure = np.std(pressure_series)
        max_fluctuation = np.max(np.abs(np.diff(pressure_series))) if len(pressure_series) > 1 else 0.0
        low_freq_energy = self._calculate_low_freq_energy(pressure_series, fs, 0.5, 2.0)
        return np.array([mean_pressure, std_pressure, max_fluctuation, low_freq_energy])

    def process_blade_angle(self, angle_series: np.ndarray, target_angle: float) -> np.ndarray:
        if not isinstance(angle_series, np.ndarray) or angle_series.ndim != 1 or len(angle_series) == 0:
            return np.zeros(5)
        current_angle = angle_series[-1]
        angle_deviation = np.abs(current_angle - target_angle)
        angle_rate = np.abs(np.diff(angle_series)).mean() if len(angle_series) > 1 else 0.0
        stuck_points = np.sum(np.abs(np.diff(angle_series)) < 0.5) / len(angle_series) if len(angle_series) > 1 else 0.0
        return np.array([current_angle, target_angle, angle_deviation, angle_rate, stuck_points])

    def process_oil_analysis(self, particles: float, viscosity: float) -> np.ndarray:
        return np.array([particles, viscosity])

    def process_current(self, current_series: np.ndarray) -> np.ndarray:
        if not isinstance(current_series, np.ndarray) or len(current_series) == 0:
             return np.zeros(2)
        mean_curr = np.mean(current_series)
        if mean_curr == 0: return np.array([0.0, 0.0])
        if current_series.ndim == 1:
            harmonic_ratio = np.std(current_series) / mean_curr if mean_curr != 0 else 0
            unbalance_metric = np.max(np.abs(current_series - mean_curr)) / mean_curr if mean_curr != 0 else 0
        elif current_series.ndim == 0 and len(current_series.shape) == 1 and len(current_series) == 3:
            harmonic_ratio = 0
            unbalance_metric = (np.max(current_series) - np.min(current_series)) / mean_curr if mean_curr != 0 else 0
        else:
            harmonic_ratio = np.std(current_series.flatten()) / np.mean(current_series.flatten()) if np.mean(current_series.flatten()) !=0 else 0
            unbalance_metric = 0
        return np.array([harmonic_ratio, unbalance_metric])

    def _calculate_low_freq_energy(self, input_signal: np.ndarray, fs: float, f_low: float, f_high: float) -> float:
        if not isinstance(input_signal, np.ndarray) or len(input_signal) == 0: return 0.0
        nperseg_val = min(len(input_signal), 128)
        if nperseg_val < 2 : return 0.0
        f, psd = sig.welch(input_signal, fs=fs, nperseg=nperseg_val)
        if psd.size == 0 or np.sum(psd) == 0: return 0.0
        relevant_psd = psd[(f >= f_low) & (f <= f_high)]
        if relevant_psd.size == 0 : return 0.0
        return np.sum(relevant_psd) / np.sum(psd)

    def normalize_features(self, features: np.ndarray, feature_type: str) -> np.ndarray:
        if feature_type in self.scalers:
            if not self.fitted_scalers.get(feature_type):
                if "predicting_now" in globals() and globals()["predicting_now"]:
                    raise RuntimeError(f"{feature_type} 的缩放器必须在预测前拟合。")
                return features
            if features.ndim == 1:
                features_reshaped = features.reshape(1, -1)
            elif features.ndim == 2 and features.shape[0] == 1:
                features_reshaped = features
            else:
                 raise ValueError(f"{feature_type} 的特征形状对于缩放不符合预期: {features.shape}")
            return self.scalers[feature_type].transform(features_reshaped).flatten()
        return features

class CNNModel(nn.Module):
    def __init__(self, input_channels: int = 1, num_classes: int = NUM_FAULT_CLASSES, example_input_len: int = VIBRATION_RAW_SEQ_LEN):
        super(CNNModel, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv1d(input_channels, 32, kernel_size=7, stride=1, padding=3), nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.Conv1d(32, 64, kernel_size=5, stride=1, padding=2), nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2)
        )
        dummy_input = torch.randn(1, input_channels, example_input_len)
        with torch.no_grad():
            conv_output_size = self.conv_layers(dummy_input).view(1, -1).size(1)
        self.fc = nn.Sequential(
            nn.Linear(conv_output_size, 256), nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )
    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class LSTMModel(nn.Module):
    def __init__(self, input_size: int = LSTM_FEATURE_SIZE, hidden_size: int = 128, num_classes: int = NUM_FAULT_CLASSES):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
        self.fc = nn.Sequential(
            nn.Linear(hidden_size * 2, 256), nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )
    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        x = lstm_out[:, -1, :]
        return self.fc(x)

class DSEvidenceFusion:
    def __init__(self, num_classes: int = NUM_FAULT_CLASSES):
        self.num_classes = num_classes
    def combine_evidence(self, evidences: List[np.ndarray], weights: Optional[List[float]] = None) -> np.ndarray:
        if not evidences:
            return np.ones(self.num_classes) / self.num_classes
        if weights is None:
            weights = [1.0 / len(evidences)] * len(evidences)
        else:
            if len(weights) != len(evidences):
                raise ValueError("权重数量必须与证据数量匹配。")
            weights = np.array(weights) / np.sum(weights)
        processed_evidences = []
        for e in evidences:
            if not isinstance(e, np.ndarray): e = np.array(e)
            if e.shape != (self.num_classes,):
                 raise ValueError(f"证据形状不匹配。期望 ({self.num_classes},),得到 {e.shape}")
            processed_evidences.append(e)
        combined_belief = np.zeros(self.num_classes)
        for i, evidence in enumerate(processed_evidences):
            combined_belief += weights[i] * evidence
        return combined_belief

class FaultDiagnosisSystem:
    def __init__(self, cnn_input_len=VIBRATION_RAW_SEQ_LEN, lstm_feat_size=LSTM_FEATURE_SIZE, electrical_feat_size=ELECTRICAL_FEATURE_SIZE):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.data_processor = DataProcessor()
        self.num_classes = len(self.data_processor.fault_types)
        self.cnn_model = CNNModel(num_classes=self.num_classes, example_input_len=cnn_input_len).to(self.device)
        self.lstm_model = LSTMModel(input_size=lstm_feat_size, num_classes=self.num_classes).to(self.device)
        self.electrical_model = nn.Sequential(
            nn.Linear(electrical_feat_size, 32), nn.ReLU(),
            nn.Linear(32, self.num_classes),
        ).to(self.device)
        self.fusion = DSEvidenceFusion(num_classes=self.num_classes)
        self.model_weights = [0.4, 0.3, 0.2, 0.1]
        self.design_pressure_default = 100.0

    def load_models(self, cnn_path=BEST_CNN_PATH, lstm_path=BEST_LSTM_PATH, electrical_path=BEST_ELECTRICAL_PATH, scalers_path=SCALER_SAVE_PATH):
        try:
            self.cnn_model.load_state_dict(torch.load(cnn_path, map_location=self.device))
            self.lstm_model.load_state_dict(torch.load(lstm_path, map_location=self.device))
            self.electrical_model.load_state_dict(torch.load(electrical_path, map_location=self.device))
            print("神经网络模型加载成功。")
        except Exception as e:
            print(f"加载神经网络模型时出错: {e}。请确保路径正确且模型匹配。")
            raise
        self.cnn_model.eval()
        self.lstm_model.eval()
        self.electrical_model.eval()
        if scalers_path:
            self.data_processor.load_scalers(scalers_path)
        else:
            print("警告:未提供缩放器路径。除非单独拟合/加载缩放器,否则特征不会被归一化。")

    def save_models(self, cnn_path=BEST_CNN_PATH, lstm_path=BEST_LSTM_PATH, electrical_path=BEST_ELECTRICAL_PATH, scalers_path=SCALER_SAVE_PATH):
        torch.save(self.cnn_model.state_dict(), cnn_path)
        torch.save(self.lstm_model.state_dict(), lstm_path)
        torch.save(self.electrical_model.state_dict(), electrical_path)
        print(f"神经网络模型已保存至 {cnn_path}, {lstm_path}, {electrical_path}")
        if scalers_path:
            self.data_processor.save_scalers(scalers_path)

    def _rule_engine_predict(self, **kwargs) -> np.ndarray:
        fault_probs = np.zeros(self.num_classes)
        fault_probs[0] = 0.1
        v_feats = kwargs.get('vibration_features', np.zeros(8))
        t_feats = kwargs.get('temperature_features', np.zeros(5))
        p_feats = kwargs.get('pressure_features', np.zeros(4))
        b_feats = kwargs.get('blade_angle_features', np.zeros(5))
        o_feats = kwargs.get('oil_features', np.zeros(2))
        c_feats = kwargs.get('current_features', np.zeros(2))
        design_pressure = kwargs.get('design_pressure', self.design_pressure_default)
        if v_feats[2] > 5: fault_probs[self.data_processor.fault_types_str_to_int("轴承失效")] = max(fault_probs[self.data_processor.fault_types_str_to_int("轴承失效")], 0.7)
        if v_feats[7] > 0.1: fault_probs[self.data_processor.fault_types_str_to_int("基础松动")] = max(fault_probs[self.data_processor.fault_types_str_to_int("基础松动")], 0.6)
        if t_feats[0] > 85: fault_probs[self.data_processor.fault_types_str_to_int("轴承失效")] = max(fault_probs[self.data_processor.fault_types_str_to_int("轴承失效")], 0.7)
        if b_feats[2] > 5: fault_probs[self.data_processor.fault_types_str_to_int("动叶卡涩")] = max(fault_probs[self.data_processor.fault_types_str_to_int("动叶卡涩")], 0.8)
        if p_feats[3] > 0.2: fault_probs[self.data_processor.fault_types_str_to_int("喘振")] = max(fault_probs[self.data_processor.fault_types_str_to_int("喘振")], 0.7)
        if p_feats[0] > 1.1 * design_pressure : fault_probs[self.data_processor.fault_types_str_to_int("积灰堵塞")] = max(fault_probs[self.data_processor.fault_types_str_to_int("积灰堵塞")], 0.6)
        if c_feats[0] > 0.1: fault_probs[self.data_processor.fault_types_str_to_int("电机绕组故障")] = max(fault_probs[self.data_processor.fault_types_str_to_int("电机绕组故障")], 0.7)
        if c_feats[1] > 0.1: fault_probs[self.data_processor.fault_types_str_to_int("电机绕组故障")] = max(fault_probs[self.data_processor.fault_types_str_to_int("电机绕组故障")], 0.6)
        if o_feats[0] > 50:
            fault_probs[self.data_processor.fault_types_str_to_int("轴承失效")] = max(fault_probs[self.data_processor.fault_types_str_to_int("轴承失效")], 0.5)
            fault_probs[self.data_processor.fault_types_str_to_int("密封失效")] = max(fault_probs[self.data_processor.fault_types_str_to_int("密封失效")], 0.4)
        if o_feats[1] < 70:
            fault_probs[self.data_processor.fault_types_str_to_int("轴承失效")] = max(fault_probs[self.data_processor.fault_types_str_to_int("轴承失效")], 0.4)
            fault_probs[self.data_processor.fault_types_str_to_int("密封失效")] = max(fault_probs[self.data_processor.fault_types_str_to_int("密封失效")], 0.3)
        return self._normalize_prob(fault_probs)

    def _normalize_prob(self, probs: np.ndarray) -> np.ndarray:
        if np.sum(probs) == 0:
            norm_probs = np.ones_like(probs) / len(probs)
            return norm_probs
        return probs / np.sum(probs)

    def predict(self, vibration_raw: np.ndarray, temperature_series: np.ndarray, 
                pressure_series: np.ndarray, blade_angle_series: np.ndarray, 
                oil_particles_val: float, oil_viscosity_val: float, 
                current_signal: np.ndarray,
                target_blade_angle: float, 
                design_pressure: float,
                vibration_sampling_rate: int = VIBRATION_SAMPLING_RATE,
                pressure_sampling_rate: int = PRESSURE_SAMPLING_RATE
                ) -> Dict[str, float]:
        global predicting_now
        predicting_now = True
        vibration_features = self.data_processor.process_vibration(vibration_raw, vibration_sampling_rate)
        temperature_features = self.data_processor.process_temperature(temperature_series)
        pressure_features = self.data_processor.process_pressure(pressure_series, pressure_sampling_rate)
        blade_angle_features = self.data_processor.process_blade_angle(blade_angle_series, target_blade_angle)
        oil_features = self.data_processor.process_oil_analysis(oil_particles_val, oil_viscosity_val)
        current_features = self.data_processor.process_current(current_signal)
        norm_vibration_f = self.data_processor.normalize_features(vibration_features, 'vibration_features')
        norm_temperature_f = self.data_processor.normalize_features(temperature_features, 'temperature')
        norm_pressure_f = self.data_processor.normalize_features(pressure_features, 'pressure')
        norm_blade_angle_f = self.data_processor.normalize_features(blade_angle_features, 'blade_angle')
        norm_oil_f = self.data_processor.normalize_features(oil_features, 'oil_particles')
        norm_current_f = self.data_processor.normalize_features(current_features, 'current')
        if len(vibration_raw) < VIBRATION_RAW_SEQ_LEN:
            vibration_padded = np.pad(vibration_raw, (0, VIBRATION_RAW_SEQ_LEN - len(vibration_raw)), 'constant')
        else:
            vibration_padded = vibration_raw[:VIBRATION_RAW_SEQ_LEN]
        cnn_input = torch.tensor(vibration_padded.reshape(1, 1, -1), dtype=torch.float32).to(self.device)
        lstm_input_feats_combined = np.concatenate([
            norm_temperature_f, norm_pressure_f, norm_current_f[:1]
        ])
        if lstm_input_feats_combined.shape[0] != LSTM_FEATURE_SIZE:
            raise ValueError(f"LSTM 输入特征大小不匹配。期望 {LSTM_FEATURE_SIZE}, 得到 {lstm_input_feats_combined.shape[0]}")
        lstm_input = torch.tensor(lstm_input_feats_combined.reshape(1, 1, -1), dtype=torch.float32).to(self.device)
        if norm_current_f.shape[0] != ELECTRICAL_FEATURE_SIZE:
             raise ValueError(f"电气模型输入特征大小不匹配。期望 {ELECTRICAL_FEATURE_SIZE}, 得到 {norm_current_f.shape[0]}")
        electrical_input = torch.tensor(norm_current_f.reshape(1, -1), dtype=torch.float32).to(self.device)
        with torch.no_grad():
            cnn_logits = self.cnn_model(cnn_input)
            cnn_output_probs = torch.softmax(cnn_logits, dim=1).cpu().numpy()[0]
            lstm_logits = self.lstm_model(lstm_input)
            lstm_output_probs = torch.softmax(lstm_logits, dim=1).cpu().numpy()[0]
            electrical_logits = self.electrical_model(electrical_input)
            electrical_output_probs = torch.softmax(electrical_logits, dim=1).cpu().numpy()[0]
        rule_output_probs = self._rule_engine_predict(
            vibration_features=norm_vibration_f,
            temperature_features=norm_temperature_f,
            pressure_features=norm_pressure_f,
            blade_angle_features=norm_blade_angle_f,
            oil_features=norm_oil_f,
            current_features=norm_current_f,
            design_pressure=design_pressure
        )
        evidences = [cnn_output_probs, lstm_output_probs, electrical_output_probs, rule_output_probs]
        fused_probs = self.fusion.combine_evidence(evidences, self.model_weights)
        predicting_now = False
        return {
            self.data_processor.fault_types[i]: float(fused_probs[i]) 
            for i in range(len(fused_probs))
        }

    def train(self, train_loader, val_loader, epochs=10, lr=0.001):
        if not all(self.data_processor.fitted_scalers.get(s_type) for s_type in self.data_processor.scalers):
            print("警告:并非所有缩放器都已标记为已拟合。请确保在有代表性的训练数据上调用了 fit_scalers。")
            # raise RuntimeError("必须在训练前拟合所有数据缩放器。") # 可以选择更严格地在此处报错

        cnn_optimizer = torch.optim.Adam(self.cnn_model.parameters(), lr=lr)
        lstm_optimizer = torch.optim.Adam(self.lstm_model.parameters(), lr=lr)
        electrical_optimizer = torch.optim.Adam(self.electrical_model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()
        best_val_accuracy = 0.0

        for epoch in range(epochs):
            self.cnn_model.train()
            self.lstm_model.train()
            self.electrical_model.train()
            
            total_train_loss_cnn, total_train_loss_lstm, total_train_loss_elec = 0, 0, 0
            
            for batch_idx, batch_content in enumerate(train_loader):
                # 解包批次内容,假设 batch_content 是 (data_dict, labels_tensor)
                # data_dict 包含模型所需的各种输入数据
                # 例如: data_dict['vibration_raw'], data_dict['lstm_input_feature_sequence'], data_dict['electrical_features']
                # labels_tensor 是对应的故障类别标签
                # 确保所有张量都被移到正确的设备上 (self.device)

                # **您必须在此处实现从 train_loader 的 batch_content 中提取和准备数据到模型输入的逻辑**
                # 以下是概念性的数据准备,您需要用真实逻辑替换
                # -------------------------------------------------------------------
                if not isinstance(batch_content, (tuple, list)) or len(batch_content) != 2:
                    print(f"警告: train_loader 的批次内容格式不符合预期。跳过批次 {batch_idx}。")
                    continue
                
                data_dict, labels_batch_cpu = batch_content
                if not isinstance(data_dict, dict) or not isinstance(labels_batch_cpu, torch.Tensor):
                    print(f"警告: train_loader 的 data_dict 或 labels 格式不符合预期。跳过批次 {batch_idx}。")
                    continue

                labels_batch = labels_batch_cpu.to(self.device)

                # CNN 输入准备 (假设 'vibration_raw' 在 data_dict 中)
                if 'vibration_raw' not in data_dict:
                    print(f"警告: 批次 {batch_idx} 中缺少 'vibration_raw'。跳过CNN训练。")
                else:
                    cnn_input_batch = data_dict['vibration_raw'].unsqueeze(1).to(self.device) # 添加通道维度
                    cnn_optimizer.zero_grad()
                    cnn_logits = self.cnn_model(cnn_input_batch)
                    loss_cnn = criterion(cnn_logits, labels_batch)
                    loss_cnn.backward()
                    cnn_optimizer.step()
                    total_train_loss_cnn += loss_cnn.item()

                # LSTM 输入准备 (假设 'lstm_input_feature_sequence' 在 data_dict 中)
                if 'lstm_input_feature_sequence' not in data_dict:
                     print(f"警告: 批次 {batch_idx} 中缺少 'lstm_input_feature_sequence'。跳过LSTM训练。")
                else:
                    lstm_input_batch = data_dict['lstm_input_feature_sequence'].to(self.device)
                    lstm_optimizer.zero_grad()
                    lstm_logits = self.lstm_model(lstm_input_batch)
                    loss_lstm = criterion(lstm_logits, labels_batch)
                    loss_lstm.backward()
                    lstm_optimizer.step()
                    total_train_loss_lstm += loss_lstm.item()

                # Electrical Model 输入准备 (假设 'electrical_features' 在 data_dict 中)
                if 'electrical_features' not in data_dict:
                    print(f"警告: 批次 {batch_idx} 中缺少 'electrical_features'。跳过Electrical模型训练。")
                else:
                    electrical_input_batch = data_dict['electrical_features'].to(self.device)
                    electrical_optimizer.zero_grad()
                    electrical_logits = self.electrical_model(electrical_input_batch)
                    loss_electrical = criterion(electrical_logits, labels_batch)
                    loss_electrical.backward()
                    electrical_optimizer.step()
                    total_train_loss_elec += loss_electrical.item()
                # -------------------------------------------------------------------

                if batch_idx > 0 and batch_idx % 50 == 0:
                    print(f"Epoch {epoch+1}, Batch {batch_idx}/{len(train_loader)} | "
                          f"CNN Loss: {total_train_loss_cnn/batch_idx:.4f} | "
                          f"LSTM Loss: {total_train_loss_lstm/batch_idx:.4f} | "
                          f"Elec Loss: {total_train_loss_elec/batch_idx:.4f}")
            
            avg_loss_cnn = total_train_loss_cnn / len(train_loader) if len(train_loader) > 0 else 0
            avg_loss_lstm = total_train_loss_lstm / len(train_loader) if len(train_loader) > 0 else 0
            avg_loss_elec = total_train_loss_elec / len(train_loader) if len(train_loader) > 0 else 0
            print(f"Epoch {epoch+1}/{epochs} Training Complete. Avg Losses: CNN={avg_loss_cnn:.4f}, LSTM={avg_loss_lstm:.4f}, Elec={avg_loss_elec:.4f}")

            # --- 验证阶段 ---
            self.cnn_model.eval()
            self.lstm_model.eval()
            self.electrical_model.eval()
            
            all_val_preds_fused_indices = []
            all_val_labels_list = []
            
            with torch.no_grad():
                for batch_val_content in val_loader:
                    # **您必须在此处实现从 val_loader 的 batch_val_content 中提取和准备数据到模型输入的逻辑**
                    # 同时,为规则引擎准备特征(这可能需要对批次中的每个样本单独处理)
                    # -------------------------------------------------------------------
                    if not isinstance(batch_val_content, (tuple, list)) or len(batch_val_content) != 2:
                        print(f"警告: val_loader 的批次内容格式不符合预期。跳过验证批次。")
                        continue
                    
                    data_dict_val, labels_batch_val_cpu = batch_val_content
                    if not isinstance(data_dict_val, dict) or not isinstance(labels_batch_val_cpu, torch.Tensor):
                        print(f"警告: val_loader 的 data_dict 或 labels 格式不符合预期。跳过验证批次。")
                        continue
                    
                    labels_batch_val = labels_batch_val_cpu.to(self.device)
                    batch_size_val = labels_batch_val.size(0)
                    
                    # 获取神经网络模型概率
                    cnn_probs_b = torch.zeros((batch_size_val, self.num_classes), device=self.device)
                    if 'vibration_raw' in data_dict_val:
                        cnn_input_val_batch = data_dict_val['vibration_raw'].unsqueeze(1).to(self.device)
                        cnn_probs_b = torch.softmax(self.cnn_model(cnn_input_val_batch), dim=1)
                    else: print("验证中缺少 'vibration_raw'")

                    lstm_probs_b = torch.zeros((batch_size_val, self.num_classes), device=self.device)
                    if 'lstm_input_feature_sequence' in data_dict_val:
                        lstm_input_val_batch = data_dict_val['lstm_input_feature_sequence'].to(self.device)
                        lstm_probs_b = torch.softmax(self.lstm_model(lstm_input_val_batch), dim=1)
                    else: print("验证中缺少 'lstm_input_feature_sequence'")

                    electrical_probs_b = torch.zeros((batch_size_val, self.num_classes), device=self.device)
                    if 'electrical_features' in data_dict_val:
                        electrical_input_val_batch = data_dict_val['electrical_features'].to(self.device)
                        electrical_probs_b = torch.softmax(self.electrical_model(electrical_input_val_batch), dim=1)
                    else: print("验证中缺少 'electrical_features'")
                    
                    # 为批次中的每个样本应用规则引擎和融合
                    for i in range(batch_size_val):
                        # **为规则引擎提取和归一化第 i 个样本的特征**
                        # 这部分需要您根据 data_dict_val 的内容和您的 Dataset 实现来填充
                        # 例如:
                        # vibration_features_sample_i = self.data_processor.process_vibration(data_dict_val['vibration_raw_unprocessed'][i], ...)
                        # norm_vibration_f_i = self.data_processor.normalize_features(vibration_features_sample_i, 'vibration_features')
                        # ... 其他特征 ...
                        # rule_output_probs_sample_i = self._rule_engine_predict(vibration_features=norm_vibration_f_i, ...)
                        # 为了演示,我们使用一个虚拟的规则输出
                        rule_output_probs_sample_i = self._normalize_prob(np.random.rand(self.num_classes))


                        evidences_sample_i = [
                            cnn_probs_b[i].cpu().numpy(), 
                            lstm_probs_b[i].cpu().numpy(), 
                            electrical_probs_b[i].cpu().numpy(), 
                            rule_output_probs_sample_i
                        ]
                        fused_probs_sample_i = self.fusion.combine_evidence(evidences_sample_i, self.model_weights)
                        all_val_preds_fused_indices.append(np.argmax(fused_probs_sample_i))
                    
                    all_val_labels_list.extend(labels_batch_val_cpu.numpy()) # 使用CPU上的标签
                    # -------------------------------------------------------------------

            if all_val_labels_list: # 确保处理了至少一个验证批次
                val_accuracy = accuracy_score(all_val_labels_list, all_val_preds_fused_indices)
                print(f"Epoch {epoch+1}/{epochs}, 验证准确率 (融合后): {val_accuracy:.4f}")
                print(classification_report(all_val_labels_list, all_val_preds_fused_indices, target_names=[self.data_processor.fault_types[i] for i in range(self.num_classes)], zero_division=0))
                if val_accuracy > best_val_accuracy:
                    best_val_accuracy = val_accuracy
                    self.save_models() # 使用默认路径保存最佳模型
                    print(f"新最佳模型已保存,验证准确率: {best_val_accuracy:.4f}")
            else:
                print(f"Epoch {epoch+1}/{epochs}, 验证: 未处理数据。")
        print(f"训练完成。最佳验证准确率: {best_val_accuracy:.4f}")

def get_fault_type_int(fault_types_dict, fault_name_str):
    for i, name in fault_types_dict.items():
        if name == fault_name_str:
            return i
    raise ValueError(f"故障名称 '{fault_name_str}' 不在 fault_types 中。")
DataProcessor.fault_types_str_to_int = get_fault_type_int

if __name__ == "__main__":
    globals()["predicting_now"] = False
    system = FaultDiagnosisSystem()
    print(f"系统已初始化。设备: {system.device}")
    print(f"故障类别数量: {system.num_classes}")

    print("\n--- (占位符) 准备数据并拟合缩放器 ---")
    num_training_samples = 100 # 增加样本量以更好地拟合
    
    # --- 准备用于拟合缩放器的数据 ---
    # 这一步至关重要:使用与训练模型时相同的特征提取方法
    raw_training_data_for_scalers = []
    for _ in range(num_training_samples):
        # 模拟从数据集中加载一个样本的所有原始传感器读数
        sample_data = {
            'vibration_raw_unprocessed': np.random.normal(0, 1, VIBRATION_RAW_SEQ_LEN + np.random.randint(-100,100)), # 长度可变
            'temperature_series_unprocessed': np.random.normal(60, 5, np.random.randint(5,15)),
            'pressure_series_unprocessed': np.random.normal(100, 10, np.random.randint(50,150)),
            'blade_angle_series_unprocessed': np.random.normal(45, 2, np.random.randint(10,30)),
            'oil_particles_val_unprocessed': np.random.uniform(10,100),
            'oil_viscosity_val_unprocessed': np.random.uniform(60,90),
            'current_signal_unprocessed': np.random.normal(50, 5, 3), # 假设是3相电流值
            'target_blade_angle_unprocessed': 45.0,
            'design_pressure_unprocessed': 100.0
        }
        raw_training_data_for_scalers.append(sample_data)

    # 从原始数据中提取特征以拟合缩放器
    features_for_scaler_fitting = {
        'vibration_features': [], 'temperature': [], 'pressure': [],
        'blade_angle': [], 'oil_particles': [], 'current': []
    }
    for sample_raw_data in raw_training_data_for_scalers:
        features_for_scaler_fitting['vibration_features'].append(
            system.data_processor.process_vibration(sample_raw_data['vibration_raw_unprocessed'][:VIBRATION_RAW_SEQ_LEN]) # 截取或填充到固定长度
        )
        features_for_scaler_fitting['temperature'].append(
            system.data_processor.process_temperature(sample_raw_data['temperature_series_unprocessed'])
        )
        features_for_scaler_fitting['pressure'].append(
            system.data_processor.process_pressure(sample_raw_data['pressure_series_unprocessed'], fs=PRESSURE_SAMPLING_RATE)
        )
        features_for_scaler_fitting['blade_angle'].append(
            system.data_processor.process_blade_angle(sample_raw_data['blade_angle_series_unprocessed'], sample_raw_data['target_blade_angle_unprocessed'])
        )
        features_for_scaler_fitting['oil_particles'].append( # 注意:oil_particles的scaler将基于 [particles, viscosity] 数组
            system.data_processor.process_oil_analysis(sample_raw_data['oil_particles_val_unprocessed'], sample_raw_data['oil_viscosity_val_unprocessed'])
        )
        features_for_scaler_fitting['current'].append(
            system.data_processor.process_current(sample_raw_data['current_signal_unprocessed'])
        )
    
    system.data_processor.fit_scalers(features_for_scaler_fitting)
    system.data_processor.save_scalers()


    print("\n--- (占位符) 训练模型 ---")
    # --- DummyDataset 现在需要生成更接近真实场景的数据 ---
    class AdvancedDummyDataset(torch.utils.data.Dataset):
        def __init__(self, num_samples, num_classes, data_processor_ref: DataProcessor, for_scaler_data):
            self.num_samples = num_samples
            self.num_classes = num_classes
            self.data_processor = data_processor_ref # 引用已拟合缩放器的 DataProcessor
            self.raw_sensor_data_list = for_scaler_data # 使用之前为scaler准备的原始数据作为基础

            # 生成标签
            self.labels = torch.randint(0, num_classes, (num_samples,))

        def __len__(self): return self.num_samples

        def __getitem__(self, idx):
            # 从预生成的原始数据中获取一个样本(循环使用如果 num_samples > len(raw_sensor_data_list))
            raw_sample_data = self.raw_sensor_data_list[idx % len(self.raw_sensor_data_list)]

            # 1. CNN的原始振动数据 (填充/截断)
            vib_raw_unpr = raw_sample_data['vibration_raw_unprocessed']
            if len(vib_raw_unpr) < VIBRATION_RAW_SEQ_LEN:
                cnn_vib_input_np = np.pad(vib_raw_unpr, (0, VIBRATION_RAW_SEQ_LEN - len(vib_raw_unpr)), 'constant')
            else:
                cnn_vib_input_np = vib_raw_unpr[:VIBRATION_RAW_SEQ_LEN]
            
            # 2. LSTM的特征序列
            #    为简单起见,我们这里为每个样本只生成一个时间步的特征,然后复制它形成序列
            #    在真实场景中,您需要处理真正的时序特征
            temp_f = self.data_processor.normalize_features(self.data_processor.process_temperature(raw_sample_data['temperature_series_unprocessed']), 'temperature')
            pres_f = self.data_processor.normalize_features(self.data_processor.process_pressure(raw_sample_data['pressure_series_unprocessed'], fs=PRESSURE_SAMPLING_RATE), 'pressure')
            curr_f_all = self.data_processor.normalize_features(self.data_processor.process_current(raw_sample_data['current_signal_unprocessed']), 'current')
            
            lstm_single_step_features_np = np.concatenate([temp_f, pres_f, curr_f_all[:1]]) # 5+4+1 = 10
            # 复制单步特征形成序列
            lstm_feature_sequence_np = np.tile(lstm_single_step_features_np, (LSTM_SEQ_LEN, 1))

            # 3. Electrical模型的特征
            electrical_features_np = curr_f_all # 使用完整的电流特征 (假设是2个)

            # 转换为张量
            processed_data_dict = {
                'vibration_raw': torch.tensor(cnn_vib_input_np, dtype=torch.float32),
                'lstm_input_feature_sequence': torch.tensor(lstm_feature_sequence_np, dtype=torch.float32),
                'electrical_features': torch.tensor(electrical_features_np, dtype=torch.float32)
            }
            return processed_data_dict, self.labels[idx]

    # 使用 AdvancedDummyDataset
    # 注意:data_processor 实例现在被传递给 Dataset,因为它包含了已拟合的 scalers
    train_dataset = AdvancedDummyDataset(num_training_samples, system.num_classes, system.data_processor, raw_training_data_for_scalers)
    # 对于验证集,理想情况下也应该有一组独立的原始数据
    val_dataset = AdvancedDummyDataset(num_training_samples // 2, system.num_classes, system.data_processor, raw_training_data_for_scalers[:num_training_samples//2]) 
    
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16 if num_training_samples >=16 else 1, shuffle=True, num_workers=0) # num_workers=0 for simplicity
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=16 if num_training_samples//2 >=16 else 1, num_workers=0)
    
    print("\n--- 开始实际训练(使用虚拟数据) ---")
    system.train(train_loader, val_loader, epochs=3) # 仅训练几个epoch作为演示
    
    print("\n--- 加载最佳模型 (如果训练中保存了) 并进行预测 ---")
    system_for_prediction = FaultDiagnosisSystem()
    try:
        system_for_prediction.load_models() # 使用默认路径加载
    except Exception as e:
        print(f"无法加载已训练的模型 (可能是因为验证准确率未提高,未保存最佳模型): {e}")
        print("为了预测结构演示,继续使用新初始化的模型 (需要加载scaler)。")
        system_for_prediction.data_processor.load_scalers() # 确保scaler被加载

    print("\n--- 对新样本数据进行预测 ---")
    new_vibration_raw = np.random.normal(0.5, 0.2, VIBRATION_RAW_SEQ_LEN - 50) # 测试填充
    new_temperature_series = np.array([70, 72, 130, 73, 74.5])
    new_pressure_series = np.random.normal(105, 12, 200)
    new_blade_angle_series = np.array([44, 44.5, 45, 45.1, 44.8])
    new_oil_particles = 55.0
    new_oil_viscosity = 78.0
    new_current_signal = np.array([51.0, 48.5, 50.5])
    target_angle_setting = 45.0
    current_design_pressure = 100.0
    
    try:
        prediction_result = system_for_prediction.predict(
            vibration_raw=new_vibration_raw,
            temperature_series=new_temperature_series,
            pressure_series=new_pressure_series,
            blade_angle_series=new_blade_angle_series,
            oil_particles_val=new_oil_particles,
            oil_viscosity_val=new_oil_viscosity,
            current_signal=new_current_signal,
            target_blade_angle=target_angle_setting,
            design_pressure=current_design_pressure
        )
        print("\n预测的故障概率:")
        for fault_name, probability in sorted(prediction_result.items(), key=lambda item: -item[1]):
            if probability > 0.001:
                print(f"  {fault_name}: {probability:.4f}")
        predicted_fault_idx = np.argmax(list(prediction_result.values()))
        print(f"主要预测故障: {system_for_prediction.data_processor.fault_types[predicted_fault_idx]}")
    except RuntimeError as e:
        if "的缩放器必须在预测前拟合" in str(e):
            print(f"预测失败: {e}。如果缩放器未正确拟合/加载,这是预期的。")
        else:
            print(f"预测期间发生运行时错误: {e}")
    except Exception as e:
        import traceback
        print(f"预测期间发生意外错误:\n{traceback.format_exc()}")

网站公告

今日签到

点亮在社区的每一天
去签到