YOLO11目标检测运行推理简约GUI界面

发布于:2025-09-12 ⋅ 阅读:(20) ⋅ 点赞:(0)

YOLO11推理简约GUI界面

使用方法:

支持pt和onnx格式模型,并且自动检测设备,选择推理设备

选择推理图片所在的文件夹
选择推理后的结果保存地址

选择所需要的置信度阈值

点击开始推理,程序自动运行 并在下方实时显示推理进度

非常方便不用每次都改代码来推理了

界面如下所示:

代码如下:

# -*- coding: utf-8 -*-
import tkinter as tk
from tkinter import ttk, filedialog, messagebox
import os, sys, threading, subprocess
from pathlib import Path
import cv2
import torch
from ultralytics import YOLO


class App(tk.Tk):
    def __init__(self):
        super().__init__()
        self.title("YOLOv11 批量推理 (支持 .pt / .onnx)")
        self.geometry("540x480")
        self.resizable(False, False)

        # 设备信息显示
        self.device_info = self.get_device_info()
        tk.Label(self, text=f"检测到的设备: {self.device_info}", fg="blue").place(x=20, y=5)

        # ---- 权重文件 ----
        tk.Label(self, text="权重文件 (.pt / .onnx):").place(x=20, y=40)
        self.ent_w = tk.Entry(self, width=45)
        self.ent_w.place(x=180, y=40)
        tk.Button(self, text="浏览", command=lambda: self.browse_file(self.ent_w, [("模型文件", "*.pt *.onnx")])).place(
            x=460, y=36)

        # ---- 图片文件夹 ----
        tk.Label(self, text="图片文件夹:").place(x=20, y=80)
        self.ent_i = tk.Entry(self, width=45)
        self.ent_i.place(x=180, y=80)
        tk.Button(self, text="浏览", command=lambda: self.browse_directory(self.ent_i)).place(x=460, y=76)

        # ---- 输出文件夹 ----
        tk.Label(self, text="结果保存到:").place(x=20, y=120)
        self.ent_o = tk.Entry(self, width=45)
        self.ent_o.place(x=180, y=120)
        tk.Button(self, text="浏览", command=lambda: self.browse_directory(self.ent_o)).place(x=460, y=116)

        # ---- 置信度 ----
        tk.Label(self, text="置信度阈值:").place(x=20, y=160)
        self.scale_conf = tk.Scale(self, from_=0.01, to=1.0, resolution=0.01,
                                   orient=tk.HORIZONTAL, length=300)
        self.scale_conf.set(0.35)
        self.scale_conf.place(x=180, y=140)

        # ---- 设备选择 ----
        tk.Label(self, text="推理设备:").place(x=20, y=200)
        self.device_var = tk.StringVar(value="auto")
        devices = self.get_available_devices()
        self.device_combo = ttk.Combobox(self, textvariable=self.device_var, values=devices, width=15, state="readonly")
        self.device_combo.place(x=180, y=200)

        # ---- 复选框 ----
        self.var_empty = tk.BooleanVar(value=True)
        self.var_box = tk.BooleanVar(value=True)
        self.var_recursive = tk.BooleanVar(value=False)

        tk.Checkbutton(self, text="保存无目标的图片", variable=self.var_empty).place(x=20, y=240)
        tk.Checkbutton(self, text="在结果图片上画框", variable=self.var_box).place(x=220, y=240)
        tk.Checkbutton(self, text="递归子文件夹", variable=self.var_recursive).place(x=20, y=270)

        # ---- 运行按钮 / 进度条 ----
        self.btn_run = tk.Button(self, text="开始推理", width=15, command=self.run_thread)
        self.btn_run.place(x=20, y=310)
        self.pb = ttk.Progressbar(self, length=480, mode='determinate')
        self.pb.place(x=20, y=350)

        # ---- 日志 ----
        self.txt = tk.Text(self, height=6, width=70, state="disabled")
        self.txt.place(x=20, y=380)

    def get_device_info(self):
        """获取设备信息"""
        if torch.cuda.is_available():
            gpu_count = torch.cuda.device_count()
            gpu_name = torch.cuda.get_device_name(0)
            return f"GPU: {gpu_name} ({gpu_count}个)"
        else:
            return "CPU only"

    def get_available_devices(self):
        """获取可用设备列表"""
        devices = ["auto"]
        if torch.cuda.is_available():
            for i in range(torch.cuda.device_count()):
                devices.append(f"cuda:{i}")
        devices.append("cpu")
        return devices

    def browse_file(self, entry, filetypes):
        """浏览文件"""
        f = filedialog.askopenfilename(filetypes=filetypes)
        if f:
            entry.delete(0, tk.END)
            entry.insert(0, f)

    def browse_directory(self, entry):
        """浏览目录"""
        f = filedialog.askdirectory()
        if f:
            entry.delete(0, tk.END)
            entry.insert(0, f)

    def log(self, msg):
        """日志输出"""
        self.txt.configure(state="normal")
        self.txt.insert(tk.END, msg + "\n")
        self.txt.see(tk.END)
        self.txt.configure(state="disabled")
        self.update()

    # ---------- 推理 ----------
    def run_thread(self):
        """启动推理线程"""
        if not self.validate():
            return
        self.btn_run.config(state="disabled")
        self.pb["value"] = 0
        threading.Thread(target=self.infer, daemon=True).start()

    def validate(self):
        """验证输入"""
        for e in (self.ent_w, self.ent_i, self.ent_o):
            if not e.get():
                messagebox.showerror("提示", "请完整填写路径!")
                return False

        w_path = Path(self.ent_w.get())
        if not w_path.exists():
            messagebox.showerror("错误", "权重文件不存在!")
            return False
        if w_path.suffix.lower() not in ['.pt', '.onnx']:
            messagebox.showerror("错误", "只支持 .pt 或 .onnx 格式的权重文件!")
            return False

        i_path = Path(self.ent_i.get())
        if not i_path.exists():
            messagebox.showerror("错误", "图片文件夹不存在!")
            return False

        return True

    def infer(self):
        """执行推理"""
        try:
            # 获取设备设置
            device_choice = self.device_var.get()
            if device_choice == "auto":
                device = "0" if torch.cuda.is_available() else "cpu"
            else:
                device = device_choice

            w_path = self.ent_w.get()
            ext = Path(w_path).suffix.lower()

            self.log(f"正在加载模型,使用设备: {device}...")

            # 关键修改:初始化时不传递device参数[1,3,4](@ref)
            model = YOLO(w_path)

            # 对于PyTorch模型,使用.to()方法迁移到指定设备[6,7](@ref)
            if ext == '.pt' and device != 'cpu':
                model.to(device)
                self.log(f"PyTorch模型已迁移到设备: {device}")
            elif ext == '.onnx':
                self.log("注意: ONNX模型当前使用CPU推理,如需GPU加速请使用TensorRT转换")
                device = 'cpu'

            in_dir = Path(self.ent_i.get())
            out_dir = Path(self.ent_o.get())
            out_dir.mkdir(parents=True, exist_ok=True)

            # 收集图片文件
            pattern = '​**​/*' if self.var_recursive.get() else '*'
            img_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
            imgs = [p for p in in_dir.glob(pattern)
                    if p.suffix.lower() in img_extensions and p.is_file()]

            total = len(imgs)
            if total == 0:
                messagebox.showwarning("警告", "未找到任何图片!")
                return

            self.pb["maximum"] = total
            self.log(f"找到 {total} 张图片,开始推理...")

            # 批量推理
            for idx, img_p in enumerate(imgs, 1):
                rel_path = img_p.relative_to(in_dir) if in_dir in img_p.parents else Path()
                save_dir = out_dir / rel_path.parent
                save_dir.mkdir(parents=True, exist_ok=True)
                save_img = save_dir / f"{img_p.stem}_result.jpg"

                # 执行推理,在predict方法中指定设备[1,3,4](@ref)
                results = model.predict(
                    source=str(img_p),
                    conf=self.scale_conf.get(),
                    save=False,
                    verbose=False,
                    device=device  # 在这里指定设备
                )

                result = results[0]
                if len(result.boxes) == 0 and not self.var_empty.get():
                    continue

                # 处理结果图像
                img_out = result.plot() if self.var_box.get() else result.orig_img
                cv2.imwrite(str(save_img), img_out)

                # 更新进度
                self.pb["value"] = idx
                self.log(f"[{idx:03d}/{total:03d}] {img_p.name}")

            # 完成后打开结果文件夹
            subprocess.Popen(f'explorer "{out_dir}"')
            messagebox.showinfo("完成", f"推理完成!处理了 {total} 张图片。")

        except Exception as e:
            error_msg = f"推理错误: {str(e)}"
            self.log(error_msg)
            messagebox.showerror("错误", error_msg)
        finally:
            self.btn_run.config(state="normal")
            self.pb["value"] = 0


if __name__ == "__main__":
    # 在打包环境下调整路径
    if getattr(sys, 'frozen', False):
        os.chdir(os.path.dirname(sys.executable))
    app = App()
    app.mainloop()


网站公告

今日签到

点亮在社区的每一天
去签到