分类数据这里介绍。
digitDatasetPath = fullfile('D:','MATLAB WORK', 'PR','ImageNet');
imds = imageDatastore(digitDatasetPath, ...
IncludeSubfolders=true,LabelSource="foldernames");
imds.Labels = renamecats(imds.Labels, {'n01514668', 'n02129604','n02325366'}, {'chicken', 'tiger','rabbit'});
numObsPerClass = countEachLabel(imds)
totalImages = sum(numObsPerClass.Count);
% classNames = categories(imds.Labels)
% numClasses = numel(classNames);
%%
[net,classNames] = imagePretrainedNetwork("squeezenet");
% [net,classNames] = imagePretrainedNetwork("resnet101");
% analyzeNetwork(net)
%%
obj = randi(totalImages);
img = readimage(imds,obj);
X = single(img);
if canUseGPU
X = gpuArray(X);
end
scores = predict(net,X);
[label,score] = scores2label(scores,classNames);
[~,idx] = sort(scores,"descend");
idx = idx(5:-1:1);
classNamesTop = classNames(idx);
scoresTop = scores(idx);
%%
figure('Position',[868 204 800 320]);
tiledlayout(1,2,"TileSpacing","tight")
nexttile
imshow(img)
title(string(label) + "(" + string(score)+ ")" )
nexttile
barh(scoresTop)
xlim([0 1])
title("Top 5 Predictions")
xlabel("Probability")
yticklabels(classNamesTop)