基于鸢尾花数据集实施自组织神经网络聚类分析

发布于:2024-05-10 ⋅ 阅读:(21) ⋅ 点赞:(0)

基于鸢尾花数据集实施自组织神经网络聚类分析

1. 自组织神经网络的基础知识

自组织神经网络也称自组织映射(SOM)或自组织特征映射(SOFM),是一种使用非监督式学习来产生训练样本的输入空间的一个低维(通常是二维)离散化的表示的人工神经网络(ANN)。自组织映射与其他人工神经网络的不同之处在于它使用一个邻近函数来保持输入空间的拓扑性质。
在这里插入图片描述
在这里插入图片描述

2. 鸢尾花数据集的自组织分类

# 导入必要的库
import numpy as np
from minisom import MiniSom
import matplotlib.pyplot as plt
from sklearn import datasets

# 载入鸢尾花数据集
iris = datasets.load_iris()
data = iris.data
labels = iris.target

# 数据归一化
data = (data - np.min(data, axis=0)) / (np.max(data, axis=0) - np.min(data, axis=0))

# 定义 SOM 网络的参数
som_shape = (10, 10)  # SOM 网格的形状
som = MiniSom(som_shape[0], som_shape[1], data.shape[1], sigma=1.0, learning_rate=0.5)

# 初始化权重并开始训练
som.random_weights_init(data)
som.train_random(data, 100)  # 100 次迭代
# 创建 U-matrix
umatrix = som.distance_map()
print(umatrix)
# 绘制 U-matrix
plt.figure(figsize=(4,4))
plt.pcolor(umatrix.T, cmap='bone_r', alpha=0.8)
plt.colorbar()

# 绘制聚类中心
for i, target in enumerate(labels):
    x, y = som.winner(data[i])
    plt.text(x + 0.5, y + 0.5, str(target), color=plt.cm.rainbow(target / 2.0), fontdict={'weight': 'bold', 'size': 11})

plt.xticks(np.arange(som_shape[0] + 1))
plt.yticks(np.arange(som_shape[1] + 1))
plt.grid()
plt.show()

聚类结果
在这里插入图片描述

新聚类

import math
import numpy as np
from minisom import MiniSom
from sklearn import datasets
from numpy import sum as npsum
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.gridspec import GridSpec


# 分类函数
def classify(som,data,winmap):
    default_class = npsum(list(winmap.values())).most_common()[0][0]
    result = []
    for d in data:
        win_position = som.winner(d)
        if win_position in winmap:
            result.append(winmap[win_position].most_common()[0][0])
        else:
            result.append(default_class)
    return result


# 可视化
def show(som):
    """ 在输出层画标签图案 """
    plt.figure(figsize=(5, 5))
    # 定义不同标签的图案标记
    markers = ['o', 's', 'D']
    colors = ['C0', 'C1', 'C2']
    category_color = {'setosa': 'C0', 'versicolor': 'C1', 'virginica': 'C2'}

    # 背景上画U-Matrix
    heatmap = som.distance_map()
    # 画背景图
    plt.pcolor(heatmap, cmap='bone_r')

    for cnt, xx in enumerate(X_train):
        w = som.winner(xx)
        # 在样本Heat的地方画上标记
        plt.plot(w[0] + .5, w[1] + .5, markers[Y_train[cnt]], markerfacecolor='None',
                 markeredgecolor=colors[Y_train[cnt]], markersize=12, markeredgewidth=2)

    plt.axis([0, size, 0, size])
    ax = plt.gca()
    # 颠倒y轴方向
    ax.invert_yaxis()
    legend_elements = [Patch(facecolor=clr, edgecolor='w', label=l) for l, clr in category_color.items()]
    plt.legend(handles=legend_elements, loc='center left', bbox_to_anchor=(1, .95))
    plt.show()

    # """ 在每个格子里画饼图,且用颜色表示类别,用数字表示总样本数量 """
    # plt.figure(figsize=(16, 16))
    # the_grid = GridSpec(size, size)

    # for position in winmap.keys():
    #     label_fracs = [winmap[position][label] for label in [0, 1, 2]]
    #     plt.subplot(the_grid[position[1], position[0]], aspect=1)
    #     patches, texts = plt.pie(label_fracs)
    #     plt.text(position[0] / 100, position[1] / 100, str(len(list(winmap[position].elements()))),
    #              color='black', fontdict={'weight': 'bold', 'size': 15}, va='center', ha='center')
    # plt.legend(patches, class_names, loc='center right', bbox_to_anchor=(-1, 9), ncol=3)
    # plt.show()


if __name__ == '__main__':
    # 导入数据集
    iris = datasets.load_iris()
    # 提取iris数据集的标签与数据
    feature_names = iris.feature_names
    class_names = iris.target_names
    X = iris.data
    Y = iris.target
    # 划分训练集、测试集  7:3
    X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.3, random_state=0)

    # 样本数量
    N = X_train.shape[0]
    # 维度/特征数量
    M = X_train.shape[1]
    # 最大迭代次数
    max_iter = 200
    # 经验公式:决定输出层尺寸
    size = math.ceil(np.sqrt(5 * np.sqrt(N)))
    print("训练样本个数:{}  测试样本个数:{}".format(N, X_test.shape[0]))
    print("输出网格最佳边长为:", size)

    # 初始化模型
    som = MiniSom(size, size, M, sigma=3, learning_rate=0.5, neighborhood_function='bubble')
    # 初始化权值
    som.pca_weights_init(X_train)
    # 模型训练
    som.train_batch(X_train, max_iter, verbose=False)

    # 利用标签信息,标注训练好的som网络
    winmap = som.labels_map(X_train, Y_train)
    # 进行分类预测
    y_pred = classify(som, X_test, winmap)
    print(y_pred)
    # 展示在测试集上的效果
    print(classification_report(Y_test, np.array(y_pred)))

    # 可视化
    show(som)

在这里插入图片描述

3. SOM的无监督聚类

from minisom import MiniSom
import numpy as np
import pandas as pd
from sklearn import datasets
import matplotlib.pyplot as plt

# 加载鸢尾花数据
iris = datasets.load_iris()
data = iris.data
ds = pd.DataFrame(data)
dt=ds.values
som_shape = (1,3)
som = MiniSom(som_shape[0], som_shape[1], dt.shape[1], sigma=.5, learning_rate=.5,
              neighborhood_function='gaussian', random_seed=42)
som.train_batch(dt, 1000, verbose=True)
winner_coordinates = np.array([som.winner(x) for x in dt]).T
cluster_index = np.ravel_multi_index(winner_coordinates, som_shape)
## 可视化聚类结果
import matplotlib.pyplot as plt 
from matplotlib import rcParams
from matplotlib.pyplot import MultipleLocator

config = {
    "font.family": 'serif', # 衬线字体
    "font.size": 10, # 相当于小四大小
    "font.serif": ['SimSun'], # 宋体
    "mathtext.fontset": 'stix', # matplotlib渲染数学字体时使用的字体,和Times New Roman差别不大
    'axes.unicode_minus': False # 处理负号,即-号
}
rcParams.update(config)

fig = plt.figure(figsize=(5,5))
ax = plt.subplot(1,1,1,projection='3d')
# plotting the clusters using the first 2 dimentions of the data
for c in np.unique(cluster_index):
    ax.scatter3D(dt[cluster_index == c, 0],
                  dt[cluster_index == c, 1], 
                  dt[cluster_index == c, 2], 
                  label='cluster='+str(c), 
                  alpha=.7)

# plotting centroids
for centroid in som.get_weights():
    ax.scatter3D(centroid[:, 0], centroid[:, 1],centroid[:, 2], 
                 marker='x', alpha=.7,s=10, linewidths=10, color='k', 
                 label='centroid')
ax.set_xlabel('Production Function')		
ax.set_ylabel('Economy Function')
ax.set_zlabel('Social Function')
ax.set_title('SOM clustering result')
plt.legend()
# plt.margins(0,0,0)
plt.show()
# plt.savefig('./result.png',dpi=600)

在这里插入图片描述