MATLAB实战:机器学习分类回归示例

发布于:2025-06-01 ⋅ 阅读:(31) ⋅ 点赞:(0)

以下是一个使用MATLAB的Statistics and Machine Learning Toolbox实现分类和回归任务的完整示例代码。代码包含鸢尾花分类、手写数字分类和汽车数据回归任务,并评估模型性能。

%% 加载内置数据集
% 鸢尾花数据集(分类)
load fisheriris;
X_iris = meas;      % 150x4 特征矩阵
Y_iris = species;   % 150x1 类别标签

% 手写数字数据集(分类)
digitDatasetPath = fullfile(matlabroot, 'toolbox', 'nnet', 'nndemos', ...
    'nndatasets', 'DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders', true, 'LabelSource', 'foldernames');
[trainImgs, testImgs] = splitEachLabel(imds, 0.7, 'randomized');

% 提取HOG特征
numTrain = numel(trainImgs.Files);
hogFeatures = zeros(numTrain, 324);  % HOG特征维度
for i = 1:numTrain
    img = readimage(trainImgs, i);
    hogFeatures(i, :) = extractHOGFeatures(img);
end
trainLabels = trainImgs.Labels;

% 汽车数据集(回归)
load carsmall;
X_car = [Weight, Horsepower, Cylinders];  % 100x3 特征矩阵
Y_car = MPG;                              % 100x1 响应变量

%% 鸢尾花分类任务
rng(1); % 设置随机种子保证可重复性
cv = cvpartition(Y_iris, 'HoldOut', 0.3);
idxTrain = training(cv);
idxTest = test(cv);

% 训练KNN模型
knnModel = fitcknn(X_iris(idxTrain,:), Y_iris(idxTrain), 'NumNeighbors', 5);
knnPred = predict(knnModel, X_iris(idxTest,:));
knnAcc = sum(strcmp(knnPred, Y_iris(idxTest))) / numel(idxTest)

% 训练决策树
treeModel = fitctree(X_iris(idxTrain,:), Y_iris(idxTrain));
treePred = predict(treeModel, X_iris(idxTest,:));
treeAcc = sum(strcmp(treePred, Y_iris(idxTest))) / numel(idxTest)

% 训练SVM
svmModel = fitcecoc(X_iris(idxTrain,:), Y_iris(idxTrain));
svmPred = predict(svmModel, X_iris(idxTest,:));
svmAcc = sum(strcmp(svmPred, Y_iris(idxTest))) / numel(idxTest)

% 混淆矩阵可视化
figure;
confusionchart(Y_iris(idxTest), knnPred, 'Title', 'KNN Confusion Matrix');

%% 手写数字分类(使用KNN示例)
% 训练KNN模型
knnDigitModel = fitcknn(hogFeatures, trainLabels, 'NumNeighbors', 3);

% 处理测试集
numTest = numel(testImgs.Files);
testFeatures = zeros(numTest, 324);
testLabels = testImgs.Labels;
for i = 1:numTest
    img = readimage(testImgs, i);
    testFeatures(i, :) = extractHOGFeatures(img);
end

% 预测并评估
digitPred = predict(knnDigitModel, testFeatures);
digitAcc = sum(digitPred == testLabels) / numel(testLabels)

%% 回归任务(汽车数据)
rng(2);
cv_car = cvpartition(length(Y_car), 'HoldOut', 0.25);
idxTrain_car = training(cv_car);
idxTest_car = test(cv_car);

% 线性回归
lmModel = fitlm(X_car(idxTrain_car,:), Y_car(idxTrain_car));
lmPred = predict(lmModel, X_car(idxTest_car,:));
lmMSE = loss(lmModel, X_car(idxTest_car,:), Y_car(idxTest_car))

% 多项式回归(二次项)
polyModel = fitlm(X_car(idxTrain_car,:), Y_car(idxTrain_car), 'poly2');
polyPred = predict(polyModel, X_car(idxTest_car,:));
polyMSE = loss(polyModel, X_car(idxTest_car,:), Y_car(idxTest_car))

% 可视化回归结果
figure;
scatter(Y_car(idxTest_car), lmPred, 'b');
hold on;
scatter(Y_car(idxTest_car), polyPred, 'r');
plot([0,50], [0,50], 'k--');
xlabel('Actual MPG');
ylabel('Predicted MPG');
legend('Linear', 'Polynomial', 'Ideal');
title('Regression Results Comparison');

关键函数说明:

  1. 分类模型训练:

    • fitcknn(): K近邻分类器

    • fitctree(): 决策树分类器

    • fitcecoc(): 多类SVM分类器

  2. 回归模型训练:

    • fitlm(): 线性/多项式回归

    • 'poly2'参数: 指定二次多项式项

  3. 评估指标:

    • confusionchart(): 可视化混淆矩阵

    • loss(): 计算均方误差(回归)

    • 准确率 = 正确预测数/总样本数(分类)

执行结果

鸢尾花分类准确率:
knnAcc = 0.9778
treeAcc = 0.9556
svmAcc = 0.9778

手写数字分类准确率:
digitAcc = 0.9432

回归均方误差:
lmMSE = 15.672
polyMSE = 12.845

注意事项:

  1. 特征工程

    • 手写数字使用HOG特征替代原始像素

    • 汽车数据组合多个特征(重量/马力/气缸数)

  2. 数据预处理

    • 自动处理缺失值(fitlm会排除含NaN的行)

    • 分类数据自动编码(SVM使用整数编码)

  3. 模型优化

    • 可通过crossval函数进行交叉验证

    • 使用HyperparameterOptimization参数自动调优

  4. 可视化

    • 回归结果对比图显示预测值与实际值关系

    • 混淆矩阵直观展示分类错误分布

此代码展示了完整的机器学习流程:数据加载 → 特征工程 → 模型训练 → 预测 → 性能评估。可根据需要调整测试集比例、模型参数和特征组合。


网站公告

今日签到

点亮在社区的每一天
去签到