RTDETRv2 pytorch 官方版自己数据集训练遇到的问题解决

发布于:2025-06-16 ⋅ 阅读:(22) ⋅ 点赞:(0)

rtdetrv2 训练问题遇到的问题。

pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2  --index-url https://download.pytorch.org/whl/cu117

1

Please make sure torchvision version >= 0.15.2

发现自己实际装的是 torchvison==0.15.2+cu117

修改_misc.py中修改为实际版本

if importlib.metadata.version('torchvision') == '0.15.2+cu117':

2

Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"

报错很多行。

标注的coco序号不对,自己的数据集,分类需要id从0开始,

coco像这种格式。“info”字段也要有,不然也会报错。

{
  "info": {
    "description": "COCO Dataset"
  },
  "licenses": [
    {
      "name": ""
    }
  ],
  "images": [
    {
      "id": 1,
      "file_name": "00002.png",
      "height": 1080,
      "width": 1920
    },
    {
      "id": 2,
      "file_name": "00009.png",
      "height": 1080,
      "width": 1920
    }
  ],
  "annotations": [
    {
      "id": 1,
      "image_id": 1,
      "category_id": 0,
      "segmentation": [
        [
          642.6923076923077,
          234.23076923076925,
          1377.3076923076924,
          234.23076923076925,
          1377.3076923076924,
          782.3076923076923,
          642.6923076923077,
          782.3076923076923
        ]
      ],
      "area": 402625.7396449703,
      "bbox": [
        642.6923076923077,
        234.23076923076925,
        734.6153846153846,
        548.076923076923
      ],
      "iscrowd": 0
    },
    {
      "id": 2,
      "image_id": 2,
      "category_id": 1,
      "segmentation": [
        [
          490.76923076923083,
          222.6923076923077,
          1252.3076923076924,
          222.6923076923077,
          1252.3076923076924,
          784.2307692307692,
          490.76923076923083,
          784.2307692307692
        ]
      ],
      "area": 427633.1360946745,
      "bbox": [
        490.76923076923083,
        222.6923076923077,
        761.5384615384615,
        561.5384615384614
      ],
      "iscrowd": 0
    }
  ],
  "categories": [
    {
      "id": 0,
      "name": "ng",
      "supercategory": ""
    },
    {
      "id": 1,
      "name": "ok",
      "supercategory": ""
    }
  ]
}

完整转换脚本,从xml转为coco。

import os
import json
import xml.etree.ElementTree as ET
from collections import defaultdict
from tqdm import tqdm
import argparse
import shutil
import cv2
import numpy as np

def parse_args():
    parser = argparse.ArgumentParser(description='Convert Pascal VOC XML annotations to COCO format')
    parser.add_argument('--xml_dir', type=str, required=True, 
                        help='Directory containing XML annotation files')
    parser.add_argument('--img_dir', type=str, required=True,
                        help='Directory containing corresponding images')
    parser.add_argument('--output_json', type=str, required=True,
                        help='Output COCO format JSON file path')
    parser.add_argument('--copy_images', action='store_true',
                        help='Copy images to a new directory structure')
    parser.add_argument('--output_img_dir', type=str, default='coco_dataset',
                        help='Output directory for images if copying is enabled')
    return parser.parse_args()

def get_image_size(image_path):
    """获取图像尺寸"""
    try:
        img = cv2.imread(image_path)
        if img is None:
            raise IOError(f"无法读取图像: {image_path}")
        return img.shape[1], img.shape[0]  # 宽度, 高度
    except Exception as e:
        print(f"错误: {e}")
        return 0, 0

def convert_xml_to_coco(xml_dir, img_dir, output_json, copy_images=False, output_img_dir=None):
    # 创建COCO数据结构
    coco_data = {
        "info": {
            "description": "COCO Dataset converted from Pascal VOC XML",
            "version": "1.0",
            "year": 2023,
            "contributor": "XML to COCO Converter",
            "date_created": "2023-01-01"
        },
        "licenses": [{
            "url": "https://creativecommons.org/licenses/by/4.0/",
            "id": 1,
            "name": "CC BY 4.0"
        }],
        "images": [],
        "annotations": [],
        "categories": []
    }
    
    # 处理类别
    category_dict = {}
    next_category_id = 0
    
    # 处理图像和标注
    image_dict = {}
    next_image_id = 1
    next_ann_id = 1
    
    # 收集所有XML文件
    xml_files = [f for f in os.listdir(xml_dir) if f.endswith('.xml')]
    
    # 创建输出图像目录(如果需要)
    if copy_images and output_img_dir:
        os.makedirs(output_img_dir, exist_ok=True)
    
    print(f"找到 {len(xml_files)} 个XML文件,开始转换...")
    
    # 处理每个XML文件
    for xml_file in tqdm(xml_files):
        xml_path = os.path.join(xml_dir, xml_file)
        
        try:
            # 解析XML
            tree = ET.parse(xml_path)
            root = tree.getroot()
            
            # 获取图像文件名
            filename = root.find('filename').text
            img_path = os.path.join(img_dir, filename)
            
            # 如果图像不存在,跳过
            if not os.path.exists(img_path):
                print(f"警告: 图像文件不存在 - {img_path}")
                continue
            
            # 获取图像尺寸
            size = root.find('size')
            if size is not None:
                width = int(size.find('width').text)
                height = int(size.find('height').text)
            else:
                # 如果XML中没有尺寸信息,从图像读取
                width, height = get_image_size(img_path)
                if width == 0 or height == 0:
                    print(f"警告: 无法获取图像尺寸 - {img_path}")
                    continue
            
            # 如果复制图像
            if copy_images and output_img_dir:
                new_img_path = os.path.join(output_img_dir, filename)
                shutil.copy2(img_path, new_img_path)
            
            # 创建图像条目
            if filename not in image_dict:
                image_entry = {
                    "id": next_image_id,
                    "file_name": filename,
                    "width": width,
                    "height": height,
                    "license": 1,
                    "date_captured": "2023-01-01"
                }
                coco_data["images"].append(image_entry)
                image_dict[filename] = next_image_id
                next_image_id += 1
            
            image_id = image_dict[filename]
            
            # 处理每个对象
            for obj in root.findall('object'):
                # 类别处理
                name = obj.find('name').text
                if name not in category_dict:
                    category_entry = {
                        "id": next_category_id,
                        "name": name,
                        "supercategory": "object"
                    }
                    coco_data["categories"].append(category_entry)
                    category_dict[name] = next_category_id
                    next_category_id += 1
                
                category_id = category_dict[name]
                
                # 边界框处理
                bbox = obj.find('bndbox')
                if bbox is None:
                    continue
                
                xmin = float(bbox.find('xmin').text)
                ymin = float(bbox.find('ymin').text)
                xmax = float(bbox.find('xmax').text)
                ymax = float(bbox.find('ymax').text)
                
                # 转换为COCO格式 [x, y, width, height]
                bbox_width = xmax - xmin
                bbox_height = ymax - ymin
                
                # 创建标注条目
                ann_entry = {
                    "id": next_ann_id,
                    "image_id": image_id,
                    "category_id": category_id,
                    "bbox": [xmin, ymin, bbox_width, bbox_height],
                    "area": bbox_width * bbox_height,
                    "segmentation": [],
                    "iscrowd": 0
                }
                coco_data["annotations"].append(ann_entry)
                next_ann_id += 1
                
        except Exception as e:
            print(f"处理文件 {xml_file} 时出错: {str(e)}")
    
    # 保存为JSON文件
    with open(output_json, 'w') as f:
        json.dump(coco_data, f, indent=2)
    
    print(f"转换完成!")
    print(f"共处理 {len(coco_data['images'])} 张图像")
    print(f"共处理 {len(coco_data['annotations'])} 个标注")
    print(f"共发现 {len(coco_data['categories'])} 个类别")
    print(f"结果已保存到: {output_json}")
    
    # 保存类别映射文件
    category_map_path = os.path.join(os.path.dirname(output_json), 'category_mapping.txt')
    with open(category_map_path, 'w') as f:
        for name, cid in category_dict.items():
            f.write(f"{name}: {cid}\n")
    print(f"类别映射已保存到: {category_map_path}")
    
    return coco_data

if __name__ == "__main__":
    args = parse_args()
    
    # 运行转换
    coco_data = convert_xml_to_coco(
        args.xml_dir,
        args.img_dir,
        args.output_json,
        args.copy_images,
        args.output_img_dir
    )

调用:生成coco的json

python xml_to_coco.py    --xml_dir  train2017   --img_dir  train2017   --output_json  annotations/instances_train2017.json

python xml_to_coco.py    --xml_dir  val2017   --img_dir  val2017   --output_json  annotations/instances_val2017.json

数据集结构图:

然后训练:

python tools/train.py  --config=configs/rtdetrv2/rtdetrv2_r18vd_120e_coco.yml   --use-amp --seed=0

转换onnx

python tools/export_onnx.py -c=configs/rtdetrv2/rtdetrv2_r18vd_120e_coco.yml -r last.pth --check

转换trt,python 版本

python tools/export_trt.py -i model.onnx

或者装了tensorrt 的用直接命令行。

tensorrt 版本要大于8.5.2,不然有的算子不支持,会报错。

trtexec --onnx=model.onnx --saveEngine=model.trt

上面python 文件夹whl可以直接pip install tensorrt-8.6.0-cp39-none-win_amd64.whl

安装tensorrt python版本。针对直接装python装不上的情况。

生成的权重还是挺大的,个人感觉没有yolo好用。


网站公告

今日签到

点亮在社区的每一天
去签到