kmeans实现图像像素分类

发布于:2024-04-25 ⋅ 阅读:(19) ⋅ 点赞:(0)

代码

import tkinter as tk

from tkinter import filedialog

from PIL import Image, ImageTk

import numpy as np
import random
import math

class Cluster(object):

    def __init__(self):
        # pixels是像素的意思,这里定义一个像素组用来存放像素的值
        self.pixels = []
        # 创建一个中心点
        self.centroid = None
    def addPoint(self, pixels):

        # 将像素值存放到这个空的pixels数组当中
        self.pixels.append(pixels)

    def setNewCentroid(self):

        # 这里通道R在像素点中进行轮寻R就会得到一个数组,里面是所有像素点R通道的像素值
        R = [colour[0] for colour in self.pixels]
        G = [colour[1] for colour in self.pixels]
        B = [colour[2] for colour in self.pixels]
        # 求R,G,B所有像素点的平均值
        R = sum(R) / len(R)
        R= round(R)
        G = sum(G) / len(G)
        G = round(G)
        B = sum(B) / len(B)
        B = round(B)

        self.centroid = (R,G,B)

        return self.centroid
class Kmeans(object):

    # 初始化k个簇,最大迭代次数,阈值用来判断中心点与上一代中心点的误差,小于就结束,图片大小
    def __init__(self, k=10, max_iteration=10, min_distance=5.0, size=200):

        self.k = k
        self.max_iterations = max_iteration
        self.min_distance = min_distance
        self.size = (size, size)

    def run(self, image):
        self.image = image
        #将图像缩放到指定大小self.size
        self.image.thumbnail(self.size)
        # 将image转化为数组
        self.p = np.array(image)
        # 打印出来的是每个像素的数值[113, 110,  75]这是一个像素点RGB值
        self.pixels = np.array(image.getdata(), dtype=np.uint8)
        # return self.pixels,self.p
        # 创建了一个长度为 self.k 的列表,其中每个元素都被初始化为 None。这里,self.k 是一个类的属性,代表了你想要创建的簇(clusters)的数量。
        self.clusters = [None for i in range(self.k)]
        self.oldClusters = None

        #  self.pixels 数组中随机选择 self.k 个像素点,并将这些像素点的值存储到 randomPixels 列表中
        randomPixels = random.sample(list(self.pixels), self.k)

        # 这里循环每个簇
        for idx in range(self.k):
            self.clusters[idx] = Cluster()
            self.clusters[idx].centroid = randomPixels[idx]

        iterations = 0

        while self.shouldExit(iterations) is False:

            self.oldClusters = [cluster.centroid for cluster in self.clusters]

            for pixel in self.pixels:
                self.assignClusters(pixel)

            for cluster in self.clusters:
                cluster.setNewCentroid()

            iterations += 1

        return [cluster.centroid for cluster in self.clusters]
    #分配簇,将像素分配到簇
    def assignClusters(self, pixel):
        # 可能是用来比较的shortest = float('Inf')这是设定距离为无穷大
        shortest = float('Inf')
        for cluster in self.clusters:
            distance = self.calcDistance(cluster.centroid, pixel)
            if distance < shortest:
                shortest = distance
                nearest = cluster

        nearest.addPoint(pixel)

    # 计算像素到中心点的欧式距离
    def calcDistance(self, a, b):
        # 计算欧氏距离
        result = math.sqrt((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2 + (a[2] - b[2]) ** 2)
        # result = np.sqrt(sum((a-b) ** 2))
        return result

    # 迭代结束
    def shouldExit(self, iterations):

        if self.oldClusters is None:
            return False

        for idx in range(self.k):
            dist = self.calcDistance(
                np.array(self.clusters[idx].centroid),
                np.array(self.oldClusters[idx])
            )
            if dist <self.min_distance:
                return True

        if iterations <= self.max_iterations:
            return False

        return True

    # 显示原图像
    def showImage(self):
        self.image.show()
    # 显示每个200*200的图像,颜色是每个聚类中心像素
    def showCentroidColours(self):

        # 创建cluster * 200大小的图像
        total_width = len(self.clusters) * 200
        # 整体高度是200
        total_height = 200
        big_image = Image.new("RGB", (total_width, total_height), "white")
        # 计算每个小图片在大图上的位置
        x_offset = 0
        y_offset = 0
        for cluster in self.clusters:
            image = Image.new("RGB", (200, 200), cluster.centroid)
            # 将小图片粘贴到大图上
            big_image.paste(image, (x_offset, y_offset))
            # 更新 x 偏移量以准备下一个图片的位置
            x_offset += 200
        # new_width = 1000
        # new_height = 400
        # new_image = Image.new('RGB', (new_width, new_height), 'white')
        #
        # big_image = np.array(big_image)
        # image = np.concatenate((self.p, big_image), axis=0)
        # image = np.vstack((image, big_image))
        # big_image.show()
        big_image = big_image.resize((500, 30))
        return big_image

        # y_offset = 0
        # for img in zip([image, big_image]):
        #     new_image.paste(img, (0, y_offset))
        #     y_offset = y_offset + 1
        # new_image.show()


        # new_image.show()

    # 颜色图像显示
    def showClustering(self):
        # 创建一个与localPixels相同长度的
        localPixels =[None] * len(self.image.getdata())

        for idx, pixel in enumerate(self.image.getdata()):
            shortest = float('Inf')
            for cluster in self.clusters:
                distance =self.calcDistance(cluster.centroid, pixel)
                if distance < shortest:
                    shortest = distance
                    nearest = cluster

            localPixels[idx] = nearest.centroid

        w,h = self.image.size
        # 将localPixel转换为一个大小为(h, w, 3)的图像
        localPixels = np.asarray(localPixels)\
            .astype('uint8')\
            .reshape((h, w, 3))

        # 颜色图像显示
        colourMap = Image.fromarray(localPixels)
        colourMap = colourMap.resize((200, 200))
        return colourMap
        # colourMap.show()
# 初始化Tkinter窗口

root = tk.Tk()

root.title("图片处理GUI")

# 全局变量用于存储图片和图片数据
image_path = None

image_data = None

def open_image():
    global image_path, image_data, image

    # 打开文件对话框,选择图片文件

    file_path = filedialog.askopenfilename(title="选择图片", filetypes=[("图像文件", "*.png;*.jpg;*.jpeg;*.bmp;*.gif")])

    if file_path:
        # 使用PIL打开图片
        image = Image.open(file_path)
        # 转换为Tkinter可以显示的格式
        image = image.resize((200, 200))

        tk_image = ImageTk.PhotoImage(image)

        # 展示图片

        label_image.config(image=tk_image)
        label_image.config(padx=10, pady=5)

        label_image.image = tk_image

        # 存储图片路径和图片数据(转换为numpy数组)

        image_path = file_path

        image_data = np.array(image)


def apply_kmeans():
    global image_data, image_path

    if image_data is not None:
        # 将图片数据重塑为二维数组,每行是一个像素,每列是RGB值
        image = Image.open(image_path)
        image = image.resize((200, 200))
        # 初始化自定义的Kmeans类并运行算法
        k = Kmeans()
        # k.showImage()
        k.run(image)
        segmented_image_pil = k.showCentroidColours()

        # 展示处理后的图片
        segmented_tk_image = ImageTk.PhotoImage(segmented_image_pil)
        label_segmented.config(image=segmented_tk_image)
        label_segmented.image = segmented_tk_image
    else:
        print("请先打开一张图片!")
    # 创建按钮来打开图片

def apply_kmeans1():
    global image_data, image_path

    if image_data is not None:
        # 将图片数据重塑为二维数组,每行是一个像素,每列是RGB值
        image = Image.open(image_path)
        image = image.resize((200, 200))
        # 初始化自定义的Kmeans类并运行算法
        k = Kmeans()
        # k.showImage()
        k.run(image)
        segmented_image_pil = k.showClustering()

        # 展示处理后的图片
        segmented_tk_image = ImageTk.PhotoImage(segmented_image_pil)
        label_segmented1.config(image=segmented_tk_image)
        label_segmented1.image = segmented_tk_image
    else:
        print("请先打开一张图片!")
    # 创建按钮来打开图片


button_open = tk.Button(root, text="打开图片", command=open_image)
# 使用 grid 布局,并指定在第0行第0列
button_open.grid(row=0, column=0, sticky="ew")

# 创建标签来展示原始图片

label_image = tk.Label(root)
# sticky="news" 表示填充所有方向
label_image.grid(row=1, column=0, sticky="news")

# 创建按钮来应用K-means算法

button_kmeans = tk.Button(root, text="应用K-means", command=apply_kmeans)

button_kmeans.grid(row=0, column=1, sticky="ew")

# 创建标签来展示K-means处理后的图片

label_segmented = tk.Label(root)

label_segmented.grid(row=1, column=1, sticky="news")


button_kmeans1 = tk.Button(root, text="返回图像", command=apply_kmeans1)

button_kmeans1.grid(row=0, column=2, sticky="ew")


label_segmented1 = tk.Label(root)

label_segmented1.grid(row=1, column=2, sticky="news")

# 运行Tkinter事件循环
root.columnconfigure(0, weight=1)
root.columnconfigure(1, weight=1)
root.columnconfigure(2, weight=1)


root.mainloop()