knn实现手写数字识别

发布于:2022-12-01 ⋅ 阅读:(733) ⋅ 点赞:(0)

在这里插入图片描述

数据集

手写数字识别数据集 其中数据集特征包括

样本特征

raw 特征1 特征2 特征3 标签
9 0 0 10 9

该样本8个特征[ 0. 0. 10. 8. 8. 4. 0. 0.]

算法步骤

  1. 数据集导入
  2. 分析处理数据
  3. 训练数据
  4. 测试数据
  5. 计算模型准确率

数据集导入

这里使用的是sklearn官方的数据集 导入比较简单

from sklearn import datasets
digits = datasets.load_digits()

分析处理数据

将数据集分为训练集与测试集两部分 测试集用于训练生成的模型的准确率 其比例为8:2

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
    digits.data, digits.target, test_size=0.2)

训练数据

使用knn训练数据 knn原理为某个样本在空间中的k个最近的样本中的最多数属于某一个类别

对于knn距离使用欧拉距离

多维度欧拉公式为 d ( p , q ) = ( p 1 − q 1 ) 2 + ( p 2 − q 2 ) 2 + ⋯ + ( p i − q i ) 2 + ⋯ + ( p n − q n ) 2 . \displaystyle d(p,q)={\sqrt {(p_{1}-q_{1})^{2}+(p_{2}-q_{2})^{2}+\cdots +(p_{i}-q_{i})^{2}+\cdots +(p_{n}-q_{n})^{2}}}. d(p,q)=(p1q1)2+(p2q2)2++(piqi)2++(pnqn)2 .

公式用python表达也非常简单 np.sqrt(np.sum(np.square(X - x)))

对于训练集每个点使用下述方式计算距离

import numpy as np
dis = []
for X in X_train:
    dis.append(np.sqrt(np.sum(np.square(X - x))))
X_dis_idx = np.argsort(dis)

投票算法使用少数服从多数 y表示target sum表示vote数量

predict_target = {'y': 0, 'sum': 0}
for y in y_label:
    if not predict_target['sum'] < np.sum(y_train[X_dis_idx[:k]] == y):
        continue
    predict_target['y'] = y
    predict_target['sum'] = np.sum(y_train[X_dis_idx[:k]] == y)
# print(predict_target)

上述代码选择了最大的vote作为预测结果

测试数据

取前两个预测数据和测试数据比较下

y_predict = clsfy.predict(X_test)
print(y_predict[:2], y_test[:2])
raw 特征1 特征2 特征3 预测标签 标签
3 0 0 10 3 3
4 0 0 10 4 4

计算模型准确率

预测准确率接近99.3%

from sklearn.metrics import accuracy_score
accuracy_score(y_test, y_predict)

示例

import numpy as np
from matplotlib import pyplot as plt
from sklearn import datasets
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# np.random.seed(123)

digits = datasets.load_digits()
# print(digits.keys(), digits['feature_names'], digits['target_names'])

print(digits.data.shape)
# rand = np.random.randint(len(digits.data))
# print(digits.target[rand])
# plt.imshow(digits.data[rand].reshape(8, 8), cmap=matplotlib.cm.binary)
# plt.show()

X_train, X_test, y_train, y_test = train_test_split(
    digits.data, digits.target, test_size=0.2)
# print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)

clsf = KNeighborsClassifier(n_neighbors=6)
clsf.fit(X_train, y_train)
print(clsf.score(X_test, y_test))  # accuracy_score(y_test, y_predict)
y_predict = clsf.predict(X_test)
print(y_predict[:5], y_test[:5])
本文含有隐藏内容,请 开通VIP 后查看

网站公告


今日签到

点亮在社区的每一天
去签到