使用OpenCV训练自有模型的实践

发布于:2025-07-03 ⋅ 阅读:(20) ⋅ 点赞:(0)

技术博客:基于OpenCV与YOLOv8的自有模型训练实践

1. 项目背景与目标

在众多计算机视觉任务中,目标检测是一项核心且广泛应用于实际场景的技术。本项目旨在通过使用 OpenCV 和 YOLOv8 模型,实现对特定目标(如舌头)的检测。整个过程涵盖了数据准备、模型训练、验证评估及结果输出等多个环节。

2. 系统架构
2.1 数据集结构

数据集的组织结构直接影响到模型的训练效果。本项目的数据集结构如下:

dataset/
├── images/
│   ├── train/
│   │   ├── img1.jpg
│   │   └── img2.jpg
│   └── val/
│       ├── img1.jpg
│       └── img2.jpg
└── labels/
    ├── train/
    │   ├── img1.xml
    │   └── img2.xml
    └── val/
        ├── img1.txt
        └── img2.txt
  • images/:存放图像文件,分为 train(训练集)和 val(验证集)两个子目录。
  • labels/:存放对应的标注文件,格式为 YOLO 标注格式(.txt 文件)。
2.2 模型架构

本项目采用的是 YOLOv8 模型,该模型具有以下特点:

  • 高效性:YOLOv8 在保持高精度的同时,拥有较快的推理速度。
  • 灵活性:支持多种任务(如分类、检测、分割等),并且可以通过简单的配置进行定制。
  • 易用性:Ultralytics 提供了简洁的 API,方便用户快速上手。
3. 实现原理
3.1 YOLO 标注格式

YOLO 标注格式是一种简单且高效的标注方式,每个 .txt 文件对应一张图像,内容格式如下:

<class_id> <x_center> <y_center> <width> <height>
  • class_id:类别编号(从0开始)。
  • x_center, y_center:边界框中心点坐标(相对于图像宽度和高度的归一化值)。
  • width, height:边界框宽高(同样归一化)。
3.2 数据预处理

数据预处理是模型训练前的重要步骤,主要包括:

  • 图像读取与转换:使用 OpenCV 读取图像,并将其转换为适合模型输入的格式。
  • 标注文件解析:解析 YOLO 格式的标注文件,提取边界框和类别信息。
  • 数据增强:通过随机裁剪、翻转、颜色抖动等方法增加数据多样性,提升模型泛化能力。
3.3 模型训练

模型训练基于 YOLOv8 的标准流程,具体步骤如下:

  1. 加载模型:通过 YOLO('yolov8n.yaml') 或加载预训练模型来初始化模型。
  2. 配置数据集:编写 [custom_dataset.yaml]文件,指定训练集和验证集的路径。
  3. 启动训练:调用 model.train(data='custom_dataset.yaml') 方法开始训练。
train: ../images/train
val: ../images/val

nc: 1  # 类别数量,根据你的数据集修改
names: ['tongue']  # 类别名称列表,例如 ['car'], ['person'] 等

3.4 模型验证

模型训练完成后,需要对其进行验证以评估性能。主要步骤包括:

  • 加载模型:使用训练好的模型权重文件(如 [best.pt](file:///Users/franks/workspace/AI/ModelTrain/models/yolov8_custom_train/weights/best.pt))。
  • 验证评估:调用 model.val() 方法进行验证,并输出 mAP 等指标。
4. 技术实现细节
4.1 数据集准备

数据集的准备主要包括图像采集、标注及格式转换。具体实现可参考 [valImg.py]文件:

import os
import xml.etree.ElementTree as ET
import cv2

# 设置路径
ANNOTATION_DIR = 'data/labels/train'
IMAGE_DIR = 'data/images/train'
VAL_IMAGE_DIR = 'data/images/val'
VAL_LABEL_DIR = 'data/labels/val'

os.makedirs(VAL_IMAGE_DIR, exist_ok=True)
os.makedirs(VAL_LABEL_DIR, exist_ok=True)

def parse_voc_xml(xml_path):
    tree = ET.parse(xml_path)
    root = tree.getroot()

    filename = root.find('filename').text
    img_path = os.path.join(IMAGE_DIR, filename)

    # 支持常见图像格式(.jpg .png .jpeg .webp)
    for ext in ['.jpg', '.png', '.jpeg', '.webp']:
        if os.path.exists(img_path.replace('.webp', ext)):
            img_path = img_path.replace('.webp', ext)
            break

    if not os.path.exists(img_path):
        print(f"Image not found: {img_path}")
        return None

    size = root.find('size')
    img_w = int(size.find('width').text)
    img_h = int(size.find('height').text)

    objects = []
    for obj in root.findall('object'):
        name = obj.find('name').text
        bndbox = obj.find('bndbox')
        xmin = int(bndbox.find('xmin').text)
        ymin = int(bndbox.find('ymin').text)
        xmax = int(bndbox.find('xmax').text)
        ymax = int(bndbox.find('ymax').text)

        objects.append({
            'class': name,
            'bbox': [xmin, ymin, xmax, ymax]
        })

    return img_path, img_w, img_h, objects

# 处理所有 XML 文件
for xml_file in os.listdir(ANNOTATION_DIR):
    if not xml_file.endswith('.xml'):
        continue

    xml_path = os.path.join(ANNOTATION_DIR, xml_file)
    result = parse_voc_xml(xml_path)
    if not result:
        continue

    img_path, img_w, img_h, objects = result
    img = cv2.imread(img_path)

    for i, obj in enumerate(objects):
        cls_name = obj['class']
        xmin, ymin, xmax, ymax = obj['bbox']

        # 裁剪图像
        cropped = img[ymin:ymax, xmin:xmax]

        # 保存为 val 图像
        base_name = os.path.splitext(os.path.basename(img_path))[0]
        new_img_name = f"{base_name}_{i}.jpg"
        new_img_path = os.path.join(VAL_IMAGE_DIR, new_img_name)
        cv2.imwrite(new_img_path, cropped)

        # 保存标注文件(YOLO 格式)
        xc = (xmin + xmax) / 2 / img_w
        yc = (ymin + ymax) / 2 / img_h
        w = (xmax - xmin) / img_w
        h = (ymax - ymin) / img_h

        label_path = os.path.join(VAL_LABEL_DIR, new_img_name.replace('.jpg', '.txt'))
        with open(label_path, 'w') as f:
            class_id = 0  # 假设只有一类 "tongue"
            f.write(f"{class_id} {xc:.6f} {yc:.6f} {w:.6f} {h:.6f}\n")

print("✅ Val 图像及标注文件已生成!")

4.2 模型训练

模型训练的核心代码位于 train.py 文件中:

from ultralytics import YOLO

# 加载预训练模型(如 yolov8n.pt, yolov8s.pt 等)
model = YOLO('models/yolov8s.pt')

# 开始训练
results = model.train(
    data='data/custom_dataset.yaml',  # 自定义数据集配置文件路径
    epochs=100,                       # 总训练轮数
    imgsz=640,                        # 输入图像尺寸
    batch=16,                         # 批处理大小
    name='yolov8_custom_train',       # 训练结果保存的目录名
    project='models/',                # 模型保存的项目路径
    exist_ok=True                     # 是否允许覆盖已存在的训练结果目录
)

4.3 模型导出

模型训练完成后,可以将其导出为标准的 .pt 文件,便于部署和使用:

from ultralytics import YOLO

# 加载自定义训练模型
model = YOLO('models/yolov8_custom_train/weights/best.pt')

# 保存为标准 pt 文件(可选)
model.save('models/yolov8_model_tongue.pt')

5. 结果输出与分析
5.1 验证指标

模型验证的主要指标包括 mAP(Mean Average Precision):

  • mAP50:IoU 阈值为 0.5 时的平均精度。
  • mAP50-95:IoU 阈值在 [0.5, 0.95] 范围内的平均精度。

这些指标越高,表示模型的检测性能越好。通常,mAP50 > 0.7 即认为模型表现良好。

from ultralytics import YOLO

model = YOLO('models/yolov8_custom_train/weights/best.pt')  # 加载训练好的模型
metrics = model.val()  # 开始验证
print(metrics.box.map)    # mAP50
print(metrics.box.map50)  # mAP50-95  这些指标越高越好,通常 mAP50 > 0.7 表示模型表现良好。
5.2 可视化结果

为了更直观地展示模型的检测效果,可以使用 Ultralytics 提供的可视化工具:

from ultralytics import YOLO

model = YOLO('models/yolov8_custom_train/weights/best.pt')
model.val(data='data/custom_dataset.yaml')

这将显示每张图上的边界框,帮助我们验证标注是否准确以及模型的检测效果。

6. 总结与展望

本文介绍了基于 OpenCV 和 YOLOv8 的自有模型训练实践,从数据集准备、模型训练、验证评估到结果输出,提供了一套完整的解决方案。未来的工作可以进一步优化数据增强策略、调整模型超参数,以提升模型的检测精度和鲁棒性。

希望本文能为读者在目标检测领域的研究和应用提供有价值的参考。如果你有任何问题或建议,欢迎留言交流!