augment.py
ultralytics\data\augment.py
目录
5.class CutMix(BaseMixTransform):
6.class CopyPaste(BaseMixTransform):
7.def v8_transforms(dataset, imgsz, hyp, stretch=False):
1.所需的库和模块
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import math
import random
from copy import deepcopy
from typing import List, Tuple, Union
import cv2
import numpy as np
import torch
from PIL import Image
from torch.nn import functional as F
from ultralytics.data.utils import polygons2masks, polygons2masks_overlap
from ultralytics.utils import LOGGER, colorstr
from ultralytics.utils.checks import check_version
from ultralytics.utils.instance import Instances
from ultralytics.utils.metrics import bbox_ioa
from ultralytics.utils.ops import segment2box, xywh2xyxy, xyxyxyxy2xywhr
from ultralytics.utils.torch_utils import TORCHVISION_0_10, TORCHVISION_0_11, TORCHVISION_0_13
DEFAULT_MEAN = (0.0, 0.0, 0.0)
DEFAULT_STD = (1.0, 1.0, 1.0)
2.class BaseTransform:
# 这段代码定义了一个名为 BaseTransform 的类,它是一个基础的变换类,用于对图像及其相关标签(如实例分割和语义分割)进行一系列的变换操作。
# 定义了一个名为 BaseTransform 的类。这个类可以作为其他具体变换类的基类,提供了一个通用的框架,用于对图像及其相关标签进行变换。
class BaseTransform:
# Ultralytics 库中图像转换的基类。
# 此类是实现各种图像处理操作的基础,旨在兼容分类和语义分割任务。
# 方法:
# apply_image:将图像转换应用于标签。
# apply_instances:将转换应用于标签中的对象实例。
# apply_semantic:将语义分割应用于图像。
# __call__:将所有标签转换应用于图像、实例和语义蒙版。
"""
Base class for image transformations in the Ultralytics library.
This class serves as a foundation for implementing various image processing operations, designed to be
compatible with both classification and semantic segmentation tasks.
Methods:
apply_image: Apply image transformations to labels.
apply_instances: Apply transformations to object instances in labels.
apply_semantic: Apply semantic segmentation to an image.
__call__: Apply all label transformations to an image, instances, and semantic masks.
Examples:
>>> transform = BaseTransform()
>>> labels = {"image": np.array(...), "instances": [...], "semantic": np.array(...)}
>>> transformed_labels = transform(labels)
"""
# 定义了类的初始化方法 __init__ 。这个方法在创建类的实例时被调用。
# -> None 表示这个方法没有返回值。
# 这里 pass 表示初始化方法目前什么也不做。在实际使用中,可以根据需要在这个方法中初始化一些属性或参数。
def __init__(self) -> None:
# 初始化 BaseTransform 对象。
# 此构造函数设置基础转换对象,该对象可针对特定图像处理任务进行扩展。它旨在兼容分类和语义分割。
"""
Initialize the BaseTransform object.
This constructor sets up the base transformation object, which can be extended for specific image
processing tasks. It is designed to be compatible with both classification and semantic segmentation.
Examples:
>>> transform = BaseTransform()
"""
pass
# 定义了一个名为 apply_image 的方法,用于对图像进行变换。
# 参数 1.labels 可能是一个包含图像数据及其标签的结构,但具体含义需要根据实际使用场景来确定。
# 这里 pass 表示这个方法目前什么也不做。在实际使用中,可以根据需要在这个方法中实现具体的图像变换逻辑,如裁剪、缩放、旋转等。
def apply_image(self, labels):
# 将图像转换应用于标签。
# 此方法旨在被子类重写,以实现特定的图像转换逻辑。在其基本形式中,它返回未更改的输入标签。
"""
Apply image transformations to labels.
This method is intended to be overridden by subclasses to implement specific image transformation
logic. In its base form, it returns the input labels unchanged.
Args:
labels (Any): The input labels to be transformed. The exact type and structure of labels may
vary depending on the specific implementation.
Returns:
(Any): The transformed labels. In the base implementation, this is identical to the input.
Examples:
>>> transform = BaseTransform()
>>> original_labels = [1, 2, 3]
>>> transformed_labels = transform.apply_image(original_labels)
>>> print(transformed_labels)
[1, 2, 3]
"""
pass
# 定义了一个名为 apply_instances 的方法,用于对实例分割标签进行变换。
# 参数 1.labels 可能是一个包含实例分割标签的数据结构,但具体含义需要根据实际使用场景来确定。
# 这里 pass 表示这个方法目前什么也不做。在实际使用中,可以根据需要在这个方法中实现具体的实例分割标签变换逻辑,如调整边界框、掩码等。
def apply_instances(self, labels):
# 对标签中的对象实例应用变换。
# 此方法负责对给定标签中的对象实例应用各种变换。它旨在被子类重写,以实现特定的实例变换逻辑。
"""
Apply transformations to object instances in labels.
This method is responsible for applying various transformations to object instances within the given
labels. It is designed to be overridden by subclasses to implement specific instance transformation
logic.
Args:
labels (dict): A dictionary containing label information, including object instances.
Returns:
(dict): The modified labels dictionary with transformed object instances.
Examples:
>>> transform = BaseTransform()
>>> labels = {"instances": Instances(xyxy=torch.rand(5, 4), cls=torch.randint(0, 80, (5,)))}
>>> transformed_labels = transform.apply_instances(labels)
"""
pass
# 定义了一个名为 apply_semantic 的方法,用于对语义分割标签进行变换。
# 参数 1.labels 可能是一个包含语义分割标签的数据结构,但具体含义需要根据实际使用场景来确定。
# 这里 pass 表示这个方法目前什么也不做。在实际使用中,可以根据需要在这个方法中实现具体的语义分割标签变换逻辑,如调整像素标签等。
def apply_semantic(self, labels):
# 对图像应用语义分割转换。
# 此方法旨在被子类重写,以实现特定的语义分割转换。其基本形式不执行任何操作。
"""
Apply semantic segmentation transformations to an image.
This method is intended to be overridden by subclasses to implement specific semantic segmentation
transformations. In its base form, it does not perform any operations.
Args:
labels (Any): The input labels or semantic segmentation mask to be transformed.
Returns:
(Any): The transformed semantic segmentation mask or labels.
Examples:
>>> transform = BaseTransform()
>>> semantic_mask = np.zeros((100, 100), dtype=np.uint8)
>>> transformed_mask = transform.apply_semantic(semantic_mask)
"""
pass
# 定义了一个名为 __call__ 的方法,使得类的实例可以像函数一样被调用。
# 参数 1.labels 是一个包含图像及其标签的数据结构。
# 在这个方法中,依次调用了 apply_image 、 apply_instances 和 apply_semantic 方法,对图像及其标签进行一系列的变换操作。
def __call__(self, labels):
# 将所有标签转换应用于图像、实例和语义蒙版。
# 此方法协调将 BaseTransform 类中定义的各种转换应用于输入标签。它依次调用 apply_image 和 apply_instances 方法分别处理图像和对象实例。
# 参数:
# labels (dict):包含图像数据和注释的字典。预期键包括“img”(表示图像数据)和“instances”(表示对象实例)。
# 返回:
# (dict):包含转换后图像和实例的输入标签字典。
"""
Apply all label transformations to an image, instances, and semantic masks.
This method orchestrates the application of various transformations defined in the BaseTransform class
to the input labels. It sequentially calls the apply_image and apply_instances methods to process the
image and object instances, respectively.
Args:
labels (dict): A dictionary containing image data and annotations. Expected keys include 'img' for
the image data, and 'instances' for object instances.
Returns:
(dict): The input labels dictionary with transformed image and instances.
Examples:
>>> transform = BaseTransform()
>>> labels = {"img": np.random.rand(640, 640, 3), "instances": []}
>>> transformed_labels = transform(labels)
"""
# 这三行代码分别调用了 apply_image 、 apply_instances 和 apply_semantic 方法,对输入的 labels 进行相应的变换操作。
# 这种设计使得 BaseTransform 类可以作为一个通用的变换框架,通过在子类中实现具体的变换逻辑,来实现不同的变换操作。
self.apply_image(labels)
self.apply_instances(labels)
self.apply_semantic(labels)
# 这段代码定义了一个名为 BaseTransform 的基础变换类,它提供了一个通用的框架,用于对图像及其相关标签(如实例分割和语义分割)进行一系列的变换操作。通过定义 apply_image 、 apply_instances 和 apply_semantic 方法,类可以分别对图像、实例分割标签和语义分割标签进行变换。而 __call__ 方法则使得类的实例可以像函数一样被调用,并依次执行这些变换操作。这种设计使得 BaseTransform 类可以方便地被扩展和复用,适用于各种图像处理和计算机视觉任务。
3.class Compose:
# 这段代码定义了一个名为 Compose 的类,用于将多个变换操作组合起来,形成一个可调用的变换序列。
# 定义了一个名为 Compose 的类,用于组合多个变换操作。
class Compose:
# 用于组合多个图像变换的类。
# 属性:
# transforms (List[Callable]):按顺序应用的变换函数列表。
# 方法:
# __call__:对输入数据应用一系列变换。
# append:将新的变换附加到现有的变换列表中。
# insert:在变换列表中的指定索引处插入新的变换。
# __getitem__:使用索引检索特定变换或一组变换。
# __setitem__:使用索引设置特定变换或一组变换。
# tolist:将变换列表转换为标准 Python 列表。
"""
A class for composing multiple image transformations.
Attributes:
transforms (List[Callable]): A list of transformation functions to be applied sequentially.
Methods:
__call__: Apply a series of transformations to input data.
append: Append a new transform to the existing list of transforms.
insert: Insert a new transform at a specified index in the list of transforms.
__getitem__: Retrieve a specific transform or a set of transforms using indexing.
__setitem__: Set a specific transform or a set of transforms using indexing.
tolist: Convert the list of transforms to a standard Python list.
Examples:
>>> transforms = [RandomFlip(), RandomPerspective(30)]
>>> compose = Compose(transforms)
>>> transformed_data = compose(data)
>>> compose.append(CenterCrop((224, 224)))
>>> compose.insert(0, RandomFlip())
"""
# 定义了类的初始化方法 __init__ ,接收一个参数 1.transforms ,它是一个变换操作的列表或单个变换操作。 self.transforms 用于存储变换操作的列表。
def __init__(self, transforms):
# 使用转换列表初始化 Compose 对象。
"""
Initialize the Compose object with a list of transforms.
Args:
transforms (List[Callable]): A list of callable transform objects to be applied sequentially.
Examples:
>>> from ultralytics.data.augment import Compose, RandomHSV, RandomFlip
>>> transforms = [RandomHSV(), RandomFlip()]
>>> compose = Compose(transforms)
"""
# 如果 transforms 是一个列表,则直接赋值给 self.transforms 。
# 如果 transforms 不是列表,则将其封装成一个列表,确保 self.transforms 始终是一个列表。
self.transforms = transforms if isinstance(transforms, list) else [transforms]
# 定义了类的 __call__ 方法,使得类的实例可以像函数一样被调用。 参数 1.data 是需要被变换的数据。
def __call__(self, data):
# 对输入数据应用一系列转换。
# 此方法按顺序将 Compose 对象转换中的每个转换应用于输入数据。
"""
Apply a series of transformations to input data.
This method sequentially applies each transformation in the Compose object's transforms to the input data.
Args:
data (Any): The input data to be transformed. This can be of any type, depending on the
transformations in the list.
Returns:
(Any): The transformed data after applying all transformations in sequence.
Examples:
>>> transforms = [Transform1(), Transform2(), Transform3()]
>>> compose = Compose(transforms)
>>> transformed_data = compose(input_data)
"""
# 遍历 self.transforms 中的每个变换操作 t ,并将 data 依次传递给每个变换操作。
# 最终返回经过所有变换操作后的 data 。
for t in self.transforms:
data = t(data)
return data
# 定义了一个 append 方法,用于向变换操作列表中添加一个新的变换操作。 参数 1.transform 是一个变换操作。
def append(self, transform):
# 将新的变换附加到现有的变换列表中。
"""
Append a new transform to the existing list of transforms.
Args:
transform (BaseTransform): The transformation to be added to the composition.
Examples:
>>> compose = Compose([RandomFlip(), RandomPerspective()])
>>> compose.append(RandomHSV())
"""
# 使用列表的 append 方法将新的变换操作添加到 self.transforms 中。
self.transforms.append(transform)
# 定义了一个 insert 方法,用于在指定位置插入一个新的变换操作。 参数 1.index 是插入的位置, transform 是新的变换操作。
def insert(self, index, transform):
# 在现有变换列表中的指定索引处插入新的变换。
"""
Insert a new transform at a specified index in the existing list of transforms.
Args:
index (int): The index at which to insert the new transform.
transform (BaseTransform): The transform object to be inserted.
Examples:
>>> compose = Compose([Transform1(), Transform2()])
>>> compose.insert(1, Transform3())
>>> len(compose.transforms)
3
"""
# 使用列表的 insert 方法将新的变换操作插入到指定位置。
self.transforms.insert(index, transform)
# 定义了 __getitem__ 方法,用于通过索引获取子变换序列。 参数 1.index 可以是一个整数或整数列表,表示需要获取的变换操作的索引。
def __getitem__(self, index: Union[list, int]) -> "Compose":
# 使用索引检索特定变换或一组变换。
# 参数:
# index (int | List[int]):要检索的变换的索引或索引列表。
# 返回:
# (Compose):一个包含所选变换的新 Compose 对象。
# 引发:
# AssertionError:如果索引不是 int 或 list 类型。
"""
Retrieve a specific transform or a set of transforms using indexing.
Args:
index (int | List[int]): Index or list of indices of the transforms to retrieve.
Returns:
(Compose): A new Compose object containing the selected transform(s).
Raises:
AssertionError: If the index is not of type int or list.
Examples:
>>> transforms = [RandomFlip(), RandomPerspective(10), RandomHSV(0.5, 0.5, 0.5)]
>>> compose = Compose(transforms)
>>> single_transform = compose[1] # Returns a Compose object with only RandomPerspective
>>> multiple_transforms = compose[0:2] # Returns a Compose object with RandomFlip and RandomPerspective
"""
# 首先检查 index 是否为整数或整数列表,如果不是则抛出异常。
assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}" # 索引应该是列表或 int 类型,但得到的是 {type(index)} 。
# 如果 index 是整数,则将其转换为列表。
index = [index] if isinstance(index, int) else index
# 使用列表推导式从 self.transforms 中提取指定索引的变换操作,并返回一个新的 Compose 实例。
return Compose([self.transforms[i] for i in index])
# 这段代码定义了 Compose 类的 __setitem__ 方法,用于设置指定索引处的变换操作。
# 定义了 __setitem__ 方法,该方法允许通过索引设置 self.transforms 中的变换操作。
# 参数 1.index 可以是一个整数或整数列表,表示需要设置的索引位置。
# 参数 2.value 可以是一个变换操作或变换操作列表,表示要设置的新值。
# 返回值为 None ,表示该方法不返回任何值。
def __setitem__(self, index: Union[list, int], value: Union[list, int]) -> None:
# 使用索引在组合中设置一个或多个变换。
# 参数:
# index (int | List[int]):要设置变换的索引或索引列表。
# value (Any | List[Any]):要在指定索引处设置的变换或变换列表。
# 引发:
# AssertionError:如果索引类型无效、值类型与索引类型不匹配或索引超出范围。
"""
Set one or more transforms in the composition using indexing.
Args:
index (int | List[int]): Index or list of indices to set transforms at.
value (Any | List[Any]): Transform or list of transforms to set at the specified index(es).
Raises:
AssertionError: If index type is invalid, value type doesn't match index type, or index is out of range.
Examples:
>>> compose = Compose([Transform1(), Transform2(), Transform3()])
>>> compose[1] = NewTransform() # Replace second transform
>>> compose[0:2] = [NewTransform1(), NewTransform2()] # Replace first two transforms
"""
# 使用 assert 语句检查 index 是否为整数或整数列表。 如果 index 既不是整数也不是列表,抛出一个 AssertionError ,并提供错误信息。
assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}" # 索引应该是列表或 int 类型,但得到的是 {type(index)} 。
# 如果 index 是一个列表,使用 assert 语句检查 value 是否也是一个列表。 如果 value 不是列表,抛出一个 AssertionError ,并提供错误信息。
if isinstance(index, list):
assert isinstance(value, list), (
f"The indices should be the same type as values, but got {type(index)} and {type(value)}" # 索引应与值的类型相同,但得到的是 {type(index)} 和 {type(value)} 。
)
# 如果 index 是一个整数,将其转换为一个包含该整数的列表。 同时,将 value 也转换为一个包含该值的列表。 这样可以统一处理单个索引和多个索引的情况。
if isinstance(index, int):
index, value = [index], [value]
# 使用 zip 函数将 index 和 value 配对,形成一个迭代器。 遍历这个迭代器,每次迭代中, i 是索引, v 是对应的值。
for i, v in zip(index, value):
# 使用 assert 语句检查索引 i 是否在 self.transforms 的有效范围内。 如果索引超出范围,抛出一个 AssertionError ,并提供错误信息。
assert i < len(self.transforms), f"list index {i} out of range {len(self.transforms)}." # 列表索引 {i} 超出范围 {len(self.transforms)}。
# 将 self.transforms 中索引为 i 的变换操作设置为新的值 v 。
self.transforms[i] = v
# 这段代码定义了 Compose 类的 __setitem__ 方法,用于设置指定索引处的变换操作。该方法支持单个索引和多个索引的设置,并且在设置过程中进行了严格的类型检查和范围检查,确保操作的正确性和安全性。通过这种方式, Compose 类可以灵活地管理和修改变换操作列表,适用于各种数据预处理和增强任务。
# 定义了一个 tolist 方法,用于将变换操作列表转换为普通列表。
def tolist(self):
# 将变换列表转换为标准 Python 列表。
"""
Convert the list of transforms to a standard Python list.
Returns:
(list): A list containing all the transform objects in the Compose instance.
Examples:
>>> transforms = [RandomFlip(), RandomPerspective(10), CenterCrop()]
>>> compose = Compose(transforms)
>>> transform_list = compose.tolist()
>>> print(len(transform_list))
3
"""
# 直接返回 self.transforms 。
return self.transforms
# 定义了 __repr__ 方法,用于返回类的字符串表示。
def __repr__(self):
# 返回 Compose 对象的字符串表示形式。
"""
Return a string representation of the Compose object.
Returns:
(str): A string representation of the Compose object, including the list of transforms.
Examples:
>>> transforms = [RandomFlip(), RandomPerspective(degrees=10, translate=0.1, scale=0.1)]
>>> compose = Compose(transforms)
>>> print(compose)
Compose([
RandomFlip(),
RandomPerspective(degrees=10, translate=0.1, scale=0.1)
])
"""
# 使用字符串格式化方法返回类的名称和变换操作列表的字符串表示。
return f"{self.__class__.__name__}({', '.join([f'{t}' for t in self.transforms])})"
# 这段代码定义了一个名为 Compose 的类,用于将多个变换操作组合起来,形成一个可调用的变换序列。通过 __call__ 方法,可以依次执行所有变换操作。此外,类还提供了 append 、 insert 、 __getitem__ 、 __setitem__ 、 tolist 和 __repr__ 等方法,用于方便地管理变换操作列表。这种设计使得 Compose 类可以灵活地组合和管理多个变换操作,适用于各种数据预处理和增强任务。
4.class BaseMixTransform:
# 这段代码定义了一个名为 BaseMixTransform 的类,它是一个基础的混合变换类,用于实现如Mosaic、CutMix或MixUp等数据增强技术。这些技术通过混合多个图像及其标签来生成新的训练样本,从而提高模型的泛化能力。
class BaseMixTransform:
# Cutmix、MixUp 和 Mosaic 等混合变换的基类。
# 此类为在数据集上实现混合变换提供了基础。它处理基于概率的变换应用,并管理多幅图像和标签的混合。
# 方法:
# __call__:将混合变换应用于输入标签。
# _mix_transform:由子类实现的抽象方法,用于特定的混合操作。
# get_indexes:获取待混合图像索引的抽象方法。
# _update_label_text:更新混合图像的标签文本。
"""
Base class for mix transformations like Cutmix, MixUp and Mosaic.
This class provides a foundation for implementing mix transformations on datasets. It handles the
probability-based application of transforms and manages the mixing of multiple images and labels.
Attributes:
dataset (Any): The dataset object containing images and labels.
pre_transform (Callable | None): Optional transform to apply before mixing.
p (float): Probability of applying the mix transformation.
Methods:
__call__: Apply the mix transformation to the input labels.
_mix_transform: Abstract method to be implemented by subclasses for specific mix operations.
get_indexes: Abstract method to get indexes of images to be mixed.
_update_label_text: Update label text for mixed images.
Examples:
>>> class CustomMixTransform(BaseMixTransform):
... def _mix_transform(self, labels):
... # Implement custom mix logic here
... return labels
...
... def get_indexes(self):
... return [random.randint(0, len(self.dataset) - 1) for _ in range(3)]
>>> dataset = YourDataset()
>>> transform = CustomMixTransform(dataset, p=0.5)
>>> mixed_labels = transform(original_labels)
"""
# 定义了类的初始化方法 __init__ ,接收三个参数:
# 1.dataset :数据集对象,用于获取图像及其标签。
# 2.pre_transform :可选的预变换操作,用于在混合之前对图像进行预处理。
# 3.p :混合变换的概率,默认为0.0,表示不进行混合变换。
def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
# 初始化 BaseMixTransform 对象,用于 CutMix、MixUp 和 Mosaic 等混合变换。
# 此类是图像处理流程中实现混合变换的基础。
"""
Initialize the BaseMixTransform object for mix transformations like CutMix, MixUp and Mosaic.
This class serves as a base for implementing mix transformations in image processing pipelines.
Args:
dataset (Any): The dataset object containing images and labels for mixing.
pre_transform (Callable | None): Optional transform to apply before mixing.
p (float): Probability of applying the mix transformation. Should be in the range [0.0, 1.0].
Examples:
>>> dataset = YOLODataset("path/to/data")
>>> pre_transform = Compose([RandomFlip(), RandomPerspective()])
>>> mix_transform = BaseMixTransform(dataset, pre_transform, p=0.5)
"""
# 将传入的参数分别赋值给类的属性 self.dataset 、 self.pre_transform 和 self.p 。
self.dataset = dataset
self.pre_transform = pre_transform
self.p = p
# 这段代码定义了 BaseMixTransform 类的 __call__ 方法,它实现了混合变换操作的核心逻辑。
# 定义了 BaseMixTransform 类的 __call__ 方法,使得类的实例可以像函数一样被调用。 参数 1.labels 是一个包含图像及其标签的数据结构,通常是一个字典。
def __call__(self, labels):
# 对标签数据应用预处理变换以及 CutMix/Mixup/Mosaic 变换。
# 此方法根据概率因子确定是否应用混合变换。如果应用,它将选择其他图像,应用预变换(如果指定),然后执行混合变换。
"""
Apply pre-processing transforms and cutmix/mixup/mosaic transforms to labels data.
This method determines whether to apply the mix transform based on a probability factor. If applied, it
selects additional images, applies pre-transforms if specified, and then performs the mix transform.
Args:
labels (dict): A dictionary containing label data for an image.
Returns:
(dict): The transformed labels dictionary, which may include mixed data from other images.
Examples:
>>> transform = BaseMixTransform(dataset, pre_transform=None, p=0.5)
>>> result = transform({"image": img, "bboxes": boxes, "cls": classes})
"""
# 使用 random.uniform(0, 1) 生成一个0到1之间的随机数。
# 如果随机数大于 self.p (混合变换的概率),则直接返回原始的 labels ,不进行混合变换。
# 这一步确保了混合变换操作以指定的概率 self.p 发生。
if random.uniform(0, 1) > self.p:
return labels
# Get index of one or three other images
# 调用 self.get_indexes() 方法获取一个或多个其他图像的索引。
indexes = self.get_indexes()
# 如果返回的 indexes 是一个整数,则将其转换为一个包含该整数的列表。 这样可以统一处理单个索引和多个索引的情况。
if isinstance(indexes, int):
indexes = [indexes]
# Get images information will be used for Mosaic, CutMix or MixUp
# 使用列表推导式,通过 self.dataset.get_image_and_label(i) 方法获取每个索引对应的图像及其标签,存储在 mix_labels 列表中。 self.dataset.get_image_and_label(i) 方法通常返回一个包含图像和标签的数据结构。
mix_labels = [self.dataset.get_image_and_label(i) for i in indexes]
# 如果 self.pre_transform 不为 None ,则对 mix_labels 中的每个图像及其标签应用预变换操作。 预变换操作可以在混合之前对图像进行一些处理,如归一化、裁剪等。
if self.pre_transform is not None:
for i, data in enumerate(mix_labels):
mix_labels[i] = self.pre_transform(data)
# 将 mix_labels 添加到 labels 字典中,键为 "mix_labels" 。 这样可以在后续的混合变换操作中访问这些混合标签。
labels["mix_labels"] = mix_labels
# Update cls and texts
# 调用 self._update_label_text(labels) 方法,更新 labels 中的类别和文本标签。 这个方法通常会处理类别ID和文本标签的映射,确保混合后的标签一致。
labels = self._update_label_text(labels)
# Mosaic, CutMix or MixUp
# 调用 self._mix_transform(labels) 方法,执行具体的混合变换操作(如Mosaic、CutMix或MixUp)。 这个方法需要在子类中被具体实现,因为不同的混合变换技术有不同的实现逻辑。
labels = self._mix_transform(labels)
# 从 labels 字典中移除 "mix_labels" 键,避免后续处理中出现不必要的数据。 使用 pop 方法时,如果键不存在,不会抛出异常,而是返回 None 。
labels.pop("mix_labels", None)
# 返回经过混合变换后的 labels 。
return labels
# 这段代码定义了 BaseMixTransform 类的 __call__ 方法,实现了混合变换操作的核心逻辑。该方法的主要步骤包括: 根据给定的概率决定是否进行混合变换。 随机选择一个或多个其他图像及其标签。 对这些图像及其标签应用预变换操作(如果提供了预变换操作)。 更新类别和文本标签,确保类别ID与混合后的文本标签一致。 调用具体的混合变换方法(如Mosaic、CutMix或MixUp)。 移除混合标签,返回最终的标签。这种设计使得 BaseMixTransform 类可以作为各种混合变换技术的基础框架,通过在子类中实现具体的混合变换逻辑,可以方便地扩展和复用。
# 定义了一个抽象方法 _mix_transform ,用于实现具体的混合变换逻辑。 这个方法在子类中需要被具体实现。
def _mix_transform(self, labels):
# 将 CutMix、MixUp 或 Mosaic 增强应用于标签字典。
# 此方法应由子类实现,以执行特定的混合转换,例如 CutMix、MixUp 或 Mosaic。它会使用增强数据就地修改输入标签字典。
"""
Apply CutMix, MixUp or Mosaic augmentation to the label dictionary.
This method should be implemented by subclasses to perform specific mix transformations like CutMix, MixUp or
Mosaic. It modifies the input label dictionary in-place with the augmented data.
Args:
labels (dict): A dictionary containing image and label data. Expected to have a 'mix_labels' key
with a list of additional image and label data for mixing.
Returns:
(dict): The modified labels dictionary with augmented data after applying the mix transform.
Examples:
>>> transform = BaseMixTransform(dataset)
>>> labels = {"image": img, "bboxes": boxes, "mix_labels": [{"image": img2, "bboxes": boxes2}]}
>>> augmented_labels = transform._mix_transform(labels)
"""
# 抛出一个 NotImplementedError ,表示这个方法需要在子类中被实现。
raise NotImplementedError
# 定义了一个方法 get_indexes ,用于获取一个或多个其他图像的索引。
def get_indexes(self):
# 获取用于马赛克增强的混洗索引列表。
"""
Get a list of shuffled indexes for mosaic augmentation.
Returns:
(List[int]): A list of shuffled indexes from the dataset.
Examples:
>>> transform = BaseMixTransform(dataset)
>>> indexes = transform.get_indexes()
>>> print(indexes) # [3, 18, 7, 2]
"""
# 使用 random.randint 生成一个随机索引,范围从0到 len(self.dataset) - 1 。
return random.randint(0, len(self.dataset) - 1)
# 这段代码定义了 BaseMixTransform 类中的一个静态方法 _update_label_text ,用于更新类别和文本标签,确保混合后的标签一致。
# 使用 @staticmethod 装饰器定义了一个静态方法 _update_label_text 。
@staticmethod
# 静态方法不需要访问类或实例的属性,因此不接收 self 参数。
# 参数 1.labels 是一个包含图像及其标签的数据结构,通常是一个字典。
def _update_label_text(labels):
# 更新图像增广中混合标签的标签文本和类别 ID。
# 此方法处理输入标签字典和任何混合标签的“texts”和“cls”字段,创建一组统一的文本标签并相应地更新类别 ID。
# 参数:
# labels (dict):包含标签信息的字典,包括“texts”和“cls”字段,以及可选的“mix_labels”字段,其中包含其他标签字典。
# 返回:
# (dict):更新后的标签字典,包含统一的文本标签和更新的类别 ID。
"""
Update label text and class IDs for mixed labels in image augmentation.
This method processes the 'texts' and 'cls' fields of the input labels dictionary and any mixed labels,
creating a unified set of text labels and updating class IDs accordingly.
Args:
labels (dict): A dictionary containing label information, including 'texts' and 'cls' fields,
and optionally a 'mix_labels' field with additional label dictionaries.
Returns:
(dict): The updated labels dictionary with unified text labels and updated class IDs.
Examples:
>>> labels = {
... "texts": [["cat"], ["dog"]],
... "cls": torch.tensor([[0], [1]]),
... "mix_labels": [{"texts": [["bird"], ["fish"]], "cls": torch.tensor([[0], [1]])}],
... }
>>> updated_labels = self._update_label_text(labels)
>>> print(updated_labels["texts"])
[['cat'], ['dog'], ['bird'], ['fish']]
>>> print(updated_labels["cls"])
tensor([[0],
[1]])
>>> print(updated_labels["mix_labels"][0]["cls"])
tensor([[2],
[3]])
"""
# 检查 labels 字典中是否包含 "texts" 键。 如果不包含 "texts" 键,直接返回原始的 labels ,不进行任何处理。
if "texts" not in labels:
return labels
# 使用列表推导式,将 labels["texts"] 和 labels["mix_labels"] 中每个混合标签的 "texts" 合并成一个列表。 sum 函数用于将这些列表合并成一个大列表。
mix_texts = sum([labels["texts"]] + [x["texts"] for x in labels["mix_labels"]], [])
# 使用集合去重,确保 mix_texts 中的文本标签是唯一的。 将每个文本标签转换为元组(因为列表不可哈希,不能直接作为集合的元素),然后转换回列表。
mix_texts = list({tuple(x) for x in mix_texts})
# 创建一个从文本标签到唯一ID的映射 text2id 。 使用 enumerate 函数遍历 mix_texts ,为每个文本标签分配一个唯一的ID。
text2id = {text: i for i, text in enumerate(mix_texts)}
# 遍历 labels 和 labels["mix_labels"] 中的每个标签。
for label in [labels] + labels["mix_labels"]:
# 对于每个标签,遍历其类别ID列表 label["cls"] :
# 使用 label["cls"].squeeze(-1).tolist() 将类别ID张量转换为列表。
# 对于每个类别ID,通过 label["texts"][int(cls)] 获取对应的文本标签。
# 使用 text2id[tuple(text)] 将文本标签映射为新的类别ID,并更新 label["cls"] 。
# 将去重后的 mix_texts 赋值给每个标签的 "texts" 键。
for i, cls in enumerate(label["cls"].squeeze(-1).tolist()):
text = label["texts"][int(cls)]
label["cls"][i] = text2id[tuple(text)]
label["texts"] = mix_texts
# 返回更新后的 labels 字典。
return labels
# 这段代码定义了一个静态方法 _update_label_text ,用于更新类别和文本标签,确保混合后的标签一致。具体步骤包括: 检查 labels 中是否包含 "texts" 键,如果不包含则直接返回。 合并 labels["texts"] 和 labels["mix_labels"] 中的文本标签。 使用集合去重,确保文本标签是唯一的。 创建一个从文本标签到唯一ID的映射。 遍历 labels 和 labels["mix_labels"] ,更新每个标签的类别ID,确保类别ID与去重后的文本标签一致。 将去重后的文本标签赋值给每个标签的 "texts" 键。 返回更新后的 labels 。这种设计确保了在混合变换操作中,类别ID和文本标签的一致性和唯一性,适用于各种数据增强任务。
# 这段代码定义了一个名为 BaseMixTransform 的类,用于实现混合变换操作,如Mosaic、CutMix或MixUp。该类通过以下步骤实现混合变换: 根据给定的概率决定是否进行混合变换。 随机选择一个或多个其他图像及其标签。 对这些图像及其标签应用预变换操作(如果提供了预变换操作)。 更新类别和文本标签,确保类别ID与混合后的文本标签一致。 调用具体的混合变换方法(如Mosaic、CutMix或MixUp)。 移除混合标签,返回最终的标签。这种设计使得 BaseMixTransform 类可以作为各种混合变换技术的基础框架,通过在子类中实现具体的混合变换逻辑,可以方便地扩展和复用。
5.class CutMix(BaseMixTransform):
# 这段代码定义了一个名为 CutMix 的类,它继承自 BaseMixTransform ,用于实现CutMix数据增强技术。CutMix通过将一张图像的某个区域替换为另一张图像的相应区域,来生成新的训练样本。
# 定义了一个名为 CutMix 的类,继承自 BaseMixTransform 。
class CutMix(BaseMixTransform):
# 按照论文 https://arxiv.org/abs/1905.04899 中的说明,将 CutMix 增强应用于图像数据集。
# CutMix 通过将一幅图像中的随机矩形区域替换为另一幅图像中的相应区域来合并两幅图像,并根据混合区域的面积按比例调整标签。
# 方法:
# _mix_transform:将 CutMix 增强应用于输入标签。
# _rand_bbox:为剪切区域生成随机边界框坐标。
"""
Apply CutMix augmentation to image datasets as described in the paper https://arxiv.org/abs/1905.04899.
CutMix combines two images by replacing a random rectangular region of one image with the corresponding region from another image,
and adjusts the labels proportionally to the area of the mixed region.
Attributes:
dataset (Any): The dataset to which CutMix augmentation will be applied.
pre_transform (Callable | None): Optional transform to apply before CutMix.
p (float): Probability of applying CutMix augmentation.
beta (float): Beta distribution parameter for sampling the mixing ratio.
num_areas (int): Number of areas to try to cut and mix.
Methods:
_mix_transform: Apply CutMix augmentation to the input labels.
_rand_bbox: Generate random bounding box coordinates for the cut region.
Examples:
>>> from ultralytics.data.augment import CutMix
>>> dataset = YourDataset(...) # Your image dataset
>>> cutmix = CutMix(dataset, p=0.5)
>>> augmented_labels = cutmix(original_labels)
"""
# 定义了 CutMix 类的初始化方法 __init__ ,接收以下参数:
# 1.dataset :数据集对象,用于获取图像及其标签。
# 2.pre_transform :可选的预变换操作,用于在混合之前对图像进行预处理。
# 3.p :混合变换的概率,默认为0.0。
# 4.beta :Beta分布的参数,用于控制CutMix区域的大小。
# 5.num_areas :尝试生成的CutMix区域数量。
def __init__(self, dataset, pre_transform=None, p=0.0, beta=1.0, num_areas=3) -> None:
# 初始化 CutMix 增强对象。
"""
Initialize the CutMix augmentation object.
Args:
dataset (Any): The dataset to which CutMix augmentation will be applied.
pre_transform (Callable | None): Optional transform to apply before CutMix.
p (float): Probability of applying CutMix augmentation.
beta (float): Beta distribution parameter for sampling the mixing ratio.
num_areas (int): Number of areas to try to cut and mix.
"""
# 调用父类 BaseMixTransform 的初始化方法,初始化 dataset 、 pre_transform 和 p 。
super().__init__(dataset=dataset, pre_transform=pre_transform, p=p)
# 初始化 self.beta 和 self.num_areas 。
self.beta = beta
self.num_areas = num_areas
# 这段代码定义了 CutMix 类中的一个方法 _rand_bbox ,用于生成随机的CutMix边界框。
# 定义了一个名为 _rand_bbox 的方法,用于生成随机的CutMix边界框。 参数 1.width 和 2.height 分别表示图像的宽度和高度。
def _rand_bbox(self, width, height):
# 为剪切区域生成随机边界框坐标。
"""
Generate random bounding box coordinates for the cut region.
Args:
width (int): Width of the image.
height (int): Height of the image.
Returns:
(tuple): (x1, y1, x2, y2) coordinates of the bounding box.
"""
# Sample mixing ratio from Beta distribution
# 使用 np.random.beta 从Beta分布中采样一个混合比例 lam 。 Beta分布的参数为 self.beta ,这决定了CutMix区域的大小分布。 通常,Beta分布的参数 alpha 和 beta 相等,表示对称分布。
lam = np.random.beta(self.beta, self.beta)
# 计算CutMix区域的大小比例 cut_ratio ,取 1.0 - lam 的平方根。
cut_ratio = np.sqrt(1.0 - lam)
# 根据 cut_ratio 计算CutMix区域的宽度 cut_w 和高度 cut_h 。 使用 int 函数将宽度和高度转换为整数。
cut_w = int(width * cut_ratio)
cut_h = int(height * cut_ratio)
# Random center
# 使用 np.random.randint 随机选择CutMix区域的中心点 cx 和 cy 。 cx 的范围是 [0, width) , cy 的范围是 [0, height) 。
cx = np.random.randint(width)
cy = np.random.randint(height)
# Bounding box coordinates
# 计算CutMix区域的边界框坐标 x1 、 y1 、 x2 和 y2 。
# x1 和 y1 分别是边界框的左上角坐标, x2 和 y2 分别是边界框的右下角坐标。
# 使用 np.clip 确保边界框坐标在图像范围内,即 x1 和 y1 不小于0, x2 和 y2 不大于图像的宽度和高度。
x1 = np.clip(cx - cut_w // 2, 0, width)
y1 = np.clip(cy - cut_h // 2, 0, height)
x2 = np.clip(cx + cut_w // 2, 0, width)
y2 = np.clip(cy + cut_h // 2, 0, height)
# 返回计算得到的边界框坐标 x1 、 y1 、 x2 和 y2 。
return x1, y1, x2, y2
# 这段代码定义了一个方法 _rand_bbox ,用于生成随机的CutMix边界框。具体步骤包括: 从Beta分布中采样一个混合比例 lam 。 根据混合比例计算CutMix区域的大小比例 cut_ratio 。 根据 cut_ratio 计算CutMix区域的宽度 cut_w 和高度 cut_h 。 随机选择CutMix区域的中心点 cx 和 cy 。 计算CutMix区域的边界框坐标 x1 、 y1 、 x2 和 y2 ,并确保这些坐标在图像范围内。 返回计算得到的边界框坐标。这种方法确保了CutMix区域的大小和位置是随机的,从而增加了数据增强的多样性,有助于提高模型的泛化能力。
# 这段代码定义了 CutMix 类中的 _mix_transform 方法,用于实现CutMix混合变换逻辑。CutMix通过将一张图像的某个区域替换为另一张图像的相应区域,来生成新的训练样本。
# 定义了 _mix_transform 方法,用于实现CutMix混合变换逻辑。 参数 1.labels 是一个包含图像及其标签的数据结构,通常是一个字典。
def _mix_transform(self, labels):
# 对输入标签应用 CutMix 增强。
"""
Apply CutMix augmentation to the input labels.
Args:
labels (dict): A dictionary containing the original image and label information.
Returns:
(dict): A dictionary containing the mixed image and adjusted labels.
Examples:
>>> cutter = CutMix(dataset)
>>> mixed_labels = cutter._mix_transform(labels)
"""
# Get a random second image
# 获取图像的高度 h 和宽度 w 。
h, w = labels["img"].shape[:2]
# 使用列表推导式,调用 self._rand_bbox(w, h) 生成 self.num_areas 个CutMix区域。 将这些区域存储在 cut_areas 数组中,数据类型为 np.float32 。
cut_areas = np.asarray([self._rand_bbox(w, h) for _ in range(self.num_areas)], dtype=np.float32)
# 使用 bbox_ioa 函数计算每个CutMix区域与目标边界框的交并比(IOA)。 ioa1 的形状为 (self.num_areas, num_boxes) ,表示每个CutMix区域与每个目标边界框的交并比。
ioa1 = bbox_ioa(cut_areas, labels["instances"].bboxes) # (self.num_areas, num_boxes)
# 使用 np.nonzero 选择那些与目标边界框不重叠的CutMix区域(即交并比之和为0的区域)。
idx = np.nonzero(ioa1.sum(axis=1) <= 0)[0]
# 如果没有找到这样的区域,则直接返回原始的 labels 。
if len(idx) == 0:
return labels
# 从 labels 中获取混合图像及其标签,存储在 labels2 中。 使用 pop 方法移除 labels["mix_labels"] ,避免后续处理中出现不必要的数据。
labels2 = labels.pop("mix_labels")[0]
# 从不与目标边界框重叠的区域中随机选择一个CutMix区域。
area = cut_areas[np.random.choice(idx)] # randomly select one
# 使用 bbox_ioa 函数计算选定的CutMix区域与混合图像的边界框的交并比。 area[None] 将 area 扩展为一个二维数组,以便与 labels2["instances"].bboxes 进行计算。 使用 squeeze(0) 将结果压缩为一维数组。
ioa2 = bbox_ioa(area[None], labels2["instances"].bboxes).squeeze(0)
# 选择那些与CutMix区域有足够重叠的实例(交并比大于等于0.01或0.1)。
indexes2 = np.nonzero(ioa2 >= (0.01 if len(labels["instances"].segments) else 0.1))[0]
# 如果没有找到这样的实例,则直接返回原始的 labels 。
if len(indexes2) == 0:
return labels
# 提取混合图像中与CutMix区域有足够重叠的实例。
instances2 = labels2["instances"][indexes2]
# 将实例的边界框转换为 xyxy 格式,并反归一化到原始图像尺寸。
instances2.convert_bbox("xyxy")
instances2.denormalize(w, h)
# Apply CutMix
# 将混合图像的CutMix区域替换到目标图像的相应区域。
# 使用 astype(np.int32) 将边界框坐标转换为整数。
x1, y1, x2, y2 = area.astype(np.int32)
labels["img"][y1:y2, x1:x2] = labels2["img"][y1:y2, x1:x2]
# Restrain instances2 to the random bounding border
# 调整混合实例的边界框,使其适应CutMix区域。
# 使用 add_padding 方法将实例的边界框平移到CutMix区域的局部坐标系。
# 使用 clip 方法将实例的边界框裁剪到CutMix区域的范围内。
# 再次使用 add_padding 方法将实例的边界框平移到目标图像的全局坐标系。
instances2.add_padding(-x1, -y1)
instances2.clip(x2 - x1, y2 - y1)
instances2.add_padding(x1, y1)
# 更新目标标签,将混合实例的类别和实例信息添加到目标标签中。
# 使用 np.concatenate 将类别ID合并。
# 使用 Instances.concatenate 将实例信息合并。
labels["cls"] = np.concatenate([labels["cls"], labels2["cls"][indexes2]], axis=0)
labels["instances"] = Instances.concatenate([labels["instances"], instances2], axis=0)
# 返回更新后的 labels 。
return labels
# 这段代码定义了 CutMix 类中的 _mix_transform 方法,用于实现CutMix混合变换逻辑。CutMix通过将一张图像的某个区域替换为另一张图像的相应区域,来生成新的训练样本。该方法的主要步骤包括: 生成多个CutMix区域。 选择与目标边界框不重叠的CutMix区域。 选择与混合边界框有足够重叠的实例。 将混合图像的CutMix区域替换到目标图像的相应区域。 调整混合实例的边界框,使其适应CutMix区域。 更新目标标签,将混合实例的类别和实例信息添加到目标标签中。这种设计使得 CutMix 类可以方便地实现CutMix数据增强技术,适用于各种计算机视觉任务,有助于提高模型的泛化能力。
# 这段代码定义了一个名为 CutMix 的类,用于实现CutMix数据增强技术。CutMix通过将一张图像的某个区域替换为另一张图像的相应区域,来生成新的训练样本。该类的主要步骤包括: 生成随机的CutMix区域。 选择与目标边界框不重叠的CutMix区域。 选择与混合边界框有足够重叠的实例。 将混合图像的CutMix区域替换到目标图像的相应区域。 更新目标标签,将混合实例的类别和实例信息添加到目标标签中。这种设计使得 CutMix 类可以方便地实现CutMix数据增强技术,适用于各种计算机视觉任务。
6.class CopyPaste(BaseMixTransform):
# 这段代码定义了一个名为 CopyPaste 的类,它继承自 BaseMixTransform ,用于实现Copy-Paste数据增强技术。Copy-Paste通过将一张图像中的某些实例复制到另一张图像中,来生成新的训练样本。
# 定义了一个名为 CopyPaste 的类,继承自 BaseMixTransform 。
class CopyPaste(BaseMixTransform):
# 用于将复制粘贴增强应用于图像数据集的 CopyPaste 类。
# 该类实现了论文“简单的复制粘贴是一种强大的实例分割数据增强方法”(https://arxiv.org/abs/2012.07177) 中描述的复制粘贴增强技术。它将来自不同图像的对象组合起来以创建新的训练样本。
# 方法:
# _mix_transform:将复制粘贴增强应用于输入标签。
# __call__:将复制粘贴转换应用于图像和注释。
"""
CopyPaste class for applying Copy-Paste augmentation to image datasets.
This class implements the Copy-Paste augmentation technique as described in the paper "Simple Copy-Paste is a Strong
Data Augmentation Method for Instance Segmentation" (https://arxiv.org/abs/2012.07177). It combines objects from
different images to create new training samples.
Attributes:
dataset (Any): The dataset to which Copy-Paste augmentation will be applied.
pre_transform (Callable | None): Optional transform to apply before Copy-Paste.
p (float): Probability of applying Copy-Paste augmentation.
Methods:
_mix_transform: Apply Copy-Paste augmentation to the input labels.
__call__: Apply the Copy-Paste transformation to images and annotations.
Examples:
>>> from ultralytics.data.augment import CopyPaste
>>> dataset = YourDataset(...) # Your image dataset
>>> copypaste = CopyPaste(dataset, p=0.5)
>>> augmented_labels = copypaste(original_labels)
"""
# 定义了 CopyPaste 类的初始化方法 __init__ ,接收以下参数:
# 1.dataset :数据集对象,用于获取图像及其标签。
# 2.pre_transform :可选的预变换操作,用于在混合之前对图像进行预处理。
# 3.p :混合变换的概率,默认为0.5。
# 4.mode :Copy-Paste的模式,可以是 "flip" 或 "mixup" 。
def __init__(self, dataset=None, pre_transform=None, p=0.5, mode="flip") -> None:
# 使用数据集、pre_transform 和应用 MixUp 的概率初始化 CopyPaste 对象。
"""Initialize CopyPaste object with dataset, pre_transform, and probability of applying MixUp."""
# 调用父类 BaseMixTransform 的初始化方法,初始化 dataset 、 pre_transform 和 p 。
super().__init__(dataset=dataset, pre_transform=pre_transform, p=p)
# 使用 assert 语句检查 mode 是否为 "flip" 或 "mixup" ,如果不是则抛出异常。
assert mode in {"flip", "mixup"}, f"Expected `mode` to be `flip` or `mixup`, but got {mode}." # 预期“模式”为“翻转”或“混合”,但得到的是{mode}。
# 初始化 self.mode 。
self.mode = mode
# 定义了 _mix_transform 方法,用于实现Copy-Paste混合变换逻辑。
def _mix_transform(self, labels):
# 应用复制粘贴增强功能将另一幅图像中的对象组合到当前图像中。
"""Apply Copy-Paste augmentation to combine objects from another image into the current image."""
# 从 labels 中获取混合图像及其标签,存储在 labels2 中。
labels2 = labels["mix_labels"][0]
# 调用 self._transform(labels, labels2) 方法,执行具体的Copy-Paste变换逻辑。
return self._transform(labels, labels2)
# 这段代码定义了 CopyPaste 类的 __call__ 方法,使得类的实例可以像函数一样被调用。这个方法实现了Copy-Paste数据增强技术的主逻辑,根据指定的模式( mode )选择不同的处理方式。
# 定义了 CopyPaste 类的 __call__ 方法,使得类的实例可以像函数一样被调用。 参数 1.labels 是一个包含图像及其标签的数据结构,通常是一个字典。
def __call__(self, labels):
# 对图像及其标签应用复制粘贴增强功能。
"""Apply Copy-Paste augmentation to an image and its labels."""
# 如果目标图像中没有实例( labels["instances"].segments 为空)或者混合变换的概率 self.p 为0,则直接返回原始的 labels ,不进行任何处理。
if len(labels["instances"].segments) == 0 or self.p == 0:
return labels
# 如果 self.mode 为 "flip" ,则调用 self._transform(labels) 方法,直接对目标图像进行Copy-Paste变换。 这种模式下,混合图像使用目标图像的水平翻转版本。
if self.mode == "flip":
return self._transform(labels)
# Get index of one or three other images
# 调用 self.get_indexes() 方法获取一个或多个其他图像的索引。
indexes = self.get_indexes()
# 如果返回的 indexes 是一个整数,则将其转换为一个包含该整数的列表,以统一处理。
if isinstance(indexes, int):
indexes = [indexes]
# Get images information will be used for Mosaic or MixUp
# 使用列表推导式,通过 self.dataset.get_image_and_label(i) 方法获取每个索引对应的图像及其标签,存储在 mix_labels 列表中。
mix_labels = [self.dataset.get_image_and_label(i) for i in indexes]
# 如果 self.pre_transform 不为 None ,则对 mix_labels 中的每个图像及其标签应用预变换操作。预变换操作可以在混合之前对图像进行一些处理,如归一化、裁剪等。
if self.pre_transform is not None:
for i, data in enumerate(mix_labels):
mix_labels[i] = self.pre_transform(data)
# 将 mix_labels 添加到 labels 字典中,键为 "mix_labels" 。 这样可以在后续的混合变换操作中访问这些混合标签。
labels["mix_labels"] = mix_labels
# Update cls and texts
# 调用 self._update_label_text(labels) 方法,更新 labels 中的类别和文本标签。 这个方法通常会处理类别ID和文本标签的映射,确保混合后的标签一致。
labels = self._update_label_text(labels)
# Mosaic or MixUp
# 调用 self._mix_transform(labels) 方法,执行具体的混合变换操作。 在 CopyPaste 类中, _mix_transform 方法会调用 _transform 方法来完成Copy-Paste变换。
labels = self._mix_transform(labels)
# 从 labels 字典中移除 "mix_labels" 键,避免后续处理中出现不必要的数据。 使用 pop 方法时,如果键不存在,不会抛出异常,而是返回 None 。
labels.pop("mix_labels", None)
# 返回更新后的 labels ,包含新的图像及其标签。
return labels
# 这段代码定义了 CopyPaste 类的 __call__ 方法,实现了Copy-Paste数据增强技术的主逻辑。具体步骤包括: 检查是否跳过处理,如果目标图像中没有实例或混合变换的概率为0,则直接返回原始的 labels 。 根据 mode 参数选择不同的处理方式: 如果 mode 为 "flip" ,则直接对目标图像进行Copy-Paste变换。 如果 mode 为 "mixup" ,则从数据集中获取其他图像及其标签,进行混合变换。 获取其他图像的索引,并获取这些图像及其标签。 对这些图像及其标签应用预变换操作(如果提供了预变换操作)。 将混合标签添加到原始标签中,并更新类别和文本标签。 执行混合变换操作,将选择的实例从混合图像复制到目标图像中。 移除混合标签,返回更新后的标签。这种设计使得 CopyPaste 类可以灵活地实现Copy-Paste数据增强技术,适用于各种计算机视觉任务,有助于提高模型的泛化能力。
# 这段代码定义了 CopyPaste 类中的 _transform 方法,用于实现Copy-Paste数据增强技术的核心逻辑。Copy-Paste通过将一张图像中的某些实例复制到另一张图像中,来生成新的训练样本。
# 定义了 _transform 方法,用于实现Copy-Paste变换逻辑。 参数 1.labels1 是目标图像及其标签, 2.labels2 是混合图像及其标签(默认为空字典)。
def _transform(self, labels1, labels2={}):
# 应用复制粘贴增强功能将另一幅图像中的对象组合到当前图像中。
"""Apply Copy-Paste augmentation to combine objects from another image into the current image."""
# 从 labels1 字典中提取目标图像 im 。 im 是一个NumPy数组,表示图像的像素数据。
im = labels1["img"]
# 从 labels1 字典中提取类别信息 cls 。 cls 是一个NumPy数组,表示每个实例的类别ID。
cls = labels1["cls"]
# 获取目标图像的高度 h 和宽度 w 。 im.shape[:2] 返回图像的前两个维度,即高度和宽度。
h, w = im.shape[:2]
# 从 labels1 字典中提取实例信息 instances 。 使用 pop 方法移除 labels1["instances"] ,避免后续处理中出现不必要的数据。 instances 是一个包含实例信息的对象,通常包含边界框、分割掩码等。
instances = labels1.pop("instances")
# 调用 instances.convert_bbox 方法,将边界框的格式转换为 xyxy 格式。 xyxy 格式表示边界框的左上角坐标 (x1, y1) 和右下角坐标 (x2, y2) 。
instances.convert_bbox(format="xyxy")
# 调用 instances.denormalize 方法,将边界框的坐标从归一化坐标转换为绝对坐标。 w 和 h 分别是图像的宽度和高度,用于将归一化坐标转换为像素坐标。
instances.denormalize(w, h)
# 创建一个新的图像 im_new ,其形状与目标图像 im 相同,但所有像素值为0。 使用 np.uint8 数据类型,确保图像数据的类型一致。
im_new = np.zeros(im.shape, np.uint8)
# 从 labels2 字典中提取混合图像的实例信息 instances2 。 使用 pop 方法移除 labels2["instances"] ,避免后续处理中出现不必要的数据。 如果 labels2 中没有 "instances" 键,则 instances2 为 None 。
instances2 = labels2.pop("instances", None)
# 如果 instances2 为空,则创建一个目标实例 instances 的副本,并将其水平翻转。
if instances2 is None:
# 使用 deepcopy 确保创建的是一个独立的副本,避免修改原始实例。
instances2 = deepcopy(instances)
# 调用 instances2.fliplr(w) 方法,将实例的边界框和分割掩码水平翻转。
instances2.fliplr(w)
# 使用 bbox_ioa 函数计算混合实例 instances2 与目标实例 instances 的交并比(IOA)。 ioa 的形状为 (N, M) ,表示每个混合实例与每个目标实例的交并比。
ioa = bbox_ioa(instances2.bboxes, instances.bboxes) # intersection over area, (N, M)
# 使用 np.nonzero 选择那些与目标实例不重叠(交并比小于0.30)的混合实例。 (ioa < 0.30).all(1) 确保每个混合实例与所有目标实例的交并比都小于0.30。 np.nonzero 返回满足条件的实例索引。
indexes = np.nonzero((ioa < 0.30).all(1))[0] # (N, )
# 获取选择的实例数量 n 。
n = len(indexes)
# 对选择的实例进行排序,确保交并比最小的实例优先。
# ioa.max(1) 计算每个混合实例与目标实例的最大交并比。
# np.argsort 返回排序后的索引。
# 使用排序后的索引重新排列 indexes 。
sorted_idx = np.argsort(ioa.max(1)[indexes])
indexes = indexes[sorted_idx]
# 遍历选择的实例索引 indexes ,但只处理前 round(self.p * n) 个实例。 self.p 是复制实例的概率, n 是选择的实例数量。 round(self.p * n) 确保根据概率 self.p 选择一定数量的实例。
for j in indexes[: round(self.p * n)]:
# 将选择的实例 j 的类别信息添加到目标图像的类别信息 cls 中。
# labels2.get("cls", cls) 从 labels2 中获取类别信息,如果 labels2 中没有 "cls" 键,则使用目标图像的类别信息 cls 。
# 使用 [[j]] 获取第 j 个实例的类别信息,并使用 np.concatenate 将其添加到 cls 中。
cls = np.concatenate((cls, labels2.get("cls", cls)[[j]]), axis=0)
# 将选择的实例 j 的实例信息添加到目标图像的实例信息 instances 中。 instances2[[j]] 获取第 j 个实例的信息。 使用 Instances.concatenate 将实例信息合并。
instances = Instances.concatenate((instances, instances2[[j]]), axis=0)
# 使用 cv2.drawContours 将选择的实例 j 的分割掩码绘制到 im_new 中。
# instances2.segments[[j]] 获取第 j 个实例的分割掩码。
# 使用 astype(np.int32) 将分割掩码转换为整数类型,因为 cv2.drawContours 需要整数类型的轮廓。
# (1, 1, 1) 是绘制的颜色,这里使用白色。
# cv2.FILLED 表示填充轮廓。
cv2.drawContours(im_new, instances2.segments[[j]].astype(np.int32), -1, (1, 1, 1), cv2.FILLED)
# 从 labels2 字典中获取混合图像 labels2["img"] 。 如果 labels2 中没有 "img" 键,则使用目标图像 im 的水平翻转版本 cv2.flip(im, 1) 。 这一步是为了提供更多的数据增强选项,例如通过翻转图像来增加多样性。
result = labels2.get("img", cv2.flip(im, 1)) # augment segments
# 检查 result 的维度。如果 result 是二维的(即灰度图像),则添加一个维度,使其成为三维数组。 这是因为 cv2.flip 操作会移除灰度图像的最后一个维度,而后续操作需要三维数组。
if result.ndim == 2: # cv2.flip would eliminate the last dimension for grayscale images
result = result[..., None]
# 将 im_new 转换为布尔掩码 i 。 im_new 是一个掩码图像,其中选择的实例的区域被标记为1,其他区域为0。 使用 astype(bool) 将 im_new 转换为布尔类型,方便后续的索引操作。
i = im_new.astype(bool)
# 使用布尔掩码 i 将 result 中对应区域的像素值复制到目标图像 im 中。 这一步实现了将混合图像中的实例复制到目标图像中。
im[i] = result[i]
# 更新目标图像 labels1["img"] 为最终的图像 im 。
labels1["img"] = im
# 更新目标图像的类别信息 labels1["cls"] 为合并后的类别信息 cls 。
labels1["cls"] = cls
# 更新目标图像的实例信息 labels1["instances"] 为合并后的实例信息 instances 。
labels1["instances"] = instances
# 返回更新后的 labels1 ,包含新的图像及其标签。
return labels1
# 这段代码定义了 CopyPaste 类中的 _transform 方法,用于实现Copy-Paste数据增强技术的核心逻辑。Copy-Paste通过将一张图像中的某些实例复制到另一张图像中,来生成新的训练样本。该方法的主要步骤包括: 提取目标图像及其实例信息。 提取混合图像及其实例信息(如果为空,则使用目标实例的翻转版本)。 计算混合实例与目标实例的交并比,选择不重叠的实例。 选择部分实例,将它们的类别和实例信息添加到目标图像中。 使用掩码将混合图像的相应部分复制到目标图像中。 更新目标图像及其标签。这种设计使得 CopyPaste 类可以方便地实现Copy-Paste数据增强技术,适用于各种计算机视觉任务,有助于提高模型的泛化能力。
# CopyPaste 类是一个实现Copy-Paste数据增强技术的工具,它通过将一张图像中的某些实例复制到另一张图像中,生成新的训练样本。该类继承自 BaseMixTransform ,并根据指定的模式(如水平翻转或混合其他图像)选择不同的处理方式。在实现过程中,它首先检查目标图像中是否有实例以及混合变换的概率,然后根据模式选择合适的混合图像,提取其实例信息,并将这些实例复制到目标图像中。此外,它还更新了目标图像的类别和实例信息,确保新的训练样本在类别和实例标注上保持一致。这种数据增强方法可以增加训练数据的多样性,有助于提高模型对不同场景和实例组合的泛化能力。
7.def v8_transforms(dataset, imgsz, hyp, stretch=False):
# 这段代码定义了一个名为 v8_transforms 的函数,用于创建一个综合的数据增强流程,适用于目标检测和关键点检测任务。这个函数根据传入的参数和配置,组合了多种数据增强技术,以提高模型的泛化能力和鲁棒性。
# 定义了一个函数 v8_transforms ,接收以下参数:
# 1.dataset :数据集对象,用于获取图像及其标签。
# 2.imgsz :目标图像的尺寸。
# 3.hyp :超参数对象,包含各种数据增强的配置。
# 4.stretch :布尔值,表示是否进行拉伸变换。
def v8_transforms(dataset, imgsz, hyp, stretch=False):
# 应用一系列图像变换进行训练。
# 此函数创建一组图像增强技术组合,用于准备用于 YOLO 训练的图像。它包含马赛克、复制粘贴、随机透视、混合以及各种颜色调整等操作。
# 返回:
# (Compose):要应用于数据集的图像变换组合。
"""
Apply a series of image transformations for training.
This function creates a composition of image augmentation techniques to prepare images for YOLO training.
It includes operations such as mosaic, copy-paste, random perspective, mixup, and various color adjustments.
Args:
dataset (Dataset): The dataset object containing image data and annotations.
imgsz (int): The target image size for resizing.
hyp (Namespace): A dictionary of hyperparameters controlling various aspects of the transformations.
stretch (bool): If True, applies stretching to the image. If False, uses LetterBox resizing.
Returns:
(Compose): A composition of image transformations to be applied to the dataset.
Examples:
>>> from ultralytics.data.dataset import YOLODataset
>>> from ultralytics.utils import IterableSimpleNamespace
>>> dataset = YOLODataset(img_path="path/to/images", imgsz=640)
>>> hyp = IterableSimpleNamespace(mosaic=1.0, copy_paste=0.5, degrees=10.0, translate=0.2, scale=0.9)
>>> transforms = v8_transforms(dataset, imgsz=640, hyp=hyp)
>>> augmented_data = transforms(dataset[0])
"""
# 创建一个 Mosaic 变换对象,用于将多张图像拼接成一张大图像。 dataset 是数据集对象, imgsz 是目标图像的尺寸, p 是Mosaic变换的概率。
mosaic = Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic)
# 创建一个 RandomPerspective 变换对象,用于对图像进行随机透视变换。
# 参数包括旋转角度 degrees 、平移比例 translate 、缩放比例 scale 、剪切角度 shear 和透视变换比例 perspective 。
# 如果 stretch 为 False ,则在透视变换前应用 LetterBox 变换,将图像调整为目标尺寸。
affine = RandomPerspective(
degrees=hyp.degrees,
translate=hyp.translate,
scale=hyp.scale,
shear=hyp.shear,
perspective=hyp.perspective,
pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)),
)
# 使用 Compose 将 Mosaic 和 RandomPerspective 变换组合成一个预变换流程。
pre_transform = Compose([mosaic, affine])
# 检查 hyp.copy_paste_mode 是否为 "flip" 。 hyp.copy_paste_mode 是一个超参数,决定了Copy-Paste变换的模式。
if hyp.copy_paste_mode == "flip":
# 如果 hyp.copy_paste_mode 为 "flip" ,则在 pre_transform 的第1个位置插入一个 CopyPaste 变换。
# CopyPaste(p=hyp.copy_paste, mode=hyp.copy_paste_mode) 创建一个 CopyPaste 对象,其中 p 是Copy-Paste变换的概率, mode 是变换模式。
# 使用 insert 方法将 CopyPaste 对象插入到 pre_transform 的第1个位置。
pre_transform.insert(1, CopyPaste(p=hyp.copy_paste, mode=hyp.copy_paste_mode))
# 如果 hyp.copy_paste_mode 不是 "flip" ,则将 CopyPaste 变换添加到 pre_transform 的末尾。
# 创建一个 CopyPaste 对象,其中:
# dataset 是数据集对象。
# pre_transform 是一个 Compose 对象,包含 Mosaic 和 RandomPerspective 变换。
# p 是Copy-Paste变换的概率。
# mode 是变换模式。
# 使用 append 方法将 CopyPaste 对象添加到 pre_transform 的末尾。
else:
pre_transform.append(
CopyPaste(
dataset,
pre_transform=Compose([Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic), affine]),
p=hyp.copy_paste,
mode=hyp.copy_paste_mode,
)
)
# 从数据集的配置中获取 flip_idx ,这是一个用于关键点翻转的索引数组。 如果 flip_idx 不存在,则默认为空列表 [] 。 flip_idx 用于定义在水平翻转图像时,关键点如何重新映射。
flip_idx = dataset.data.get("flip_idx", []) # for keypoints augmentation
# 检查数据集是否使用关键点检测。如果使用关键点检测,则需要进一步检查 flip_idx 的配置。
if dataset.use_keypoints:
# 从数据集的配置中获取 kpt_shape ,这是关键点的形状信息。 如果 kpt_shape 不存在,则默认为 None 。
kpt_shape = dataset.data.get("kpt_shape", None)
# 如果 flip_idx 为空且 hyp.fliplr (水平翻转的概率)大于0,则将 hyp.fliplr 设置为0。
# 这是因为没有定义 flip_idx 时,无法正确处理关键点的水平翻转。
# 使用 LOGGER.warning 记录警告信息,提示用户 flip_idx 未定义,因此禁用了水平翻转增强。
if len(flip_idx) == 0 and hyp.fliplr > 0.0:
hyp.fliplr = 0.0
LOGGER.warning("No 'flip_idx' array defined in data.yaml, setting augmentation 'fliplr=0.0'") # data.yaml 中未定义“flip_idx”数组,设置增强“fliplr=0.0”。
# 如果 flip_idx 不为空,但其长度与 kpt_shape[0] (关键点的数量)不匹配,则抛出 ValueError 。
# 这是因为 flip_idx 的长度必须与关键点的数量一致,以确保每个关键点都有对应的翻转索引。
# 错误信息中包含了 flip_idx 和 kpt_shape[0] 的值,方便用户调试。
elif flip_idx and (len(flip_idx) != kpt_shape[0]):
raise ValueError(f"data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}") # data.yaml flip_idx={flip_idx} 长度必须等于 kpt_shape[0]={kpt_shape[0]}。
# 使用 Compose 类将多个变换方法组合成一个完整的变换流程。 Compose 类允许将多个变换方法按顺序应用到输入数据上。
return Compose(
[
# 将之前定义的 pre_transform (包含 Mosaic 和 RandomPerspective 等变换)作为第一个变换方法。 pre_transform 已经包含了如Mosaic混合、随机透视变换等操作。
pre_transform,
# 添加 MixUp 变换,它通过混合两张图像及其标签来生成新的训练样本。 dataset 是数据集对象, pre_transform 是应用在混合图像之前的变换流程, p 是MixUp变换的概率。
MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
# 添加 CutMix 变换,它通过将一张图像的某个区域替换为另一张图像的相应区域来生成新的训练样本。 dataset 是数据集对象, pre_transform 是应用在混合图像之前的变换流程, p 是CutMix变换的概率。
CutMix(dataset, pre_transform=pre_transform, p=hyp.cutmix),
# 添加 Albumentations 变换,这是一个强大的图像增强库,可以应用多种复杂的图像变换。 p=1.0 表示这些变换总是被应用。
Albumentations(p=1.0),
# 添加 RandomHSV 变换,用于随机调整图像的色调(H)、饱和度(S)和明度(V)。 hgain 、 sgain 和 vgain 分别控制色调、饱和度和明度的调整范围。
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
# 添加 RandomFlip 变换,用于随机垂直翻转图像。 direction="vertical" 指定翻转方向为垂直, p 是垂直翻转的概率。
RandomFlip(direction="vertical", p=hyp.flipud),
# 添加 RandomFlip 变换,用于随机水平翻转图像。 direction="horizontal" 指定翻转方向为水平, p 是水平翻转的概率。 flip_idx 是关键点翻转索引,用于在水平翻转时正确调整关键点的位置。
RandomFlip(direction="horizontal", p=hyp.fliplr, flip_idx=flip_idx),
]
# 关闭 Compose 的括号,表示变换流程定义完成。 注释 # transforms 说明这是一个变换流程。
) # transforms
# 这段代码定义了一个名为 v8_transforms 的函数,用于创建一个综合的数据增强流程。该流程结合了多种数据增强技术,如Mosaic、MixUp、CutMix、随机透视变换、HSV调整和随机翻转等,以提高模型的泛化能力和鲁棒性。函数根据传入的参数和配置动态构建变换流程,适用于目标检测和关键点检测任务。
8.class ToTensor:
# NOTE: keep this class for backward compatibility 注意:保留此类是为了向后兼容。
# 这段代码定义了一个名为 ToTensor 的类,用于将图像从NumPy数组转换为PyTorch张量,并进行一些预处理操作,如改变数据类型和归一化。
# 定义了一个名为 ToTensor 的类,用于将图像从NumPy数组转换为PyTorch张量。
class ToTensor:
# 将图像从 NumPy 数组转换为 PyTorch 张量。
# 此类旨在作为转换流程的一部分,例如 T.Compose([LetterBox(size), ToTensor()])。
# 方法:
# __call__:将张量转换应用于输入图像。
# 备注:
# 输入图像预计为 BGR 格式,形状为 (H, W, C)。
# 输出张量将为 RGB 格式,形状为 (C, H, W),并归一化到 [0, 1]。
"""
Convert an image from a numpy array to a PyTorch tensor.
This class is designed to be part of a transformation pipeline, e.g., T.Compose([LetterBox(size), ToTensor()]).
Attributes:
half (bool): If True, converts the image to half precision (float16).
Methods:
__call__: Apply the tensor conversion to an input image.
Examples:
>>> transform = ToTensor(half=True)
>>> img = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)
>>> tensor_img = transform(img)
>>> print(tensor_img.shape, tensor_img.dtype)
torch.Size([3, 640, 640]) torch.float16
Notes:
The input image is expected to be in BGR format with shape (H, W, C).
The output tensor will be in RGB format with shape (C, H, W), normalized to [0, 1].
"""
# 定义了类的初始化方法 __init__ ,接收一个参数 1.half ,默认值为 False 。 half 参数用于指定是否将图像数据转换为半精度浮点数( torch.float16 )。
def __init__(self, half=False):
# 初始化 ToTensor 对象,用于将图像转换为 PyTorch 张量。
# 此类旨在用作 Ultralytics YOLO 框架中图像预处理的转换流程的一部分。它将 NumPy 数组或 PIL 图像转换为 PyTorch 张量,并可选择进行半精度 (float16) 转换。
"""
Initialize the ToTensor object for converting images to PyTorch tensors.
This class is designed to be used as part of a transformation pipeline for image preprocessing in the
Ultralytics YOLO framework. It converts numpy arrays or PIL Images to PyTorch tensors, with an option
for half-precision (float16) conversion.
Args:
half (bool): If True, converts the tensor to half precision (float16).
Examples:
>>> transform = ToTensor(half=True)
>>> img = np.random.rand(640, 640, 3)
>>> tensor_img = transform(img)
>>> print(tensor_img.dtype)
torch.float16
"""
# 调用父类的初始化方法(虽然这里没有显式继承父类,但这是Python类的常规写法)。
super().__init__()
# 将 half 参数赋值给类的属性 self.half 。
self.half = half
# 定义了类的 __call__ 方法,使得类的实例可以像函数一样被调用。 参数 1.im 是一个NumPy数组,表示图像数据。
def __call__(self, im):
# 将图像从 NumPy 数组转换为 PyTorch 张量。
# 此方法将输入图像从 NumPy 数组转换为 PyTorch 张量,并应用可选的半精度转换和归一化。图像从 HWC 格式转置为 CHW 格式,颜色通道从 BGR 反转为 RGB。
# 参数:
# im (numpy.ndarray):输入图像为 NumPy 数组,形状为 (H, W, C),按 BGR 顺序排列。
# 返回:
# (torch.Tensor):转换后的图像为 float32 或 float16 类型的 PyTorch 张量,归一化到 [0, 1] 区间,形状为 (C, H, W),按 RGB 顺序排列。
"""
Transform an image from a numpy array to a PyTorch tensor.
This method converts the input image from a numpy array to a PyTorch tensor, applying optional
half-precision conversion and normalization. The image is transposed from HWC to CHW format and
the color channels are reversed from BGR to RGB.
Args:
im (numpy.ndarray): Input image as a numpy array with shape (H, W, C) in BGR order.
Returns:
(torch.Tensor): The transformed image as a PyTorch tensor in float32 or float16, normalized
to [0, 1] with shape (C, H, W) in RGB order.
Examples:
>>> transform = ToTensor(half=True)
>>> img = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)
>>> tensor_img = transform(img)
>>> print(tensor_img.shape, tensor_img.dtype)
torch.Size([3, 640, 640]) torch.float16
"""
# 使用 im.transpose((2, 0, 1)) 将图像从HWC(高度、宽度、通道)格式转换为CHW(通道、高度、宽度)格式。 使用 [::-1] 将BGR格式的图像转换为RGB格式。 使用 np.ascontiguousarray 确保数组在内存中是连续的,这对于后续的张量操作是有益的。
im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous
# 使用 torch.from_numpy 将NumPy数组转换为PyTorch张量。
im = torch.from_numpy(im) # to torch
# 如果 self.half 为 True ,则将张量的数据类型转换为半精度浮点数( torch.float16 )。 否则,将张量的数据类型转换为单精度浮点数( torch.float32 )。
im = im.half() if self.half else im.float() # uint8 to fp16/32
# 将图像数据归一化到0.0到1.0的范围,通过将每个像素值除以255.0。
im /= 255.0 # 0-255 to 0.0-1.0
# 返回转换后的PyTorch张量。
return im
# 这段代码定义了一个名为 ToTensor 的类,用于将图像从NumPy数组转换为PyTorch张量,并进行以下预处理操作: 将图像从HWC格式转换为CHW格式。 将BGR格式的图像转换为RGB格式。 确保数组在内存中是连续的。 将NumPy数组转换为PyTorch张量。 根据 half 参数,将张量的数据类型转换为半精度或单精度浮点数。 将图像数据归一化到0.0到1.0的范围。这种设计使得 ToTensor 类可以方便地集成到数据预处理流程中,适用于各种深度学习任务,特别是在使用PyTorch框架时。