数据集划分与格式转换:从原始数据到模型训练的关键步骤

发布于:2025-05-21 ⋅ 阅读:(20) ⋅ 点赞:(0)

在计算机视觉项目中,数据集的合理划分和格式转换是实现高效模型训练的基础。本文将详细介绍如何将图片和标注数据按比例切分为训练集和测试集,以及常见的数据格式转换方法,包括 JSON 转 YOLO 格式和 XML 转 TXT 格式。

一、将图片和标注数据按比例切分为训练集和测试集

# 将图片和标注数据按比例切分为 训练集和测试集
import shutil
import random
import os
 
# 原始路径
image_original_path = r"D:\adavance\rm\position\datasets\myseg/images/"
label_original_path = r"D:\adavance\rm\position\datasets\myseg/txt/"
 
cur_path = os.getcwd()
# 训练集路径
train_image_path = os.path.join(cur_path, "datasets/images/train/")
train_label_path = os.path.join(cur_path, "datasets/labels/train/")
 
# 验证集路径
val_image_path = os.path.join(cur_path, "datasets/images/val/")
val_label_path = os.path.join(cur_path, "datasets/labels/val/")
 
 
# 训练集目录
list_train = os.path.join(cur_path, "datasets/train.txt")
list_val = os.path.join(cur_path, "datasets/val.txt")

 
train_percent = 0.9
val_percent = 0.1

 
 
def del_file(path):
    for i in os.listdir(path):
        file_data = path + "\\" + i
        os.remove(file_data)
 
 
def mkdir():
    if not os.path.exists(train_image_path):
        os.makedirs(train_image_path)
    else:
        del_file(train_image_path)
    if not os.path.exists(train_label_path):
        os.makedirs(train_label_path)
    else:
        del_file(train_label_path)
 
    if not os.path.exists(val_image_path):
        os.makedirs(val_image_path)
    else:
        del_file(val_image_path)
    if not os.path.exists(val_label_path):
        os.makedirs(val_label_path)
    else:
        del_file(val_label_path)

 
def clearfile():
    if os.path.exists(list_train):
        os.remove(list_train)
    if os.path.exists(list_val):
        os.remove(list_val)

 
 
def main():
    mkdir()
    clearfile()
 
    file_train = open(list_train, 'w')
    file_val = open(list_val, 'w')

 
    total_txt = os.listdir(label_original_path)
    num_txt = len(total_txt)
    list_all_txt = range(num_txt)
 
    num_train = int(num_txt * train_percent)
    num_val = int(num_txt * val_percent)
    num_test = num_txt - num_train - num_val
 
    train = random.sample(list_all_txt, num_train)
    # train从list_all_txt取出num_train个元素
    # 所以list_all_txt列表只剩下了这些元素
    val_test = [i for i in list_all_txt if not i in train]
    # 再从val_test取出num_val个元素,val_test剩下的元素就是test
    val = random.sample(val_test, num_val)
 
    print("训练集数目:{}, 验证集数目:{}, 测试集数目:{}".format(len(train), len(val), len(val_test) - len(val)))
    for i in list_all_txt:
        name = total_txt[i][:-4]
 
        srcImage = image_original_path + name + '.jpg'
        srcLabel = label_original_path + name + ".txt"
 
        if i in train:
            dst_train_Image = train_image_path + name + '.jpg'
            dst_train_Label = train_label_path + name + '.txt'
            shutil.copyfile(srcImage, dst_train_Image)
            shutil.copyfile(srcLabel, dst_train_Label)
            file_train.write(dst_train_Image + '\n')
        elif i in val:
            dst_val_Image = val_image_path + name + '.jpg'
            dst_val_Label = val_label_path + name + '.txt'
            shutil.copyfile(srcImage, dst_val_Image)
            shutil.copyfile(srcLabel, dst_label)
            file_val.write(dst_val_Image + '\n')
        
    file_train.close()
    file_val.close()

if __name__ == "__main__":
    main()

代码说明

  1. 路径设置 :定义了原始图片路径(image_original_path)、原始标注路径(label_original_path)、训练集路径(train_image_pathtrain_label_path)、验证集路径(val_image_pathval_label_path)以及训练集和验证集的目录文件路径(list_trainlist_val)。
  2. 数据集划分比例 :设置了训练集占总数据集的比例为 train_percent(0.9),验证集占总数据集的比例为 val_percent(0.1)。
  3. 清理和创建目录del_file 函数用于删除指定目录下的文件,mkdir 函数用于创建训练集和验证集的目录,并在目录存在时清理其中的文件。
  4. 主函数main 函数首先调用 mkdirclearfile 函数创建和清理目录,打开训练集和验证集的目录文件。然后获取原始标注文件列表,计算数据集的大小和训练集、验证集的大小。使用 random.sample 函数随机选择训练集和验证集的索引。最后,根据索引将图片和标注文件复制到相应的训练集或验证集目录,并将训练集和验证集的文件路径写入对应的目录文件中。

二、JSON 转 YOLO 格式

import cv2  
import os  
import json  
import glob  
import numpy as np  
  
def convert_json_label_to_yolov_seg_label():  
    json_path = r"D:\adavance\rm\position\datasets\myseg\image"  # 本地 json 路径
    json_files = glob.glob(json_path + "/*.json")  
    print(json_files)  
  
    # 指定输出文件夹  
    output_folder = r"D:\adavance\rm\position\datasets\myseg\txt"  # txt 存放路径
    if not os.path.exists(output_folder):  
        os.makedirs(output_folder)  
    
    # 类别映射表,将类别名称映射为整数标签
    label_map = {
        "wood": 0,  # 示例类别1
        "head": 1,  # 示例类别2
        "nut": 2,  # 示例类别2
        # 添加更多类别映射
    }
  
    for json_file in json_files:  
        print(json_file)  
        with open(json_file, 'r') as f:  
            json_info = json.load(f)  
  
        img = cv2.imread(os.path.join(json_path, os.path.basename(json_file).replace(".json", ".jpg")))  
        height, width, _ = img.shape  
        np_w_h = np.array([[width, height]], np.int32)  
  
          
        txt_file = os.path.join(output_folder, os.path.basename(json_file).replace(".json", ".txt"))  
  
        with open(txt_file, "w") as f:
            for point_json in json_info["shapes"]:
                # 获取类别名称
                label_name = point_json["label"]
                # 获取对应的整数标签
                label_index = label_map.get(label_name)
                if label_index is None:
                    raise ValueError(f"未知类别: {label_name}")

                txt_content = ""
                np_points = np.array(point_json["points"], np.int32)
                norm_points = np_points / np_w_h
                norm_points_list = norm_points.tolist()
                txt_content += f"{label_index} " + " ".join([" ".join([str(cell[0]), str(cell[1])]) for cell in norm_points_list]) + "\n"
                f.write(txt_content)
  
convert_json_label_to_yolov_seg_label()

代码说明

  1. 路径设置 :定义了 JSON 文件的路径(json_path)和输出的 TXT 文件夹路径(output_folder)。
  2. 类别映射表 :建立了类别名称到整数标签的映射关系(label_map)。
  3. 主函数convert_json_label_to_yolov_seg_label 函数遍历 JSON 文件,加载 JSON 数据,读取对应的图片以获取图片尺寸。然后,对每个 JSON 文件中的标注信息进行处理,将类别名称转换为整数标签,将标注点坐标进行归一化处理,并将结果写入对应的 TXT 文件中。

三、XML 转 TXT 格式

import xml.etree.ElementTree as ET
 
import pickle
import os
from os import listdir , getcwd
from os.path import join
import glob
 
classes = ["bolt", "bolt_lost"]    # xml 文件中标记的种类

def convert(size, box):
    dw = 1.0/size[0]
    dh = 1.0/size[1]
    x = (box[0]+box[1])/2.0
    y = (box[2]+box[3])/2.0
    w = box[1] - box[0]
    h = box[3] - box[2]
    x = x*dw
    w = w*dw
    y = y*dh
    h = h*dh
    return (x,y,w,h)

def convert_annotation(image_name):
    try:
        in_file = open(r'D:\adavance\bts\ssqz03/'+image_name[:-3]+'xml', encoding='utf-8')    # 原来的 xml 文件路径
    except FileNotFoundError:
        print(f"Warning: XML file for {image_name} not found, skipping.")
        return
    out_file = open(r'D:\adavance\bts\ssqz03/'+image_name[:-3]+'txt', 'w', encoding='utf-8')  # 转换后的 txt 文件存放路径
    tree=ET.parse(in_file)
    root = tree.getroot()
    size = root.find('size')
    w = int(size.find('width').text)
    h = int(size.find('height').text)

    for obj in root.iter('object'):
        difficult = obj.find('difficult').text
        cls = obj.find('name').text
        if cls not in classes or int(difficult) == 1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
        bb = convert((w,h), b)
        out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
    in_file.close()   
    out_file.close()  

wd = getcwd()
 
if __name__ == '__main__':
 
    for image_path in glob.glob(r"D:\adavance\bts\ssqz03/*.jpg"):  # xml 对应的图片的路径
        image_name = image_path.split('\\')[-1]
        convert_annotation(image_name)

代码说明

  1. 类别定义 :定义了 XML 文件中标记的种类(classes)。
  2. 坐标转换函数convert 函数用于将 XML 文件中的边界框坐标转换为 YOLO 格式所需的归一化坐标。
  3. 转换标注函数convert_annotation 函数读取 XML 文件,解析其中的标注信息,调用 convert 函数进行坐标转换,并将结果写入对应的 TXT 文件中。
  4. 主函数 :遍历图片路径,获取图片名称,并调用 convert_annotation 函数进行标注转换。

通过以上三种代码实现,我们可以完成数据集的划分以及不同格式之间的转换,为模型训练做好数据准备。这些步骤在实际的计算机视觉项目中具有重要的应用价值,能够提高数据处理的效率和准确性,进而提升模型的性能。

在实际应用中,可以根据具体的数据集和项目需求对代码进行适当的修改和优化。例如,可以调整数据集的划分比例、增加更多的类别映射、处理不同格式的标注文件等。同时,也要注意数据的一致性和完整性,确保转换后的数据能够正确地用于模型训练和评估。

希望本文对大家在数据集处理方面有所帮助。如果你有任何问题或建议,欢迎在评论区留言交流。


网站公告

今日签到

点亮在社区的每一天
去签到