plotting.py
ultralytics\utils\plotting.py
目录
4.def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
8.def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"):
9.def plot_tune_results(csv_file="tune_results.csv"):
10.def output_to_target(output, max_det=300):
11.def output_to_rotated_target(output, max_det=300):
12.def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")):
1.所需的库和模块
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
import math
import warnings
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
from PIL import __version__ as pil_version
from ultralytics.utils import LOGGER, TryExcept, ops, plt_settings, threaded
from ultralytics.utils.checks import check_font, check_version, is_ascii
from ultralytics.utils.files import increment_path
2.class Colors:
# 这段代码定义了一个名为 Colors 的类,用于管理和转换颜色代码。
# 类定义和初始化。
# Colors 类有一个初始化方法 __init__ ,它在创建类的实例时被自动调用。
class Colors:
# Ultralytics 调色板 https://docs.ultralytics.com/reference/utils/plotting/#ultralytics.utils.plotting.Colors。
# 此类提供使用 Ultralytics 调色板的方法,包括将十六进制颜色代码转换为 RGB 值。
# Ultralytics 调色板
# | 索引 | 颜色 | HEX | RGB |
# 姿势调色板
# | 索引 | 颜色 | HEX | RGB |
"""
Ultralytics color palette https://docs.ultralytics.com/reference/utils/plotting/#ultralytics.utils.plotting.Colors.
This class provides methods to work with the Ultralytics color palette, including converting hex color codes to
RGB values.
Attributes:
palette (list of tuple): List of RGB color values.
n (int): The number of colors in the palette.
pose_palette (np.ndarray): A specific color palette array with dtype np.uint8.
## Ultralytics Color Palette
| Index | Color | HEX | RGB |
|-------|-------------------------------------------------------------------|-----------|-------------------|
| 0 | <i class="fa-solid fa-square fa-2xl" style="color: #042aff;"></i> | `#042aff` | (4, 42, 255) |
| 1 | <i class="fa-solid fa-square fa-2xl" style="color: #0bdbeb;"></i> | `#0bdbeb` | (11, 219, 235) |
| 2 | <i class="fa-solid fa-square fa-2xl" style="color: #f3f3f3;"></i> | `#f3f3f3` | (243, 243, 243) |
| 3 | <i class="fa-solid fa-square fa-2xl" style="color: #00dfb7;"></i> | `#00dfb7` | (0, 223, 183) |
| 4 | <i class="fa-solid fa-square fa-2xl" style="color: #111f68;"></i> | `#111f68` | (17, 31, 104) |
| 5 | <i class="fa-solid fa-square fa-2xl" style="color: #ff6fdd;"></i> | `#ff6fdd` | (255, 111, 221) |
| 6 | <i class="fa-solid fa-square fa-2xl" style="color: #ff444f;"></i> | `#ff444f` | (255, 68, 79) |
| 7 | <i class="fa-solid fa-square fa-2xl" style="color: #cced00;"></i> | `#cced00` | (204, 237, 0) |
| 8 | <i class="fa-solid fa-square fa-2xl" style="color: #00f344;"></i> | `#00f344` | (0, 243, 68) |
| 9 | <i class="fa-solid fa-square fa-2xl" style="color: #bd00ff;"></i> | `#bd00ff` | (189, 0, 255) |
| 10 | <i class="fa-solid fa-square fa-2xl" style="color: #00b4ff;"></i> | `#00b4ff` | (0, 180, 255) |
| 11 | <i class="fa-solid fa-square fa-2xl" style="color: #dd00ba;"></i> | `#dd00ba` | (221, 0, 186) |
| 12 | <i class="fa-solid fa-square fa-2xl" style="color: #00ffff;"></i> | `#00ffff` | (0, 255, 255) |
| 13 | <i class="fa-solid fa-square fa-2xl" style="color: #26c000;"></i> | `#26c000` | (38, 192, 0) |
| 14 | <i class="fa-solid fa-square fa-2xl" style="color: #01ffb3;"></i> | `#01ffb3` | (1, 255, 179) |
| 15 | <i class="fa-solid fa-square fa-2xl" style="color: #7d24ff;"></i> | `#7d24ff` | (125, 36, 255) |
| 16 | <i class="fa-solid fa-square fa-2xl" style="color: #7b0068;"></i> | `#7b0068` | (123, 0, 104) |
| 17 | <i class="fa-solid fa-square fa-2xl" style="color: #ff1b6c;"></i> | `#ff1b6c` | (255, 27, 108) |
| 18 | <i class="fa-solid fa-square fa-2xl" style="color: #fc6d2f;"></i> | `#fc6d2f` | (252, 109, 47) |
| 19 | <i class="fa-solid fa-square fa-2xl" style="color: #a2ff0b;"></i> | `#a2ff0b` | (162, 255, 11) |
## Pose Color Palette
| Index | Color | HEX | RGB |
|-------|-------------------------------------------------------------------|-----------|-------------------|
| 0 | <i class="fa-solid fa-square fa-2xl" style="color: #ff8000;"></i> | `#ff8000` | (255, 128, 0) |
| 1 | <i class="fa-solid fa-square fa-2xl" style="color: #ff9933;"></i> | `#ff9933` | (255, 153, 51) |
| 2 | <i class="fa-solid fa-square fa-2xl" style="color: #ffb266;"></i> | `#ffb266` | (255, 178, 102) |
| 3 | <i class="fa-solid fa-square fa-2xl" style="color: #e6e600;"></i> | `#e6e600` | (230, 230, 0) |
| 4 | <i class="fa-solid fa-square fa-2xl" style="color: #ff99ff;"></i> | `#ff99ff` | (255, 153, 255) |
| 5 | <i class="fa-solid fa-square fa-2xl" style="color: #99ccff;"></i> | `#99ccff` | (153, 204, 255) |
| 6 | <i class="fa-solid fa-square fa-2xl" style="color: #ff66ff;"></i> | `#ff66ff` | (255, 102, 255) |
| 7 | <i class="fa-solid fa-square fa-2xl" style="color: #ff33ff;"></i> | `#ff33ff` | (255, 51, 255) |
| 8 | <i class="fa-solid fa-square fa-2xl" style="color: #66b2ff;"></i> | `#66b2ff` | (102, 178, 255) |
| 9 | <i class="fa-solid fa-square fa-2xl" style="color: #3399ff;"></i> | `#3399ff` | (51, 153, 255) |
| 10 | <i class="fa-solid fa-square fa-2xl" style="color: #ff9999;"></i> | `#ff9999` | (255, 153, 153) |
| 11 | <i class="fa-solid fa-square fa-2xl" style="color: #ff6666;"></i> | `#ff6666` | (255, 102, 102) |
| 12 | <i class="fa-solid fa-square fa-2xl" style="color: #ff3333;"></i> | `#ff3333` | (255, 51, 51) |
| 13 | <i class="fa-solid fa-square fa-2xl" style="color: #99ff99;"></i> | `#99ff99` | (153, 255, 153) |
| 14 | <i class="fa-solid fa-square fa-2xl" style="color: #66ff66;"></i> | `#66ff66` | (102, 255, 102) |
| 15 | <i class="fa-solid fa-square fa-2xl" style="color: #33ff33;"></i> | `#33ff33` | (51, 255, 51) |
| 16 | <i class="fa-solid fa-square fa-2xl" style="color: #00ff00;"></i> | `#00ff00` | (0, 255, 0) |
| 17 | <i class="fa-solid fa-square fa-2xl" style="color: #0000ff;"></i> | `#0000ff` | (0, 0, 255) |
| 18 | <i class="fa-solid fa-square fa-2xl" style="color: #ff0000;"></i> | `#ff0000` | (255, 0, 0) |
| 19 | <i class="fa-solid fa-square fa-2xl" style="color: #ffffff;"></i> | `#ffffff` | (255, 255, 255) |
!!! note "Ultralytics Brand Colors"
For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand). Please use the official Ultralytics colors for all marketing materials.
"""
# 颜色初始化。
def __init__(self):
# 将颜色初始化为 hex = matplotlib.colors.TABLEAU_COLORS.values()。
"""Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""
# hexs 包含了一系列的十六进制颜色代码。
hexs = (
"042AFF",
"0BDBEB",
"F3F3F3",
"00DFB7",
"111F68",
"FF6FDD",
"FF444F",
"CCED00",
"00F344",
"BD00FF",
"00B4FF",
"DD00BA",
"00FFFF",
"26C000",
"01FFB3",
"7D24FF",
"7B0068",
"FF1B6C",
"FC6D2F",
"A2FF0B",
)
# self.palette 是一个列表,包含将 hexs 中的颜色代码转换成 RGB 值后的结果。
self.palette = [self.hex2rgb(f"#{c}") for c in hexs]
# self.n 是颜色板中颜色的数量。
self.n = len(self.palette)
# self.pose_palette 是一个 NumPy 数组,包含了用于姿态估计的特定颜色代码。
self.pose_palette = np.array(
[
[255, 128, 0],
[255, 153, 51],
[255, 178, 102],
[230, 230, 0],
[255, 153, 255],
[153, 204, 255],
[255, 102, 255],
[255, 51, 255],
[102, 178, 255],
[51, 153, 255],
[255, 153, 153],
[255, 102, 102],
[255, 51, 51],
[153, 255, 153],
[102, 255, 102],
[51, 255, 51],
[0, 255, 0],
[0, 0, 255],
[255, 0, 0],
[255, 255, 255],
],
dtype=np.uint8,
)
# 调用方法。这个方法是类的一个特殊方法,通过 __call__ 实现,使得类的实例可以像函数一样被调用。
# 定义了一个名为 __call__ 的方法,这个方法接受三个参数。
# 1.self :指向类的实例的引用。
# 2.i :一个整数,代表要转换的颜色的索引或代码。
# 3.bgr :一个布尔值,默认为 False ,用来指定返回的颜色格式是BGR还是RGB。
def __call__(self, i, bgr=False):
"""Converts hex color codes to RGB values."""
# 这一行代码从实例的 palette 属性中获取颜色值。 palette 是一个列表,其中包含了颜色的RGB值。
# int(i) % self.n 计算索引 i 对 self.n 取余的结果,这样可以确保索引 i 在 palette 的范围内,即使传入的 i 值大于 palette 的长度。
c = self.palette[int(i) % self.n]
# 这一行代码根据 bgr 参数的值返回不同的颜色格式。如果 bgr 为 True ,则返回BGR格式的颜色值(即蓝色、绿色、红色),否则返回RGB格式的颜色值(即红色、绿色、蓝色)。
return (c[2], c[1], c[0]) if bgr else c
# 这段代码定义了一个名为 hex2rgb 的静态方法,它的作用是将十六进制颜色代码转换为RGB值。这个方法不需要类的实例就可以被调用,因此被定义为静态方法。
# 是一个装饰器,它指示下面的 hex2rgb 方法是一个静态方法。
@staticmethod
# 定义了一个名为 hex2rgb 的方法,这个方法接受一个参数。
# 1.h :它应该是一个字符串,代表十六进制颜色代码。
def hex2rgb(h):
"""Converts hex color codes to RGB values (i.e. default PIL order)."""
# num = int(x, base=10)
# int() 是 Python 中的一个内置函数,用于将一个整数或字符串转换成一个整数类型。如果转换成功,返回整数类型的结果;如果转换失败,会抛出一个 ValueError 异常。
# 参数 :
# x :要转换的值,可以是一个字符串或者一个数字。
# base :(可选)进制基数,用于字符串转换时指定数字的进制,默认为 10 进制。
# 返回值 :
# 返回一个整数类型的值。
# 注意事项 :
# 如果字符串包含非数字字符(除了首字符为正负号外), int() 函数将抛出 ValueError 。
# 如果传入的是浮点数, int() 函数会直接去掉小数部分,只保留整数部分。
# 如果 base 参数指定了进制基数, x 必须是一个符合该进制表示的有效字符串。
# int() 函数是处理数字和字符串转换时的基础工具,常用于数据类型转换、解析用户输入、文件解析等多种场景。
# 这一行代码是方法的主体,它使用列表推导式来转换十六进制颜色代码为RGB值,并返回一个元组。
# h[1 + i : 1 + i + 2] :这部分代码从字符串 h 中提取每两个字符,分别对应十六进制颜色代码的红色、绿色和蓝色分量。字符串索引 1 是因为十六进制颜色代码通常以 # 开头,所以实际的颜色代码从第二个字符开始。 i 的值是 0 、 2 、 4 ,分别对应红色、绿色、蓝色分量在十六进制字符串中的位置。
# int(..., 16) :这部分代码将提取的每两个字符的十六进制字符串转换为十进制整数。
# tuple(...) :这部分代码将列表推导式的结果转换为元组,因为RGB值通常以元组的形式表示。
return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
# 这个类可以用于任何需要颜色管理的应用程序,特别是在图像处理和计算机视觉领域。通过提供一个颜色板和转换方法,它简化了颜色代码的使用和管理。
3.class Annotator:
# Annotator 对象,用于在画布上添加注释。
class Annotator:
# Ultralytics Annotator 用于训练/验证马赛克和 JPG 以及预测注释。
"""
Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
Attributes:
im (Image.Image or numpy array): The image to annotate.
pil (bool): Whether to use PIL or cv2 for drawing annotations.
font (ImageFont.truetype or ImageFont.load_default): Font used for text annotations.
lw (float): Line width for drawing.
skeleton (List[List[int]]): Skeleton structure for keypoints.
limb_color (List[int]): Color palette for limbs.
kpt_color (List[int]): Color palette for keypoints.
"""
# 这段代码是 Annotator 类的构造函数 __init__ 的定义,它用于初始化类的实例。
# 1.im : 要注释的图像,可以是PIL图像对象或NumPy数组。
# 2.line_width : 绘制线条的宽度,默认为 None 。
# 3.font_size : 字体大小,默认为 None 。
# 4.font : 字体文件路径,默认为 "Arial.ttf" 。
# 5.pil : 布尔值,指示是否强制使用PIL库,默认为 False 。
# 6.example : 用于检查是否包含非ASCII字符的字符串,默认为 "abc" 。
def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"):
# 使用图像和线宽以及关键点和肢体的调色板初始化 Annotator 类。
"""Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
# 非ASCII检查。检查 example 字符串是否包含非ASCII字符。
# def is_ascii(s) -> bool: -> 用于检查传入的参数 s 是否只包含 ASCII 字符。如果字符串 s 中的所有字符都是 ASCII 字符, all() 函数将返回 True ;如果至少有一个字符不是 ASCII 字符, all() 函数将返回 False 。 -> return all(ord(c) < 128 for c in s)
non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
# 输入图像类型检查。检查输入图像是否为PIL图像对象。
input_is_pil = isinstance(im, Image.Image)
# PIL使用决策。根据是否需要非ASCII支持、输入是否为PIL图像或用户指定,决定是否使用PIL库。
self.pil = pil or non_ascii or input_is_pil
# 线条宽度设置。如果没有指定线条宽度,则根据图像尺寸自动计算。
self.lw = line_width or max(round(sum(im.size if input_is_pil else im.shape) / 2 * 0.003), 2)
# PIL库使用。
# 如果 self.pil 为 True ,则使用PIL库进行图像处理。
if self.pil: # use PIL
# Image.fromarray(arr, mode=None)
# Image.fromarray 是 Python Imaging Library (PIL) 中的一个函数,它用于将一个数组(通常是 NumPy 数组)转换成 PIL 图像对象。这个函数非常有用,因为它允许你在图像处理和分析中轻松地在数组和图像对象之间转换。
# 参数说明 :
# arr :一个数组对象,通常是 NumPy 数组。这个数组包含了图像的像素数据。
# mode :(可选)一个字符串,指定图像模式。如果未指定, fromarray 将根据数组的形状和数据类型自动选择一个模式。常见的模式包括 "L"(灰度图),"RGB"(真彩色图像),"RGBA"(带有透明度通道的真彩色图像)等。
# 返回值 :
# 返回一个 PIL 图像对象,你可以使用 PIL 提供的方法对这个对象进行操作,比如旋转、缩放、裁剪等。
# 确保图像是PIL图像对象。
self.im = im if input_is_pil else Image.fromarray(im)
# 创建一个用于绘制的PIL ImageDraw 对象。
self.draw = ImageDraw.Draw(self.im)
# 字体设置。
# 尝试加载指定的字体文件,如果失败则使用默认字体。
try:
# 根据是否需要非ASCII支持选择字体文件。
# def check_font(font="Arial.ttf"):
# -> 检查指定的字体文件是否存在于用户配置目录中。如果文件存在,则返回该文件路径。如果找到了匹配的系统字体路径,返回第一个匹配项。返回下载的字体文件路径。
# -> return file / return matches[0] / return file
font = check_font("Arial.Unicode.ttf" if non_ascii else font)
# 如果没有指定字体大小,则根据图像尺寸自动计算。
size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)
# 创建PIL字体对象。
self.font = ImageFont.truetype(str(font), size)
# 异常处理。
# 这一行代码捕获了在尝试加载自定义字体时可能发生的任何异常。如果在加载字体时遇到问题(例如,字体文件不存在或损坏),则执行以下代码。
except Exception:
# 如果在加载自定义字体时发生异常,这段代码将使用PIL库的默认字体作为后备方案。这样可以确保即使自定义字体加载失败,程序也能继续运行,而不会因为字体问题而崩溃。
self.font = ImageFont.load_default()
# Deprecation fix for w, h = getsize(string) -> _, _, w, h = getbox(string)
# 版本特定的弃用修复。
# 这一行代码检查PIL库的版本是否为9.2.0或更高版本。
if check_version(pil_version, "9.2.0"):
# 从PIL 9.2.0版本开始, ImageFont 对象的 getsize 方法被弃用,并被 getbbox 方法替代。
# getbbox 方法返回一个四元组,包含文本的边界框信息,其中第三个和第四个元素分别是文本的宽度和高度。
# 这个lambda函数是一个快捷方式,用于获取文本的 宽度 和 高度 ,与旧版本的 getsize 方法功能相同。
self.font.getsize = lambda x: self.font.getbbox(x)[2:4] # text width, height
# OpenCV库使用。
# 如果不使用PIL库,则使用OpenCV库进行图像处理。
else: # use cv2
# 确保图像数据是连续的。
assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images." # 图像不连续。将 np.ascontiguousarray(im) 应用于 Annotator 输入图像。
# 确保图像是可写的。
self.im = im if im.flags.writeable else im.copy()
# 设置字体宽度。
self.tf = max(self.lw - 1, 1) # font thickness
# 设置字体比例。
self.sf = self.lw / 3 # font scale
# Pose
# 人体骨骼连接定义。定义了人体骨骼的连接方式,是一个包含关键点索引的列表。
self.skeleton = [
[16, 14],
[14, 12],
[17, 15],
[15, 13],
[12, 13],
[6, 12],
[7, 13],
[6, 7],
[6, 8],
[7, 9],
[8, 10],
[9, 11],
[2, 3],
[1, 2],
[1, 3],
[2, 4],
[3, 5],
[4, 6],
[5, 7],
]
# 颜色设置。
# self.limb_color 和 self.kpt_color 定义了骨骼和关键点的颜色,使用了一个名为 colors.pose_palette 的颜色调色板。
self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]
self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]
# self.dark_colors 和 self.light_colors 定义了在不同背景下使用的颜色集合,以确保关键点和骨骼的可见性。
self.dark_colors = {
(235, 219, 11),
(243, 243, 243),
(183, 223, 0),
(221, 111, 255),
(0, 237, 204),
(68, 243, 0),
(255, 255, 0),
(179, 255, 1),
(11, 255, 162),
}
self.light_colors = {
(255, 42, 4),
(79, 68, 255),
(255, 0, 189),
(255, 180, 0),
(186, 0, 221),
(0, 192, 38),
(255, 36, 125),
(104, 0, 123),
(108, 27, 255),
(47, 109, 252),
(104, 31, 17),
}
# 这个构造函数为 Annotator 类提供了初始化设置,包括图像处理库的选择、字体和颜色的设置,以及人体骨骼连接的定义。这些设置为后续的图像注释提供了必要的基础。
# 这段代码定义了一个名为 get_txt_color 的方法,它是 Annotator 类的一个成员函数。这个方法的目的是确定在给定背景颜色下,文本应该使用的颜色,以确保文本的可见性。
# 参数说明。
# 1.color : 背景颜色,以RGB元组的形式给出,默认值为 (128, 128, 128) 。
# 2.txt_color : 如果背景颜色不属于预定义的深色或浅色集合,将使用的文本颜色,默认值为 (255, 255, 255) ,即白色。
def get_txt_color(self, color=(128, 128, 128), txt_color=(255, 255, 255)):
# 根据背景颜色分配文本颜色。
"""Assign text color based on background color."""
# 检查背景颜色是否为深色。如果传入的背景颜色 color 在 self.dark_colors 集合中,这意味着背景颜色较深。
if color in self.dark_colors:
# 返回深色背景的文本颜色。如果背景颜色是深色,返回一个浅色的文本颜色 (104, 31, 17) ,以确保文本在深色背景上可见。
return 104, 31, 17
# 检查背景颜色是否为浅色。如果传入的背景颜色 color 在 self.light_colors 集合中,这意味着背景颜色较浅。
elif color in self.light_colors:
# 返回浅色背景的文本颜色。如果背景颜色是浅色,返回一个深色的文本颜色 (255, 255, 255) ,即白色,以确保文本在浅色背景上可见。
return 255, 255, 255
# 返回自定义文本颜色。如果背景颜色既不在 self.dark_colors 也不在 self.light_colors 集合中,返回用户自定义的文本颜色 txt_color 。
else:
return txt_color
# 这个方法通过简单的条件判断,根据背景颜色的深浅自动选择合适的文本颜色,以提高文本的可读性。这是一种常见的设计模式,用于确保在不同背景下文本的可见性。
# 这段代码定义了 Annotator 类中的 circle_label 方法,该方法用于在图像上绘制一个圆形标签,并在圆内显示文本。
# 1.box : 一个四元组,表示要绘制圆形标签的边界框的坐标 (x1, y1, x2, y2) 。
# 2.label : 要显示在圆内的文本标签,默认为空字符串。
# 3.color : 圆形标签的背景颜色,默认为 (128, 128, 128) 。
# 4.txt_color : 文本的颜色,默认为 (255, 255, 255) ,即白色。
# 5.margin : 文本与圆边缘的间距,默认为 2 。
def circle_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), margin=2):
# 绘制一个带有背景圆的标签,背景圆位于给定边界框的中心。
"""
Draws a label with a background circle centered within a given bounding box.
Args:
box (tuple): The bounding box coordinates (x1, y1, x2, y2).
label (str): The text label to be displayed.
color (tuple, optional): The background color of the rectangle (B, G, R).
txt_color (tuple, optional): The color of the text (R, G, B).
margin (int, optional): The margin between the text and the rectangle border.
"""
# If label have more than 3 characters, skip other characters, due to circle size 如果标签有 3 个以上字符,则由于圆圈大小而跳过其他字符。
# 标签字符数限制。
# 如果 label 的长度超过3个字符,只保留前3个字符,并打印一条消息说明超出的字符将被忽略。
if len(label) > 3:
print(
f"Length of label is {len(label)}, initial 3 label characters will be considered for circle annotation!" # 标签的长度为 {len(label)},最初的 3 个标签字符将被视为圆形注释!
)
label = label[:3]
# Calculate the center of the box 计算边界框的中心。
# 计算圆心坐标。使用边界框的坐标计算圆心的 x 和 y 坐标。
x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
# (cv2.Size(width, height), baseline) = cv2.getTextSize(text, fontFace, fontScale, thickness)
# cv2.getTextSize 函数是 OpenCV 库中的一个函数,它用于计算给定文本的尺寸(宽度和高度)。这个函数在绘制文本之前非常有用,因为它可以帮助你确定文本的尺寸,从而可以进行适当的定位和布局。
# 参数 :
# text :要测量的文本字符串。
# fontFace :字体类型,可以是 OpenCV 预定义的字体,如 cv2.FONT_HERSHEY_SIMPLEX 、 cv2.FONT_HERSHEY_PLAIN 等。
# fontScale :字体缩放因子,用于调整字体大小。
# thickness :字体线条的厚度。
# 返回值 :
# cv2.Size(width, height) :一个 cv2.Size 对象,包含文本的宽度和高度。
# baseline :文本基线(即文本底部到基线的距离),这个值可以用来确定文本的位置,以确保基线对齐。
# Get the text size 获取文本大小。
# 获取文本尺寸。使用 cv2.getTextSize 函数获取文本的尺寸,这个尺寸将用于确定圆的大小。
text_size = cv2.getTextSize(str(label), cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.15, self.tf)[0]
# Calculate the required radius to fit the text with the margin 计算所需的半径以使文本与边距相符。
# 计算所需半径。根据文本尺寸和间距计算所需的圆的半径。
required_radius = int(((text_size[0] ** 2 + text_size[1] ** 2) ** 0.5) / 2) + margin
# cv2.circle(img, center, radius, color[, thickness[, lineType[, shift]]])
# cv2.circle() 是 OpenCV 库中的一个函数,用于在图像上绘制圆形。
# 参数说明 :
# img : 要绘制圆形的图像,必须是8位或浮点数的单通道或三通道图像。
# center : 圆形的中心点坐标,以 (x, y) 形式表示。
# radius : 圆形的半径。
# color : 圆形的颜色,以 BGR 格式表示(对于彩色图像),如果图像是灰度的,则为灰度值。
# thickness : 线条的厚度。如果是正数,则表示实线;如果是 0 或 -1 ,则表示填充圆。
# lineType : 线条的类型,可以是以下值之一 : cv2.LINE_4 :4连通性线(默认)。 cv2.LINE_8 :8连通性线。 cv2.LINE_AA :抗锯齿线。
# shift : 圆心坐标和半径的缩放因子,可以是0或1。如果 shift=1 ,则坐标以像素为单位;如果 shift=0 ,则坐标以图像尺寸的比例为单位。
# 返回值 :
# 该函数没有返回值,它直接在输入图像 img 上进行绘制。
# Draw the circle with the required radius 按照所需半径绘制圆。
# 绘制圆形。使用 cv2.circle 函数在计算出的圆心位置绘制一个填充的圆形,半径为 required_radius ,颜色为 color 。
cv2.circle(self.im, (x_center, y_center), required_radius, color, -1)
# Calculate the position for the text 计算文本的位置。
# 计算文本位置。计算文本在圆内的位置,确保文本水平居中,垂直居中。
text_x = x_center - text_size[0] // 2
text_y = y_center + text_size[1] // 2
# cv2.putText(img, text, org, fontFace, fontScale, color[, thickness[, lineType[, bottomLeftOrigin]]])
# cv2.putText() 是 OpenCV 库中的一个函数,用于在图像上绘制文本。
# 参数说明 :
# img : 要绘制文本的图像。
# text : 要绘制的文本字符串。
# org : 文本的起始坐标(左下角),以 (x, y) 形式表示。
# fontFace : 字体类型,OpenCV 提供了几种字体如 cv2.FONT_HERSHEY_SIMPLEX 、 cv2.FONT_HERSHEY_PLAIN 等。
# fontScale : 字体缩放因子,用于调整字体大小。
# color : 文本的颜色,以 BGR 格式表示。
# thickness : 文本线条的厚度,默认值为 1。
# lineType : 线条的类型,默认值为 cv2.LINE_8 ,可以是 cv2.LINE_4 、 cv2.LINE_8 或 cv2.LINE_AA 。
# bottomLeftOrigin : 可选参数,如果设置为 True ,则图像数据的原点在左下角,否则原点在左上角,默认为 False 。
# 返回值 :该函数没有返回值,它直接在输入图像 img 上进行绘制。
# Draw the text 绘制文本。
# 绘制文本。使用 cv2.putText 函数在圆内绘制文本,使用之前计算的文本位置、字体、缩放比例、颜色和线型。
cv2.putText(
self.im,
str(label),
(text_x, text_y),
cv2.FONT_HERSHEY_SIMPLEX,
self.sf - 0.15,
# def get_txt_color(self, color=(128, 128, 128), txt_color=(255, 255, 255)):
# -> 确定在给定背景颜色下,文本应该使用的颜色,以确保文本的可见性。如果背景颜色是深色,返回一个浅色的文本颜色。如果背景颜色是浅色,返回一个深色的文本颜色。返回自定义文本颜色。
# -> return 104, 31, 17 / return 255, 255, 255 / return txt_color
self.get_txt_color(color, txt_color),
self.tf,
lineType=cv2.LINE_AA,
)
# 这个方法结合了图像处理和文本显示的功能,通过在图像上绘制圆形标签并添加文本,提供了一种视觉上吸引人的方式来标注图像中的特定区域。这种方法在需要突出显示图像中的特定对象或区域时非常有用,例如在目标检测或人体姿态估计的应用中。
# 这段代码定义了 Annotator 类中的 text_label 方法,该方法用于在图像上绘制带有背景矩形的文本标签。
# 1.box : 一个四元组,表示要绘制文本标签的边界框的坐标 (x1, y1, x2, y2) 。
# 2.label : 要显示的文本标签,默认为空字符串。
# 3.color : 背景矩形的颜色,默认为 (128, 128, 128) 。
# 4.txt_color : 文本的颜色,默认为 (255, 255, 255) ,即白色。
# 5.margin : 背景矩形边缘与文本之间的间距,默认为 5 。
def text_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), margin=5):
# 绘制一个带有背景矩形的标签,该矩形位于给定边界框的中心。
"""
Draws a label with a background rectangle centered within a given bounding box.
Args:
box (tuple): The bounding box coordinates (x1, y1, x2, y2).
label (str): The text label to be displayed.
color (tuple, optional): The background color of the rectangle (B, G, R).
txt_color (tuple, optional): The color of the text (R, G, B).
margin (int, optional): The margin between the text and the rectangle border.
"""
# Calculate the center of the bounding box
# 计算边界框中心。使用边界框的坐标计算文本应该居中的 x 和 y 坐标。
x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
# Get the size of the text
# 获取文本尺寸。使用 cv2.getTextSize 函数获取文本的尺寸,这个尺寸将用于确定背景矩形的大小。
text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.1, self.tf)[0]
# Calculate the top-left corner of the text (to center it)
# 计算文本位置。计算文本在图像上的左上角坐标,使得文本在水平方向上居中。
text_x = x_center - text_size[0] // 2
text_y = y_center + text_size[1] // 2
# Calculate the coordinates of the background rectangle
# 计算背景矩形坐标。根据文本尺寸和间距计算背景矩形的左上角和右下角坐标。
rect_x1 = text_x - margin
rect_y1 = text_y - text_size[1] - margin
rect_x2 = text_x + text_size[0] + margin
rect_y2 = text_y + margin
# Draw the background rectangle
# 绘制背景矩形。使用 cv2.rectangle 函数在计算出的坐标位置绘制一个填充的矩形,颜色为 color 。
cv2.rectangle(self.im, (rect_x1, rect_y1), (rect_x2, rect_y2), color, -1)
# Draw the text on top of the rectangle
# 绘制文本。使用 cv2.putText 函数在背景矩形上方绘制文本,使用之前计算的文本位置、字体、缩放比例、颜色和线型。
cv2.putText(
self.im,
label,
(text_x, text_y),
cv2.FONT_HERSHEY_SIMPLEX,
self.sf - 0.1,
self.get_txt_color(color, txt_color),
self.tf,
lineType=cv2.LINE_AA,
)
# 这个方法结合了图像处理和文本显示的功能,通过在图像上绘制带有背景矩形的文本标签,提供了一种视觉上吸引人的方式来标注图像中的特定区域。这种方法在需要突出显示图像中的特定对象或区域时非常有用,例如在目标检测或图像标注的应用中。
# 通过调整 color 、 txt_color 和 margin 参数,可以改变标签的外观和样式。
# 这段代码定义了 Annotator 类中的 box_label 方法,该方法用于在图像上绘制带有文本标签的边界框。这个方法支持绘制普通矩形边界框和旋转的多边形边界框。
# 1.box : 边界框的坐标,可以是四元组(对于矩形)或多边形顶点列表(对于旋转的多边形)。
# 2.label : 要显示在边界框上的文本标签,默认为空字符串。
# 3.color : 边界框的颜色,默认为 (128, 128, 128) 。
# 4.txt_color : 文本的颜色,默认为 (255, 255, 255) ,即白色。
# 5.rotated : 布尔值,指示是否绘制旋转的多边形边界框,默认为 False 。
def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False):
# 使用标签在图像上绘制边界框。
"""
Draws a bounding box to image with label.
Args:
box (tuple): The bounding box coordinates (x1, y1, x2, y2).
label (str): The text label to be displayed.
color (tuple, optional): The background color of the rectangle (B, G, R).
txt_color (tuple, optional): The color of the text (R, G, B).
rotated (bool, optional): Variable used to check if task is OBB
"""
# 调用 get_txt_color 方法来确定基于背景颜色 color 的文本颜色 txt_color 。
# def get_txt_color(self, color=(128, 128, 128), txt_color=(255, 255, 255)):
# -> 确定在给定背景颜色下,文本应该使用的颜色,以确保文本的可见性。如果背景颜色是深色,返回一个浅色的文本颜色。如果背景颜色是浅色,返回一个深色的文本颜色。返回自定义文本颜色。
# -> return 104, 31, 17 / return 255, 255, 255 / return txt_color
txt_color = self.get_txt_color(color, txt_color)
# 检查 box 是否是 PyTorch 张量,如果是,则将其转换为列表。
if isinstance(box, torch.Tensor):
box = box.tolist()
# 检查是否使用 PIL 库或者标签 label 包含非ASCII字符。
# def is_ascii(s) -> bool: -> 用于检查传入的参数 s 是否只包含 ASCII 字符。如果字符串 s 中的所有字符都是 ASCII 字符, all() 函数将返回 True ;如果至少有一个字符不是 ASCII 字符, all() 函数将返回 False 。 -> return all(ord(c) < 128 for c in s)
if self.pil or not is_ascii(label):
# 如果边界框是旋转的,使用 PIL 的 draw.polygon 方法绘制多边形边界框。
if rotated:
p1 = box[0]
# ImageDraw.polygon(outline, fill=None, width=0)
# 在Python的PIL库(现在更多的是其分支Pillow)中, ImageDraw.polygon() 函数用于在图像上绘制一个多边形。
# 参数说明 :
# outline : 多边形顶点的序列,以坐标对 (x, y) 的形式给出。这些坐标对可以是元组列表、列表列表或者其他可迭代对象。
# fill : 填充颜色(可选)。如果为 None ,则多边形不会填充。
# width : 多边形边界线的宽度,默认为0,即不绘制边界线。
# 返回值 :该函数没有返回值,它直接在 ImageDraw 对象关联的图像上进行绘制。
# ImageDraw.polygon() 函数是Pillow库中绘制多边形的常用方法,适用于需要绘制复杂形状或自定义多边形图形的场景。
self.draw.polygon([tuple(b) for b in box], width=self.lw, outline=color) # PIL requires tuple box
# 如果边界框不是旋转的,使用 PIL 的 draw.rectangle 方法绘制矩形边界框。
else:
p1 = (box[0], box[1])
self.draw.rectangle(box, width=self.lw, outline=color) # box
# 如果提供了标签 label ,则继续执行。
if label:
# 使用 PIL 的 font.getsize 方法获取文本的宽度和高度。
w, h = self.font.getsize(label) # text width, height
# 确定文本是否适合放在边界框的上方。
outside = p1[1] >= h # label fits outside box
# 检查文本是否超出图像右侧边界,如果是,则调整文本的起始 x 坐标。
if p1[0] > self.im.size[0] - w: # size is (w, h), check if label extend beyond right side of image
p1 = self.im.size[0] - w, p1[1]
# 在边界框旁边绘制一个填充的矩形,用于放置文本。
self.draw.rectangle(
(p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1),
fill=color,
)
# self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
# 在矩形内绘制文本。
self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font)
# 这段代码是 box_label 方法中使用 OpenCV 库绘制边界框和文本的部分。
# 表示如果不用 PIL 库或者标签包含非ASCII字符,则使用 OpenCV 库进行绘制。
else: # cv2
# 检查边界框是否为旋转的多边形。
if rotated:
# 如果边界框是旋转的,将边界框的第一个顶点( box[0] )的坐标转换为整数列表。
p1 = [int(b) for b in box[0]]
# cv2.polylines(img, pts, isClosed, color[, thickness[, lineType[, shift]]])
# cv2.polylines() 函数是 OpenCV 库中用于绘制一系列多边形线条的函数。这个函数可以绘制一个或多个多边形的轮廓,可以是开放的也可以是闭合的。
# 参数 :
# img :目标图像,必须是 8 位或浮点型(单通道或3通道)。
# pts :一个列表,其中每个元素是一个 NumPy 数组,代表多边形的顶点。如果 isClosed 参数为 True ,则每个数组代表一个闭合的多边形;如果为 False ,则代表一系列线段。
# isClosed :布尔值,指示每个多边形是否闭合。
# color :线条的颜色,在 BGR 格式下。对于灰度图像,只需提供一个值。
# thickness :可选参数,指定线条的粗细。正数表示线条的粗细, cv2.FILLED 表示填充多边形。
# lineType :可选参数,指定线条的类型,可以是以下值之一 : cv2.LINE_4 - 4连通性线(8位宽度)。 cv2.LINE_8 - 8连通性线(8位宽度)。 cv2.LINE_AA - 抗锯齿线。
# shift :可选参数,表示顶点坐标的小数点位数。如果顶点坐标是浮点数,这个参数可以用来确定坐标的小数位数。
# 返回值 :该函数没有返回值,它直接在输入图像 img 上进行绘制。
# 说明 :
# pts 参数是一个列表,列表中的每个元素是一个 NumPy 数组,代表多边形的顶点。每个数组的形状应该是 (n, 1, 2) ,其中 n 是顶点的数量。
# 如果 isClosed 参数为 True ,则每个多边形的最后一个顶点会与第一个顶点连接起来。
# thickness 参数如果为负值(例如 -1 ),则多边形会被填充。
# lineType 参数影响线条的绘制方式, cv2.LINE_AA 提供了抗锯齿的效果,使得线条更加平滑。
# shift 参数用于处理浮点数坐标,如果顶点坐标是浮点数,可以通过这个参数来确定坐标的小数位数,然后向下取整到最近的整数。
# 这个函数常用于在图像上绘制形状,如在计算机视觉任务中标记检测到的对象轮廓。
# 使用 OpenCV 的 cv2.polylines 函数在图像上绘制多边形边界框。 np.asarray(box, dtype=int) 将边界框坐标转换为 NumPy 数组,并且坐标类型为整数,因为 OpenCV 要求坐标为整数。
cv2.polylines(self.im, [np.asarray(box, dtype=int)], True, color, self.lw) # cv2 requires nparray box
# 如果边界框不是旋转的,即是一个矩形。
else:
# 将边界框的左上角和右下角坐标转换为整数,并分别存储在 p1 和 p2 中。
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
# cv2.rectangle(img, pt1, pt2, color[, thickness[, lineType[, shift]]])
# cv2.rectangle() 函数是 OpenCV 库中用于在图像上绘制矩形的函数。
# 参数 :
# img :目标图像,必须是 8 位或浮点型(单通道或3通道)。
# pt1 和 pt2 :矩形的两个对角点的坐标。这两个参数是包含两个点坐标的元组或数组,每个点的坐标由 (x, y) 组成。 pt1 是矩形的一个角点, pt2 是对角点。
# color :矩形边框的颜色,在 BGR 格式下。对于灰度图像,只需提供一个值。
# thickness :可选参数,指定边框的粗细。正数表示边框的粗细, cv2.FILLED (或 -1 )表示填充矩形内部。
# lineType :可选参数,指定线条的类型,可以是以下值之一 : cv2.LINE_4 - 4连通性线(8位宽度)。 cv2.LINE_8 - 8连通性线(8位宽度)。 cv2.LINE_AA - 抗锯齿线。
# shift :可选参数,表示坐标值的小数点位数。如果坐标是浮点数,这个参数可以用来确定坐标的小数位数。
# 返回值 :
# 该函数没有返回值,它直接在输入图像 img 上进行绘制。
# 说明 :
# pt1 和 pt2 参数定义了矩形的两个对角点,OpenCV 会自动计算出矩形的其他两个角点。
# 如果 thickness 参数为正数,则绘制的是一个只有边框的矩形;如果为 cv2.FILLED 或 -1 ,则矩形内部会被填充。
# lineType 参数影响线条的绘制方式, cv2.LINE_AA 提供了抗锯齿的效果,使得线条更加平滑。
# shift 参数用于处理浮点数坐标,如果坐标是浮点数,可以通过这个参数来确定坐标的小数位数,然后向下取整到最近的整数。
# 这个函数常用于在图像上标记区域,例如在计算机视觉任务中标记检测到的对象的边界框。
# 使用 OpenCV 的 cv2.rectangle 函数在图像上绘制矩形边界框。
cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
# 如果提供了文本标签,则继续执行。
if label:
# 使用 OpenCV 的 cv2.getTextSize 函数获取文本的宽度和高度。
w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height
# 在文本高度上增加3个像素,以便在文本周围添加一些填充。
h += 3 # add pixels to pad text
# 确定文本是否适合放在边界框的上方。
outside = p1[1] >= h # label fits outside box
# 检查文本是否超出图像右侧边界。
if p1[0] > self.im.shape[1] - w: # shape is (h, w), check if label extend beyond right side of image
# 如果文本超出右侧边界,则调整文本的起始 x 坐标。
p1 = self.im.shape[1] - w, p1[1]
# 计算文本背景矩形的右下角坐标。
p2 = p1[0] + w, p1[1] - h if outside else p1[1] + h
# 使用 OpenCV 的 cv2.rectangle 函数绘制一个填充的矩形,用于放置文本。
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
# 使用 OpenCV 的 cv2.putText 函数在填充的矩形上绘制文本。文本的位置根据 outside 变量的值进行调整,以确保文本不会与矩形边缘重叠。 fontScale 设置为 self.sf ,文本颜色为 txt_color ,厚度为 self.tf 。
cv2.putText(
self.im,
label,
(p1[0], p1[1] - 2 if outside else p1[1] + h - 1),
0,
self.sf,
txt_color,
thickness=self.tf,
lineType=cv2.LINE_AA,
)
# 绘制文本标签。在两种情况下(PIL和OpenCV),都会计算文本尺寸,并根据边界框的位置和文本尺寸调整文本位置,以确保文本不会超出图像边界,并且能够正确地显示在边界框的上方或下方。
# 这个方法提供了灵活性,可以处理不同类型的边界框和文本标签,同时确保在不同情况下都能正确地绘制。通过调整 color 、 txt_color 和 rotated 参数,可以改变边界框的外观和样式。
# 这段代码定义了 Annotator 类中的 masks 方法,该方法用于将多个掩码(masks)应用到图像上,并用指定的颜色进行着色。
# 这是 masks 方法的声明,它接受以下参数。
# 1.masks :一个包含多个掩码的张量。
# 2.colors :一个包含对应掩码颜色的列表或张量。
# 3.im_gpu :一个存储在GPU上的图像张量。
# 4.alpha :掩码的透明度,默认为0.5。
# 5.retina_masks :一个布尔值,指示是否对掩码进行缩放,默认为False。
def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
# 在图像上绘制掩膜
"""
Plot masks on image.
Args:
masks (tensor): Predicted masks on cuda, shape: [n, h, w]
colors (List[List[Int]]): Colors for predicted masks, [[r, g, b] * n]
im_gpu (tensor): Image is in cuda, shape: [3, h, w], range: [0, 1]
alpha (float): Mask transparency: 0.0 fully transparent, 1.0 opaque
retina_masks (bool): Whether to use high resolution masks or not. Defaults to False.
"""
# 如果图像是PIL格式,将其转换为numpy数组。
if self.pil:
# Convert to numpy first
# numpy.asarray(a, dtype=None, order=None)
# np.asarray() 是 NumPy 库中的一个函数,它将输入的数据转换为一个 NumPy 数组。
# 参数 :
# a :要转换为数组的对象。可以是列表、元组、另一个 NumPy 数组,或者是任何其他复合数据类型。
# dtype :可选参数,指定数组元素的期望数据类型。如果未指定,则 NumPy 会根据输入数据自动推断数据类型。
# order :可选参数,指定数组的内存布局。可以是 :
# 'C' :C 风格的行主序(默认)。
# 'F' :Fortran 风格的列主序。
# 'A' :任意顺序(如果 a 是数组,则保持原来的顺序)。
# 'K' :保持输入数组的形状(如果 a 是数组,则保持原来的顺序)。
# 返回值 :
# 返回一个与输入数据具有相同形状和数据类型的 NumPy 数组。
# 说明 :
# np.asarray() 函数主要用于确保输入数据是一个 NumPy 数组。如果输入数据已经是一个 NumPy 数组,并且不需要改变其数据类型或顺序,则该函数不会复制数据,而是直接返回输入数组。
# 如果输入数据不是 NumPy 数组, np.asarray() 会创建一个新的数组,并复制数据。
# dtype 参数允许你指定数组的数据类型,这在处理不同数据类型之间的转换时非常有用。
# order 参数允许你控制数组的内存布局,这在处理大型数组或进行性能优化时可能很重要。
# 将图像转换为numpy数组并复制。
self.im = np.asarray(self.im).copy()
# 如果没有掩码,直接将GPU上的图像数据复制到 self.im 。
if len(masks) == 0:
# 将GPU上的图像数据转换为CPU上的numpy数组,并乘以255(从0-1范围转换到0-255)。
self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
# 如果图像和掩码不在同一个设备上,将图像移动到掩码所在的设备。
if im_gpu.device != masks.device:
# 将图像移动到掩码所在的设备。
im_gpu = im_gpu.to(masks.device)
# 将颜色数组转换为PyTorch张量,并除以255(从0-255范围转换到0-1)。
colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 # shape(n,3)
# 为颜色张量增加两个维度,以匹配掩码的形状。
colors = colors[:, None, None] # shape(n,1,1,3)
# 为掩码张量增加一个维度,以匹配颜色张量的形状。
masks = masks.unsqueeze(3) # shape(n,h,w,1)
# 计算掩码的颜色,通过将掩码与颜色和透明度相乘。
masks_color = masks * (colors * alpha) # shape(n,h,w,3)
# numpy.cumprod(a, axis=None, dtype=None, out=None, *, where=None)
# np.cumprod() 是 NumPy 库中的一个函数,它用于计算输入数组中元素的累积乘积。累积乘积是指从数组的第一个元素开始,逐个将元素与之前所有元素的乘积相乘的结果。
# 参数说明 :
# a : 输入数组。
# axis : 沿哪个轴计算累积乘积。如果为 None ,则计算扁平化的数组的累积乘积。
# dtype : 输出数组的数据类型。如果为 None ,则数据类型与输入数组相同。
# out : 用于存放结果的输出数组。
# where : 一个布尔数组,与输入数组形状相同,用来选择性地计算累积乘积。
# 返回值 :
# 返回一个新的数组,其形状与输入数组相同,包含了输入数组沿指定轴的累积乘积。
# 计算累积的反透明度掩码。
inv_alpha_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
# 计算所有掩码颜色的最大值。
mcs = masks_color.max(dim=0).values # shape(n,h,w,3)
# np.flipud(m)
# np.flipud() 是 NumPy 库中的一个函数,用于将数组沿垂直轴(即沿着行)翻转。这个函数的名称来源于 "flip up-down",意味着它会将数组的行顺序颠倒,使得最后一行变成第一行,倒数第二行变成第二行,以此类推。
# 参数 :
# m ( ndarray ) : 需要被翻转的数组。
# 返回值 :
# 返回一个新的数组,它是输入数组 m 沿垂直轴翻转的结果。
# 将图像通道翻转。
im_gpu = im_gpu.flip(dims=[0]) # flip channel
# 调整图像张量的形状,并确保数据是连续的。
im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
# 将反透明度掩码应用到图像上,并加上最大颜色掩码。
im_gpu = im_gpu * inv_alpha_masks[-1] + mcs
# 将图像乘以255,转换回0-255的范围。
im_mask = im_gpu * 255
# 将图像张量转换为NumPy数组,并确保数据类型为字节。
im_mask_np = im_mask.byte().cpu().numpy()
# 如果是视网膜掩码,则直接使用掩码,否则对掩码进行缩放。
# def scale_image(masks, im0_shape, ratio_pad=None): -> 它用于将图像中的掩码(masks)从一个尺寸( im1_shape )缩放到另一个尺寸( im0_shape )。返回缩放后的掩码。 -> return masks
self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape)
# 如果图像是PIL格式,将其转换回PIL格式并更新。
if self.pil:
# Convert im back to PIL and update draw
# 将numpy数组转换回PIL图像。
self.fromarray(self.im)
# 这段代码的主要功能是在图像上绘制掩码,并根据掩码的透明度和颜色进行着色。代码中使用了PyTorch的张量操作来处理图像和掩码,以及PIL库来处理图像的格式转换。
# 这段代码定义了一个名为 kpts 的方法,用于在图像上绘制关键点(keypoints)和关键点之间的连线,通常用于表示人体姿态估计的结果。
# 1.kpts :接受关键点数组。
# 2.shape :图像尺寸。
# 3.radius :圆点半径。
# 4.kpt_line :是否绘制关键点连线。
# 5.conf_thres :置信度阈值。
# 6.kpt_color :关键点颜色。
def kpts(self, kpts, shape=(640, 640), radius=None, kpt_line=True, conf_thres=0.25, kpt_color=None):
# 在图像上绘制关键点。
"""
Plot keypoints on the image.
Args:
kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence).
shape (tuple, optional): Image shape (h, w). Defaults to (640, 640).
radius (int, optional): Keypoint radius. Defaults to 5.
kpt_line (bool, optional): Draw lines between keypoints. Defaults to True.
conf_thres (float, optional): Confidence threshold. Defaults to 0.25.
kpt_color (tuple, optional): Keypoint color (B, G, R). Defaults to None.
Note:
- `kpt_line=True` currently only supports human pose plotting.
- Modifies self.im in-place.
- If self.pil is True, converts image to numpy array and back to PIL.
"""
# 如果没有指定半径,则使用成员变量 self.lw 作为半径。
radius = radius if radius is not None else self.lw
# 如果图像是PIL格式,将其转换为numpy数组。
if self.pil:
# Convert to numpy first
# 将图像转换为numpy数组并复制。
self.im = np.asarray(self.im).copy()
# 获取关键点数组的形状, nkpt 是关键点的数量, ndim 是每个关键点的维度。
nkpt, ndim = kpts.shape
# 判断是否是人体姿态估计的关键点,通常人体姿态估计有17个关键点,每个关键点有2D或3D坐标。
is_pose = nkpt == 17 and ndim in {2, 3}
# 只有是人体姿态估计时,才考虑绘制关键点连线。
kpt_line &= is_pose # `kpt_line=True` for now only supports human pose plotting
# 遍历每个关键点。
for i, k in enumerate(kpts):
# 确定关键点的颜色。
color_k = kpt_color or (self.kpt_color[i].tolist() if is_pose else colors(i))
# 获取关键点的x和y坐标。
x_coord, y_coord = k[0], k[1]
# 排除位于图像边缘的关键点,即坐标值不是图像尺寸的整数倍,这可以防止坐标超出图像边界。
if x_coord % shape[1] != 0 and y_coord % shape[0] != 0:
# 如果关键点有3个维度,即包含置信度。
if len(k) == 3:
# 获取关键点的置信度。
conf = k[2]
# 如果置信度低于阈值,则跳过该关键点。
if conf < conf_thres:
continue
# 在图像上绘制关键点。
cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, color_k, -1, lineType=cv2.LINE_AA)
# 如果需要绘制关键点连线。
if kpt_line:
# 获取关键点的维度。
ndim = kpts.shape[-1]
# 遍历骨骼结构, self.skeleton 定义了关键点之间的连接关系。
for i, sk in enumerate(self.skeleton):
# 获取连线的起点。
pos1 = (int(kpts[(sk[0] - 1), 0]), int(kpts[(sk[0] - 1), 1]))
# 获取连线的终点。
pos2 = (int(kpts[(sk[1] - 1), 0]), int(kpts[(sk[1] - 1), 1]))
# 如果关键点有3个维度。
if ndim == 3:
# 获取起点的置信度。
conf1 = kpts[(sk[0] - 1), 2]
# 获取终点的置信度。
conf2 = kpts[(sk[1] - 1), 2]
# 如果任一端点的置信度低于阈值,则跳过该连线。
if conf1 < conf_thres or conf2 < conf_thres:
continue
# 确保起点在图像范围内。
if pos1[0] % shape[1] == 0 or pos1[1] % shape[0] == 0 or pos1[0] < 0 or pos1[1] < 0:
continue
# 确保终点在图像范围内。
if pos2[0] % shape[1] == 0 or pos2[1] % shape[0] == 0 or pos2[0] < 0 or pos2[1] < 0:
continue
# 在图像上绘制关键点连线。
cv2.line(
self.im,
pos1,
pos2,
kpt_color or self.limb_color[i].tolist(),
thickness=int(np.ceil(self.lw / 2)),
lineType=cv2.LINE_AA,
)
# 如果图像是PIL格式。
if self.pil:
# Convert im back to PIL and update draw
# 将numpy数组转换回PIL图像并更新。
self.fromarray(self.im)
# 这个方法主要用于在图像上绘制关键点和关键点连线,通常用于展示人体姿态估计的结果。代码中使用了OpenCV库来绘制圆点和线条。
# 这段代码定义了一个名为 rectangle 的方法,它是用于在图像上绘制矩形的。
# 函数定义。
# 1.self :方法的第一个参数,指向类的一个实例。
# 2.xy :一个四元组,表示矩形的左上角和右下角的坐标,形式为 (x1, y1, x2, y2) 。
# 3.fill :可选参数,指定用于填充矩形内部的颜色或画笔。如果为 None ,则矩形内部不会被填充。
# 4.outline :可选参数,指定用于绘制矩形边框的颜色或画笔。如果为 None ,则矩形不会有边框。
# 5.width :可选参数,指定矩形边框的宽度,默认为 1 。
def rectangle(self, xy, fill=None, outline=None, width=1):
# 向图像添加矩形(仅限 PIL)。
"""Add rectangle to image (PIL-only)."""
# 这行代码调用了PIL图像对象的 draw 属性中的 rectangle 方法,实际上是在图像上绘制矩形。 self.draw 是一个 ImageDraw 对象,它提供了绘制形状的方法。
self.draw.rectangle(xy, fill, outline, width)
# 这段代码定义了一个名为 text 的方法,用于在图像上添加文本。这个方法可以与PIL(Python Imaging Library)或OpenCV库一起使用。
# 1.self :方法的第一个参数,指向类的一个实例。
# 2.xy :一个元组,表示文本开始绘制的坐标,形式为 (x, y) 。
# 3.text :要绘制的文本字符串。
# 4.txt_color :可选参数,指定文本颜色,默认为白色 (255, 255, 255) 。
# 5.anchor :可选参数,指定文本锚点的位置,默认为 "top"。
# 6.box_style :可选参数,指定是否以带框的样式绘制文本,默认为 False 。
def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_style=False):
# 使用 PIL 或 cv2 向图像添加文本。
"""Adds text to an image using PIL or cv2."""
# 这段代码是 text 方法中的一部分,它处理文本锚点为 "bottom" 的情况。
# 如果锚点是 "bottom",则根据字体大小调整 y 坐标,使文本从字体底部开始绘制。
if anchor == "bottom": # start y from font bottom
# 这行代码调用 self.font.getsize(text) 方法来获取文本的宽度和高度。 self.font 是一个字体对象,它提供了 getsize 方法来测量给定文本的尺寸。 w 和 h 分别存储文本的宽度和高度。
w, h = self.font.getsize(text) # text width, height
# 这行代码调整文本的 y 坐标。由于锚点是 "bottom",所以需要将 y 坐标向下移动文本的高度减去 1 像素,以确保文本从字体的底部开始绘制。这样做是为了确保文本的底部与 xy 指定的位置对齐。
xy[1] += 1 - h
# 这部分代码的作用是,当文本的锚点设置为 "bottom" 时,计算文本的尺寸,并相应地调整 y 坐标,使得文本的底部与指定的位置对齐。
# 这段代码是 text 方法中的一部分,它处理使用 PIL(Python Imaging Library)库在图像上绘制文本的情况。
# 这是一个条件判断语句,检查是否应该使用 PIL 库来处理图像。如果 self.pil 为 True ,则执行以下代码块。
if self.pil:
# 这是一个条件判断语句,检查是否需要在文本周围绘制一个边框(即 box_style 参数为 True )。
if box_style:
# 这行代码调用 self.font.getsize(text) 方法来获取文本的宽度和高度。 self.font 是一个 PIL Font 对象,它提供了 getsize 方法来测量给定文本的尺寸。 w 和 h 分别存储文本的宽度和高度。
w, h = self.font.getsize(text)
# 如果 box_style 为 True ,这行代码使用 self.draw.rectangle 方法在文本周围绘制一个矩形框。
# 矩形框的左上角坐标是 xy ,右下角坐标是 (xy[0] + w + 1, xy[1] + h + 1) ,以确保文本完全被框包围。 fill=txt_color 参数设置矩形框的填充颜色。
self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color)
# Using `txt_color` for background and draw fg with white color
# 这行代码将文本颜色设置为白色,因为背景色已经设置为 txt_color ,所以文本颜色改为白色以确保可读性。
txt_color = (255, 255, 255)
# 这是一个条件判断语句,检查文本中是否包含换行符。
if "\n" in text:
# 如果文本中包含换行符,这行代码将文本分割成多行。
lines = text.split("\n")
# 这行代码再次获取文本的高度,但这次是为了计算每行文本的高度,因为文本被分割成了多行。
_, h = self.font.getsize(text)
# 这行代码遍历每一行文本。
for line in lines:
# 对于每一行文本,这行代码使用 self.draw.text 方法在图像上绘制文本。 xy 是文本的起始坐标, line 是当前行的文本, fill=txt_color 设置文本颜色, font=self.font 设置字体。
self.draw.text(xy, line, fill=txt_color, font=self.font)
# 在绘制每一行文本后,这行代码更新 y 坐标,使其向下移动一行文本的高度,为下一行文本的绘制做准备。
xy[1] += h
# 如果文本中不包含换行符,这个 else 代码块将被执行。
else:
# 这行代码使用 self.draw.text 方法在图像上绘制单行文本。参数与上面的相同。
self.draw.text(xy, text, fill=txt_color, font=self.font)
# 总结来说,这部分代码的作用是,当使用 PIL 库时,根据是否需要绘制文本框和文本中是否包含换行符,来在图像上绘制文本。如果需要绘制文本框,则先绘制一个矩形框,并将文本颜色设置为白色以确保可读性。如果文本包含换行符,则逐行绘制文本;如果不包含,则直接绘制整个文本。
# 这段代码是 text 方法中的一部分,它处理使用 OpenCV 库在图像上绘制文本的情况。
# 这个 else 代码块与前面的 if self.pil 对应,表示如果不使用 PIL 库,则使用 OpenCV 库来处理图像。
else:
# 这是一个条件判断语句,检查是否需要在文本周围绘制一个边框(即 box_style 参数为 True )。
if box_style:
# 如果 box_style 为 True ,这行代码使用 OpenCV 的 getTextSize 函数来获取文本的宽度和高度。 text 是要绘制的文本, 0 是字体类型(这里使用默认字体), fontScale=self.sf 是字体缩放比例, thickness=self.tf 是字体厚度。
w, h = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height
# 这行代码在文本高度上增加 3 个像素,为文本添加一些填充。
h += 3 # add pixels to pad text
# 这行代码检查文本的 y 坐标是否大于文本高度,以确定文本标签是否适合在边框外面。
outside = xy[1] >= h # label fits outside box
# 这行代码计算矩形的对角点坐标。如果 outside 为 True ,则矩形的下边沿在文本下方;否则,在文本上方。
p2 = xy[0] + w, xy[1] - h if outside else xy[1] + h
# 这行代码使用 OpenCV 的 rectangle 函数在文本周围绘制一个矩形框。 self.im 是图像, xy 是矩形的左上角坐标, p2 是矩形的右下角坐标, txt_color 是填充颜色, -1 表示填充矩形, cv2.LINE_AA 表示抗锯齿线条。
cv2.rectangle(self.im, xy, p2, txt_color, -1, cv2.LINE_AA) # filled
# Using `txt_color` for background and draw fg with white color
# 这行代码将文本颜色设置为白色,因为背景色已经设置为 txt_color ,所以文本颜色改为白色以确保可读性。
txt_color = (255, 255, 255)
# 这行代码使用 OpenCV 的 putText 函数在图像上绘制文本。 self.im 是图像, text 是要绘制的文本, xy 是文本的起始坐标, 0 是字体类型(这里使用默认字体), self.sf 是字体缩放比例, txt_color 是文本颜色, thickness=self.tf 是字体厚度, cv2.LINE_AA 表示抗锯齿线条。
cv2.putText(self.im, text, xy, 0, self.sf, txt_color, thickness=self.tf, lineType=cv2.LINE_AA)
# 总结来说,这部分代码的作用是,当使用 OpenCV 库时,根据是否需要绘制文本框来在图像上绘制文本。如果需要绘制文本框,则先绘制一个矩形框,并将文本颜色设置为白色以确保可读性。然后,使用 putText 函数在图像上绘制文本。
# 这个方法提供了一个灵活的方式来在图像上添加文本,支持不同的库和样式选项。
# 这段代码定义了一个名为 fromarray 的方法,它用于更新类实例中的图像数据。这个方法可以接受一个 NumPy 数组,并将其转换为 PIL 图像对象,然后使用这个图像对象进行绘制操作。
# 1.self :方法的第一个参数,指向类的一个实例。
# 2.im :要转换的 NumPy 数组或 PIL 图像对象。
def fromarray(self, im):
# 从 numpy 数组更新 self.im。
"""Update self.im from a numpy array."""
# 这行代码检查传入的 im 参数是否已经是一个 PIL Image.Image 对象。如果是,它直接将 im 赋值给 self.im 。如果不是,它使用 Image.fromarray 方法将 NumPy 数组转换为 PIL 图像对象,并将结果赋值给 self.im 。
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
# 这行代码创建一个 ImageDraw.Draw 对象,它允许在 self.im 图像上进行绘制操作,并将这个绘制对象赋值给 self.draw 。
self.draw = ImageDraw.Draw(self.im)
# 这个方法的主要作用是将一个 NumPy 数组或一个 PIL 图像对象转换为 PIL 图像对象,并更新类实例中的 self.im 属性。同时,它还创建了一个 ImageDraw.Draw 对象,用于后续的图像绘制操作。这使得类实例可以在 PIL 图像上进行各种绘制操作,如绘制文本、线条、矩形等。
# 这段代码定义了一个名为 result 的方法,它用于返回经过注释或处理后的图像,作为 NumPy 数组。
# 1.self 方法的第一个参数,指向类的一个实例。
def result(self):
# 以数组形式返回带注释的图像。
"""Return annotated image as array."""
# 这行代码调用 NumPy 库中的 asarray 函数,将 self.im (一个 PIL Image.Image 对象)转换为一个 NumPy 数组。
# np.asarray 函数确保输入的数据被转换为一个 NumPy 数组,如果输入数据已经是一个 NumPy 数组,它将直接返回这个数组而不进行复制。
# 方法返回这个数组,允许调用者获取经过处理的图像数据。
return np.asarray(self.im)
# 这个方法的主要作用是提供一个方便的方式来获取经过类方法处理后的图像数据,使其可以以 NumPy 数组的形式被进一步处理或分析。这在图像处理和计算机视觉任务中非常有用,因为 NumPy 数组是这些领域中常用的数据格式。
# 这段代码定义了一个名为 show 的方法,它用于显示经过注释或处理后的图像。
# 1.self :方法的第一个参数,指向类的一个实例。
# 2.title :可选参数,用于指定显示图像时的窗口标题。
def show(self, title=None):
# 显示带注释的图像。
"""Show the annotated image."""
# np.asarray(self.im) 将 self.im (一个 PIL Image.Image 对象)转换为一个 NumPy 数组。
# [..., ::-1] 这是一个高级索引操作,用于反转 NumPy 数组的最后一个维度。在图像处理中,这通常用于将颜色通道从 RGB 转换为 BGR,因为 PIL 使用 RGB 格式,而 OpenCV 使用 BGR 格式。这一步确保了如果图像之后要用于 OpenCV 函数,颜色通道是正确的。
# Image.fromarray(...) 将上述 NumPy 数组再转换回一个 PIL Image.Image 对象。
# .show(title) 调用 PIL Image.Image 对象的 show 方法来显示图像。如果提供了 title 参数,它将被用作显示图像的窗口标题。
Image.fromarray(np.asarray(self.im)[..., ::-1]).show(title)
# 这个方法的主要作用是提供一个简单的方式直接在屏幕上显示经过类方法处理后的图像。通过将图像数据在 RGB 和 BGR 格式之间转换,它确保了图像在不同的图像处理库之间可以无缝地使用。这个方法对于快速检查图像处理结果非常有用。
# 这段代码定义了一个名为 save 的方法,它用于将经过注释或处理后的图像保存到文件。
# 1.self :方法的第一个参数,指向类的一个实例。
# 2.filename :可选参数,用于指定保存图像的文件名,默认为 "image.jpg" 。
def save(self, filename="image.jpg"):
# 将带注释的图像保存为“文件名”。
"""Save the annotated image to 'filename'."""
# np.asarray(self.im) 将 self.im (一个 PIL Image.Image 对象)转换为一个 NumPy 数组。这是因为 OpenCV 的 imwrite 函数需要一个 NumPy 数组作为输入。
# cv2.imwrite(filename, ...) 调用 OpenCV 的 imwrite 函数,将图像保存到指定的文件名 filename 。这个函数接受文件名和图像数据(作为一个 NumPy 数组),并将图像数据保存到磁盘上的文件中。
cv2.imwrite(filename, np.asarray(self.im))
# 这个方法的主要作用是提供一个方便的方式来保存经过类方法处理后的图像到文件系统中。它使用 OpenCV 的 imwrite 函数来执行实际的保存操作,这使得它能够保存多种格式的图像文件,并且非常高效。默认情况下,如果没有指定文件名,它将图像保存为 "image.jpg" 。
# 这段代码定义了一个名为 get_bbox_dimension 的方法,它用于计算并返回边界框(bounding box)的尺寸,包括宽度、高度和面积。
# 1.self :方法的第一个参数,指向类的一个实例。
# 2.bbox :可选参数,表示边界框的坐标。如果未提供,则默认为 None 。
def get_bbox_dimension(self, bbox=None):
# 计算边界框的面积。
"""
Calculate the area of a bounding box.
Args:
bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
Returns:
angle (degree): Degree value of angle between three points
"""
# 这行代码解包 bbox 元组,将边界框的左上角坐标( x_min, y_min )和右下角坐标( x_max, y_max )分别赋值给相应的变量。
x_min, y_min, x_max, y_max = bbox
# 计算边界框的宽度,即右下角的 x 坐标减去左上角的 x 坐标。
width = x_max - x_min
# 计算边界框的高度,即右下角的 y 坐标减去左上角的 y 坐标。
height = y_max - y_min
# 返回边界框的宽度、高度和面积。面积是通过宽度乘以高度计算得到的。
return width, height, width * height
# 这个方法的主要作用是提供边界框的尺寸信息,包括宽度、高度和面积。这些信息在图像处理和计算机视觉任务中非常有用,例如在对象检测和图像分析中评估边界框的大小。
# 如果调用时没有提供 bbox 参数,这个方法将无法正确执行,因为它依赖于 bbox 来计算尺寸。因此,确保在调用 get_bbox_dimension 方法时提供有效的边界框坐标。
# 这段代码定义了一个名为 draw_region 的方法,它用于在图像上绘制一个区域,通常是由一系列点定义的多边形区域。
# 1.self :方法的第一个参数,指向类的一个实例。
# 2.reg_pts :可选参数,表示区域的顶点坐标列表。如果未提供,则默认为 None 。
# 3.color :可选参数,用于绘制区域的颜色,默认为绿色 (0, 255, 0) 。
# 4.thickness :可选参数,用于绘制区域的线条粗细,默认为 5 。
def draw_region(self, reg_pts=None, color=(0, 255, 0), thickness=5):
# 绘制区域线。
"""
Draw region line.
Args:
reg_pts (list): Region Points (for line 2 points, for region 4 points)
color (tuple): Region Color value
thickness (int): Region area thickness value
"""
# 这行代码使用 OpenCV 的 polylines 函数在图像 self.im 上绘制多边形区域。 reg_pts 被转换为 NumPy 数组并传递给 polylines 函数, isClosed=True 表示多边形是闭合的, color 参数设置线条颜色, thickness 参数设置线条粗细。
cv2.polylines(self.im, [np.array(reg_pts, dtype=np.int32)], isClosed=True, color=color, thickness=thickness)
# Draw small circles at the corner points
# 这行代码遍历 reg_pts 中的每个点,并使用 OpenCV 的 circle 函数在每个顶点处绘制一个小圆圈。圆圈的中心是 (point[0], point[1]) ,半径是 thickness * 2 , color 参数设置圆圈的颜色, -1 表示填充圆圈。
for point in reg_pts:
cv2.circle(self.im, (point[0], point[1]), thickness * 2, color, -1) # -1 fills the circle
# 这个方法的主要作用是在图像上绘制一个由一系列点定义的闭合区域,并在每个顶点处绘制一个小圆圈,以突出显示区域的边界。这种类型的绘制通常用于图像标注任务,如在图像中标记出感兴趣的区域。通过调整 color 和 thickness 参数,可以自定义绘制区域的外观。
# 这段代码定义了一个名为 draw_centroid_and_tracks 的方法,它用于在图像上绘制轨迹的中心点(质心)和轨迹线。
# 1.self :方法的第一个参数,指向类的一个实例。
# 2.track :表示轨迹点的列表,每个元素是一个包含坐标的元组或数组。
# 3.color :可选参数,用于绘制轨迹和中心点的颜色,默认为品红色 (255, 0, 255) 。
# 4.track_thickness :可选参数,用于绘制轨迹的线条粗细,默认为 2 。
def draw_centroid_and_tracks(self, track, color=(255, 0, 255), track_thickness=2):
# 绘制质心点和轨迹。
"""
Draw centroid point and track trails.
Args:
track (list): object tracking points for trails display
color (tuple): tracks line color
track_thickness (int): track line thickness value
"""
# numpy.hstack(tup)
# np.hstack() 是 NumPy 库中的一个函数,用于水平(按列顺序)堆叠数组。
# 参数 :
# tup :一个元组或列表,包含要水平堆叠的数组。这些数组必须有相同的形状,除了第二维(列)。
# 返回值 :
# 返回一个数组,它是输入数组水平堆叠的结果。
# 说明 :
# np.hstack() 函数将多个数组水平(沿着第二维)堆叠起来。这意味着所有输入数组的第一维(行)必须相同,而第二维(列)可以不同。
# 如果输入数组的维度大于2,那么除了第一维和第二维之外,其他维度的大小必须相同。
# 该函数常用于将具有相同行数的多个数组合并为一个更宽的数组。
# 这行代码将轨迹中的所有点水平堆叠起来,转换为 int32 类型,并重新塑形为 (-1, 1, 2) 的形状,以满足 OpenCV polylines 函数的输入要求。
# polylines 函数的输入 pts 参数是一个列表,列表中的每个元素是一个 NumPy 数组,代表多边形的顶点。每个数组的形状应该是 (n, 1, 2) ,其中 n 是顶点的数量。
points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
# 这行代码使用 OpenCV 的 polylines 函数在图像 self.im 上绘制轨迹线。 points 是轨迹点的数组, isClosed=False 表示轨迹线不闭合, color 参数设置线条颜色, thickness 参数设置线条粗细。
cv2.polylines(self.im, [points], isClosed=False, color=color, thickness=track_thickness)
# 这行代码使用 OpenCV 的 circle 函数在轨迹的最后一个点(即中心点或质心)处绘制一个填充的圆圈。圆圈的中心是 (track[-1][0], track[-1][1]) ,半径是 track_thickness * 2 , color 参数设置圆圈的颜色, -1 表示填充圆圈。
cv2.circle(self.im, (int(track[-1][0]), int(track[-1][1])), track_thickness * 2, color, -1)
# 这个方法的主要作用是在图像上绘制轨迹线和轨迹的中心点。轨迹线由一系列点组成,中心点通常是轨迹的最后一个点。通过调整 color 和 track_thickness 参数,可以自定义绘制轨迹和中心点的外观。这种类型的绘制通常用于图像处理和计算机视觉任务中,如目标跟踪和运动分析。
# 这段代码定义了一个名为 queue_counts_display 的方法,它用于在图像上显示一个标签(label),并将其放置在一系列点的中心位置,同时在标签周围绘制一个矩形框。
# 这是函数的定义行。函数名为 queue_counts_display ,它接受四个参数。
# 1.self :方法的第一个参数,指向类的一个实例。
# 2.label :要显示的文本标签。
# 3.points :可选参数,表示一系列点的列表,文本将被放置在这些点的中心位置。如果未提供,则默认为 None 。
# 4.region_color :可选参数,用于绘制矩形框的颜色,默认为白色 (255, 255, 255) 。
# 5.txt_color :可选参数,用于绘制文本的颜色,默认为黑色 (0, 0, 0) 。
def queue_counts_display(self, label, points=None, region_color=(255, 255, 255), txt_color=(0, 0, 0)):
# 在以点为中心的图像上显示队列计数,字体大小和颜色可自定义。
"""
Displays queue counts on an image centered at the points with customizable font size and colors.
Args:
label (str): queue counts label
points (tuple): region points for center point calculation to display text
region_color (RGB): queue region color
txt_color (RGB): text display color
"""
# 这行代码通过列表推导式创建一个包含所有点的 x 坐标的列表。
x_values = [point[0] for point in points]
# 这行代码创建一个包含所有点的 y 坐标的列表。
y_values = [point[1] for point in points]
# 计算所有点的x坐标的平均值,作为矩形框的 中心x 坐标。
center_x = sum(x_values) // len(points)
# 计算所有点的y坐标的平均值,作为矩形框的 中心y 坐标。
center_y = sum(y_values) // len(points)
# 使用 cv2.getTextSize 函数计算文本的尺寸。
text_size = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0]
# 获取文本的宽度。
text_width = text_size[0]
# 获取文本的高度。
text_height = text_size[1]
# 计算矩形框的宽度,比文本宽度多20像素。
rect_width = text_width + 20
# 计算矩形框的高度,比文本高度多20像素。
rect_height = text_height + 20
# 计算矩形框左上角的坐标。
rect_top_left = (center_x - rect_width // 2, center_y - rect_height // 2)
# 计算矩形框右下角的坐标。
rect_bottom_right = (center_x + rect_width // 2, center_y + rect_height // 2)
# 使用 cv2.rectangle 函数在图像上绘制一个填充的矩形框。
cv2.rectangle(self.im, rect_top_left, rect_bottom_right, region_color, -1)
# 计算文本的x坐标,使其居中于矩形框。
text_x = center_x - text_width // 2
# 计算文本的y坐标,使其在矩形框的底部居中。
text_y = center_y + text_height // 2
# Draw text
# 使用 cv2.putText 函数在图像上绘制文本。参数包括图像、文本内容、坐标、字体、字体缩放比例、颜色、厚度和线型。
cv2.putText(
self.im,
label,
(text_x, text_y),
0,
fontScale=self.sf,
color=txt_color,
thickness=self.tf,
lineType=cv2.LINE_AA,
)
# 这个方法的主要作用是在图像上显示一个文本标签,并在标签周围绘制一个矩形框。文本标签被放置在提供的点列表的中心位置,矩形框的尺寸根据文本的尺寸自动调整。通过调整 region_color 和 txt_color 参数,可以自定义矩形框和文本的颜色。
# 这段代码定义了一个名为 display_objects_labels 的函数,它的作用是在图像上显示带有背景色的文本标签。
# 这是函数的定义行。函数接受以下参数 :
# 1.self :类的实例。
# 2.im0 :要在其上绘制文本和矩形的图像。
# 3.text :要显示的文本标签。
# 4.txt_color :文本的颜色。
# 5.bg_color :背景矩形的颜色。
# 6.x_center :文本水平居中的x坐标。
# 7.y_center :文本垂直居中的y坐标。
# 8.margin :文本与矩形边缘之间的间距。
def display_objects_labels(self, im0, text, txt_color, bg_color, x_center, y_center, margin):
# 在停车管理应用中显示边界框标签。
"""
Display the bounding boxes labels in parking management app.
Args:
im0 (ndarray): inference image
text (str): object/class name
txt_color (bgr color): display color for text foreground
bg_color (bgr color): display color for text background
x_center (float): x position center point for bounding box
y_center (float): y position center point for bounding box
margin (int): gap between text and rectangle for better display
"""
# 使用 cv2.getTextSize 函数计算文本的尺寸,返回值是一个包含宽度和高度的元组。
text_size = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0]
# 计算文本的x坐标,使其在 x_center 指定的水平线上居中。
text_x = x_center - text_size[0] // 2
# 计算文本的y坐标,使其在 y_center 指定的垂直线上居中,并且稍微偏下(因为文本的基线通常在底部)。
text_y = y_center + text_size[1] // 2
# 计算矩形左上角的x坐标。
rect_x1 = text_x - margin
# 计算矩形左上角的y坐标,确保文本在矩形内部居中,并且在上方留有 margin 间距。
rect_y1 = text_y - text_size[1] - margin
# 计算矩形右下角的x坐标。
rect_x2 = text_x + text_size[0] + margin
# 计算矩形右下角的y坐标,确保文本在矩形内部居中,并且在下方留有 margin 间距。
rect_y2 = text_y + margin
# 使用 cv2.rectangle 函数在图像上绘制一个填充的矩形,其颜色为 bg_color ,坐标由上一步计算得出。
cv2.rectangle(im0, (rect_x1, rect_y1), (rect_x2, rect_y2), bg_color, -1)
# 使用 cv2.putText 函数在图像上绘制文本。参数包括图像、文本内容、坐标、字体、字体缩放比例、颜色、厚度和线型。
cv2.putText(im0, text, (text_x, text_y), 0, self.sf, txt_color, self.tf, lineType=cv2.LINE_AA)
# 这个函数通常用于在图像上标注检测到的对象,并为其添加一个背景色以突出显示。通过调整 margin 参数,可以控制文本与背景矩形之间的间距。
# 这段代码定义了一个名为 display_analytics 的函数,它的作用是在图像上显示一系列分析结果,每个结果都是一个标签和值的组合。这些文本将被放置在图像的右侧,并伴有背景色。
# 这是函数的定义行。函数接受以下参数。
# 1.self :类的实例。
# 2.im0 :要在其上绘制文本和矩形的图像。
# 3.text :一个字典,包含要显示的标签和值。
# 4.txt_color :文本的颜色。
# 5.bg_color :背景矩形的颜色。
# 6.margin :文本与矩形边缘之间的间距。
def display_analytics(self, im0, text, txt_color, bg_color, margin):
# 显示停车场的总体统计数据。
"""
Display the overall statistics for parking lots.
Args:
im0 (ndarray): inference image
text (dict): labels dictionary
txt_color (bgr color): display color for text foreground
bg_color (bgr color): display color for text background
margin (int): gap between text and rectangle for better display
"""
# 计算水平间隙,为图像宽度的2%。
horizontal_gap = int(im0.shape[1] * 0.02)
# 计算垂直间隙,为图像高度的1%。
vertical_gap = int(im0.shape[0] * 0.01)
# 初始化文本的垂直偏移量。
text_y_offset = 0
# 遍历 text 字典中的每个标签和值。
for label, value in text.items():
# 格式化要显示的文本,包括标签和值。
txt = f"{label}: {value}"
# 使用 cv2.getTextSize 函数计算文本的尺寸。
text_size = cv2.getTextSize(txt, 0, self.sf, self.tf)[0]
# 如果计算出的文本尺寸小于5像素,则设置为(5, 5),以避免绘制过小的矩形。
if text_size[0] < 5 or text_size[1] < 5:
text_size = (5, 5)
# 计算文本的x坐标,使其靠右对齐,并留有间隙。
text_x = im0.shape[1] - text_size[0] - margin * 2 - horizontal_gap
# 计算文本的y坐标,考虑之前的文本高度和间隙。
text_y = text_y_offset + text_size[1] + margin * 2 + vertical_gap
# 计算矩形左上角的x坐标。
rect_x1 = text_x - margin * 2
# 计算矩形左上角的y坐标。
rect_y1 = text_y - text_size[1] - margin * 2
# 计算矩形右下角的x坐标。
rect_x2 = text_x + text_size[0] + margin * 2
# 计算矩形右下角的y坐标。
rect_y2 = text_y + margin * 2
# 使用 cv2.rectangle 函数在图像上绘制一个填充的矩形。
cv2.rectangle(im0, (rect_x1, rect_y1), (rect_x2, rect_y2), bg_color, -1)
# 使用 cv2.putText 函数在图像上绘制文本。
cv2.putText(im0, txt, (text_x, text_y), 0, self.sf, txt_color, self.tf, lineType=cv2.LINE_AA)
# 更新文本的垂直偏移量,以便下一个文本可以在当前文本的下方绘制。
text_y_offset = rect_y2
# 这个函数通常用于在图像上显示一系列分析结果,例如在视频监控或图像识别应用中显示检测到的对象的统计信息。通过遍历 text 字典,它将每个标签和值作为一对显示在图像的右侧。
# 这段代码定义了一个名为 estimate_pose_angle 的静态方法,它的作用是估计由三个点 a 、 b 和 c 形成的角度。
# 这是一个装饰器,表示下面定义的方法是一个静态方法,它不需要类的实例就可以被调用。
@staticmethod
# 这是方法的定义行。方法接受三个参数。
# 1.a :第一个点的坐标。
# 2.b :第二个点的坐标,通常是角度的顶点。
# 3.c :第三个点的坐标。
def estimate_pose_angle(a, b, c):
# 计算物体的姿势角度。
"""
Calculate the pose angle for object.
Args:
a (float) : The value of pose point a
b (float): The value of pose point b
c (float): The value o pose point c
Returns:
angle (degree): Degree value of angle between three points
"""
# 将输入的点坐标转换为NumPy数组,以便进行向量运算。
a, b, c = np.array(a), np.array(b), np.array(c)
# numpy.arctan2(y, x, *, out=None, where=True, casting='same_kind', order='K', dtype=None, subok=True[, signature, extobj])
# np.arctan2() 是 NumPy 库中的一个函数,它计算两个给定值 y 和 x 的反正切(即 arctangent 函数),返回的值是角度,单位是弧度。这个函数特别有用,因为它可以返回正确的象限角,这是标准的 arctan 函数所做不到的。
# 参数说明 :
# y :垂直坐标值。
# x :水平坐标值。
# out :(可选)输出数组。
# where :(可选)条件数组,决定 y 和 x 中哪些元素被计算。
# casting :(可选)类型转换的行为。
# order :(可选)指定元素的遍历顺序。
# dtype :(可选)输出数组的数据类型。
# subok :(可选)如果为 True,则返回的数据类型可以是子类。
# signature 和 extobj :(可选)具体用法较复杂,一般使用中不需要。
# 返回值 :
# 返回 y/x 的反正切值,结果范围是 [-pi, pi] 。
# 这个函数返回的角度是 y 和 x 形成的向量与正x轴之间的角度。
# 当 x 为 0 时,如果 y 为正,返回 pi/2 ;如果 y 为负,返回 -pi/2 。
# 当 x 为正时,返回的角度在 [-pi/2, pi/2] 范围内;当 x 为负时,返回的角度在 [-3pi/2, 3pi/2] 范围内,但会根据 y 的符号调整到正确的象限。
# 使用 np.arctan2 函数计算向量 bc 和向量 ba 与x轴的夹角,然后求这两个角度的差值,得到向量 ba 和向量 bc 之间的夹角(以弧度为单位)。
radians = np.arctan2(c[1] - b[1], c[0] - b[0]) - np.arctan2(a[1] - b[1], a[0] - b[0])
# 将弧度转换为度数,并取绝对值,因为角度通常是正数。
angle = np.abs(radians * 180.0 / np.pi)
# 如果计算出的角度大于180度,则需要调整角度,因为一个角度的补角(360度减去该角度)通常更小,且在几何上更常用。
if angle > 180.0:
# 计算角度的补角。
angle = 360 - angle
# 返回计算出的角度。
return angle
# 这个方法可以用于估计人体姿态中的角度,例如在运动分析或机器人学中。通过输入三个点的坐标,它可以计算出由这三个点形成的角的度数。
# 这段代码定义了一个名为 draw_specific_points 的方法,它的作用是在图像上绘制特定关键点。
# 这是方法的定义行。方法接受以下参数。
# 1.self :类的实例。
# 2.keypoints :关键点的列表或数组,每个关键点是一个包含坐标和可能的置信度的元组或列表。
# 3.indices :(可选)要绘制的关键点的索引列表,默认为 None 。
# 4.shape :(可选)图像的尺寸,默认为 (640, 640) 。
# 5.radius :(可选)绘制圆点的半径,默认为 2 。
# 6.conf_thres :(可选)置信度阈值,默认为 0.25 。
def draw_specific_points(self, keypoints, indices=None, shape=(640, 640), radius=2, conf_thres=0.25):
# 绘制特定关键点以进行健身房步数计数。
"""
Draw specific keypoints for gym steps counting.
Args:
keypoints (list): Keypoints data to be plotted.
indices (list, optional): Keypoint indices to be plotted. Defaults to [2, 5, 7].
shape (tuple, optional): Image size for model inference. Defaults to (640, 640).
radius (int, optional): Keypoint radius. Defaults to 2.
conf_thres (float, optional): Confidence threshold for keypoints. Defaults to 0.25.
Returns:
(numpy.ndarray): Image with drawn keypoints.
Note:
Keypoint format: [x, y] or [x, y, confidence].
Modifies self.im in-place.
"""
# 如果 indices 参数为 None ,则设置默认值。
if indices is None:
# 默认要绘制的关键点索引为 2 、 5 和 7 。
indices = [2, 5, 7]
# 遍历 keypoints 列表, i 是索引, k 是关键点。
for i, k in enumerate(keypoints):
# 如果当前关键点的索引在 indices 列表中,则继续处理。
if i in indices:
# 提取关键点的 x 和 y 坐标。
x_coord, y_coord = k[0], k[1]
# 排除位于图像边缘的关键点,即坐标值不是图像尺寸的整数倍,这可以防止坐标超出图像边界。
# x_coord % shape[1] != 0 :这部分检查 x_coord (关键点的 x 坐标)除以 shape[1] (图像的宽度)的余数是否不等于0。如果余数为0,这意味着 x_coord 正好是图像宽度的整数倍,即关键点位于图像的垂直边界上。
# 将这两个条件用 and 连接,意味着只有当 x_coord 不是图像宽度的整数倍,且 y_coord 不是图像高度的整数倍时,条件才为真。换句话说,这个条件确保关键点不在图像的边界上。
# 然而,这个条件的逻辑可能有点反直觉,因为它实际上是用来排除位于图像边界上的关键点的。如果目的是确保关键点在图像内部,通常我们会希望检查坐标是否小于图像的尺寸,即:
# x_coord < shape[1] and y_coord < shape[0]
# 这样的条件会确保关键点的坐标既不是负数,也没有超出图像的边界。所以,原始代码中的条件是为了排除位于图像边缘的关键点,而不是检查它们是否在图像内部。如果目的是确保关键点在图像内部,应该使用上面提到的条件。
if x_coord % shape[1] != 0 and y_coord % shape[0] != 0:
# 如果关键点包含三个元素(坐标和置信度),则处理置信度。
if len(k) == 3:
# 提取关键点的置信度。
conf = k[2]
# 如果置信度低于阈值,则跳过绘制该关键点。
if conf < conf_thres:
continue
# 使用 cv2.circle 函数在图像上绘制一个绿色的圆点,表示关键点。
cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, (0, 255, 0), -1, lineType=cv2.LINE_AA)
# 返回绘制了关键点的图像。
return self.im
# 这个方法通常用于在图像上标记特定的关键点,例如在计算机视觉任务中标记人体关节点。通过设置 indices 参数,可以指定要绘制哪些关键点。置信度阈值 conf_thres 用于过滤掉置信度较低的关键点。
# 这段代码定义了一个名为 plot_angle_and_count_and_stage 的方法,它的作用是在图像上绘制 角度 、 计数 和 阶段的文本信息 ,并为每个文本信息绘制一个 背景矩形 。
# 这是方法的定义行。方法接受以下参数。
# 1.self :类的实例。
# 2.angle_text :要显示的角度文本。
# 3.count_text :要显示的计数文本。
# 4.stage_text :要显示的阶段文本。
# 5.center_kpt :文本显示的中心关键点坐标。
# 6.color :背景矩形的颜色,默认为 (104, 31, 17) 。
# 7.txt_color :文本颜色,默认为 (255, 255, 255) 。
def plot_angle_and_count_and_stage(
self, angle_text, count_text, stage_text, center_kpt, color=(104, 31, 17), txt_color=(255, 255, 255)
):
# 绘制姿势角度、计数值和步数阶段。
"""
Plot the pose angle, count value and step stage.
Args:
angle_text (str): angle value for workout monitoring
count_text (str): counts value for workout monitoring
stage_text (str): stage decision for workout monitoring
center_kpt (list): centroid pose index for workout monitoring
color (tuple): text background color for workout monitoring
txt_color (tuple): text foreground color for workout monitoring
"""
# 格式化文本信息,添加空格和冒号等。 angle_text 格式化角度文本,保留两位小数。 count_text 格式化计数文本,添加前缀 "Steps :"。 stage_text 格式化阶段文本,添加空格。
angle_text, count_text, stage_text = (f" {angle_text:.2f}", f"Steps : {count_text}", f" {stage_text}")
# 这段代码是 plot_angle_and_count_and_stage 方法中用于绘制角度文本及其背景矩形的部分。
# Draw angle
# 计算角度文本的尺寸。使用 cv2.getTextSize 函数计算给定文本 angle_text 的宽度和高度。 self.sf 是字体缩放比例, self.tf 是字体厚度。函数返回一个元组,其中包含文本的宽度和高度,这里我们只关心前两个值。
(angle_text_width, angle_text_height), _ = cv2.getTextSize(angle_text, 0, self.sf, self.tf)
# 确定角度文本的位置。将传入的 center_kpt (中心关键点坐标)转换为整数,并用作文本的位置。
angle_text_position = (int(center_kpt[0]), int(center_kpt[1]))
# 确定角度背景的位置。计算背景矩形的左上角位置。这里将背景矩形放置在文本的正上方,距离文本5个像素。
angle_background_position = (angle_text_position[0], angle_text_position[1] - angle_text_height - 5)
# 确定角度背景的尺寸。计算 背景矩形 的 宽度 和 高度 。这里在文本宽度的基础上增加了10个像素(每侧5个像素),在文本高度的基础上增加了10个像素加上字体厚度的两倍,以确保背景矩形能够完全覆盖文本。
angle_background_size = (angle_text_width + 2 * 5, angle_text_height + 2 * 5 + (self.tf * 2))
# 绘制角度背景矩形。使用 cv2.rectangle 函数在图像 self.im 上绘制一个填充的矩形。 color 参数指定矩形的颜色, -1 表示矩形填充。
cv2.rectangle(
self.im,
angle_background_position,
(
angle_background_position[0] + angle_background_size[0],
angle_background_position[1] + angle_background_size[1],
),
color,
-1,
)
# 绘制角度文本。使用 cv2.putText 函数在图像 self.im 上绘制文本。 0 表示字体类型(默认字体), self.sf 是字体缩放比例, txt_color 是文本颜色, self.tf 是字体厚度。
cv2.putText(self.im, angle_text, angle_text_position, 0, self.sf, txt_color, self.tf)
# 这段代码将角度文本绘制在图像上,并为其添加一个背景矩形,使得文本更加突出且易于阅读。
# 这段代码是 plot_angle_and_count_and_stage 方法中用于绘制计数(Counts)文本及其背景矩形的部分。
# Draw Counts
# 计算计数文本的尺寸。使用 cv2.getTextSize 函数计算给定文本 count_text 的宽度和高度。这里 self.sf 是字体缩放比例, self.tf 是字体厚度。
(count_text_width, count_text_height), _ = cv2.getTextSize(count_text, 0, self.sf, self.tf)
# 确定计数文本的位置。确定计数文本的位置,位于角度文本的正下方,间隔20像素。
count_text_position = (angle_text_position[0], angle_text_position[1] + angle_text_height + 20)
# 确定计数背景的位置。计算背景矩形的左上角位置。这里将背景矩形放置在角度背景的正下方,间隔5个像素。
count_background_position = (
angle_background_position[0],
angle_background_position[1] + angle_background_size[1] + 5,
)
# 确定计数背景的尺寸。计算背景矩形的宽度和高度。这里在文本宽度的基础上增加了10个像素,在文本高度的基础上增加了10个像素加上字体厚度,以确保背景矩形能够完全覆盖文本。
count_background_size = (count_text_width + 10, count_text_height + 10 + self.tf)
# 绘制计数背景矩形。使用 cv2.rectangle 函数在图像 self.im 上绘制一个填充的矩形。 color 参数指定矩形的颜色, -1 表示矩形填充。
cv2.rectangle(
self.im,
count_background_position,
(
count_background_position[0] + count_background_size[0],
count_background_position[1] + count_background_size[1],
),
color,
-1,
)
# 绘制计数文本。使用 cv2.putText 函数在图像 self.im 上绘制文本。 0 表示字体类型(默认字体), self.sf 是字体缩放比例, txt_color 是文本颜色, self.tf 是字体厚度。
cv2.putText(self.im, count_text, count_text_position, 0, self.sf, txt_color, self.tf)
# 这段代码将计数文本绘制在图像上,并为其添加一个背景矩形,使得文本更加突出且易于阅读。通过适当的位置和尺寸计算,确保文本和背景矩形在视觉上整洁且有序。
# 这段代码是 plot_angle_and_count_and_stage 方法中用于绘制阶段(Stage)文本及其背景矩形的部分。
# Draw Stage
# 计算阶段文本的尺寸。使用 cv2.getTextSize 函数计算给定文本 stage_text 的宽度和高度。这里 self.sf 是字体缩放比例, self.tf 是字体厚度。
(stage_text_width, stage_text_height), _ = cv2.getTextSize(stage_text, 0, self.sf, self.tf)
# 确定阶段文本的位置。确定阶段文本的位置,位于角度文本和计数文本的正下方,间隔40像素。
stage_text_position = (int(center_kpt[0]), int(center_kpt[1]) + angle_text_height + count_text_height + 40)
# 确定阶段背景的位置。计算背景矩形的左上角位置。这里将背景矩形放置在阶段文本的正上方,距离文本5个像素。
stage_background_position = (stage_text_position[0], stage_text_position[1] - stage_text_height - 5)
# 确定阶段背景的尺寸。计算背景矩形的宽度和高度。这里在文本宽度的基础上增加了10个像素,在文本高度的基础上增加了10个像素,以确保背景矩形能够完全覆盖文本。
stage_background_size = (stage_text_width + 10, stage_text_height + 10)
# 绘制阶段背景矩形。使用 cv2.rectangle 函数在图像 self.im 上绘制一个填充的矩形。 color 参数指定矩形的颜色, -1 表示矩形填充。
cv2.rectangle(
self.im,
stage_background_position,
(
stage_background_position[0] + stage_background_size[0],
stage_background_position[1] + stage_background_size[1],
),
color,
-1,
)
# 绘制阶段文本。使用 cv2.putText 函数在图像 self.im 上绘制文本。 0 表示字体类型(默认字体), self.sf 是字体缩放比例, txt_color 是文本颜色, self.tf 是字体厚度。
cv2.putText(self.im, stage_text, stage_text_position, 0, self.sf, txt_color, self.tf)
# 这段代码将阶段文本绘制在图像上,并为其添加一个背景矩形,使得文本更加突出且易于阅读。通过适当的位置和尺寸计算,确保文本和背景矩形在视觉上整洁且有序。
# 这个方法通常用于在图像上显示与关键点相关的额外信息,例如在运动分析或健康监测应用中显示用户的运动角度、步数和运动阶段。通过绘制背景矩形和文本,可以清晰地展示这些信息,提高图像的可读性和信息的传达效果。
# 这段代码定义了一个名为 seg_bbox 的方法,它的作用是在图像上绘制一个由掩码(mask)定义的多边形区域,并在该区域旁边添加一个标签。
# 这是方法的定义行。方法接受以下参数 :
# 1.self :类的实例。
# 2.mask :定义多边形区域的掩码,应该是一个点的数组。
# 3.mask_color :多边形区域的颜色,默认为 (255, 0, 255) ,即洋红色。
# 4.label :要显示的标签文本,默认为 None 。
# 5.txt_color :文本颜色,默认为 (255, 255, 255) ,即白色。
def seg_bbox(self, mask, mask_color=(255, 0, 255), label=None, txt_color=(255, 255, 255)):
# 用于以边界框形状绘制分割对象的函数。
"""
Function for drawing segmented object in bounding box shape.
Args:
mask (list): masks data list for instance segmentation area plotting
mask_color (RGB): mask foreground color
label (str): Detection label text
txt_color (RGB): text color
"""
# 使用 cv2.polylines 函数在图像 self.im 上绘制多边形区域。 np.int32([mask]) 确保点的坐标是整数类型, isClosed=True 表示多边形是封闭的, color 是多边形的颜色, thickness 是线条的厚度。
cv2.polylines(self.im, [np.int32([mask])], isClosed=True, color=mask_color, thickness=2)
# 使用 cv2.getTextSize 函数计算标签文本的尺寸。 self.sf 是字体缩放比例, self.tf 是字体厚度。
text_size, _ = cv2.getTextSize(label, 0, self.sf, self.tf)
# 使用 cv2.rectangle 函数在图像上绘制一个矩形,该矩形用于作为标签的背景。矩形的位置和尺寸基于 mask 的第一个点的坐标和文本尺寸计算得出。
cv2.rectangle(
self.im,
(int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10),
(int(mask[0][0]) + text_size[0] // 2 + 10, int(mask[0][1] + 10)),
mask_color,
-1,
)
# 如果提供了 label ,则执行以下操作。
if label:
# 使用 cv2.putText 函数在图像上绘制文本。文本的位置基于 mask 的第一个点的坐标和文本尺寸计算得出, 0 表示字体类型(默认字体), self.sf 是字体缩放比例, txt_color 是文本颜色, self.tf 是字体厚度。
cv2.putText(
self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1])), 0, self.sf, txt_color, self.tf
)
# 这个方法通常用于在图像上标记特定的区域,并为其添加描述性标签。例如,在对象检测或分割任务中,可以用于突出显示检测到的对象,并显示其类别或ID。
# 这段代码定义了一个名为 plot_distance_and_line 的方法,它的作用是在图像上绘制两点之间的距离、一条直线以及这两个点的圆心标记。
# 这是方法的定义行。方法接受以下参数 :
# 1.self :类的实例。
# 2.pixels_distance :两个中心点之间的像素距离。
# 3.centroids :两个中心点的坐标,格式为 (x1, y1), (x2, y2) 。
# 4.line_color :线段的颜色。
# 5.centroid_color :中心点的颜色。
def plot_distance_and_line(self, pixels_distance, centroids, line_color, centroid_color):
# 在框架上绘制距离和线。
"""
Plot the distance and line on frame.
Args:
pixels_distance (float): Pixels distance between two bbox centroids.
centroids (list): Bounding box centroids data.
line_color (RGB): Distance line color.
centroid_color (RGB): Bounding box centroid color.
"""
# Get the text size
# 计算文本尺寸。
# 使用 cv2.getTextSize 函数计算文本 "Pixels Distance: X.XX" 的尺寸,其中 X.XX 是 pixels_distance 的值,保留两位小数。 self.sf 是字体缩放比例, self.tf 是字体厚度。
(text_width_m, text_height_m), _ = cv2.getTextSize(
f"Pixels Distance: {pixels_distance:.2f}", 0, self.sf, self.tf
)
# Define corners with 10-pixel margin and draw rectangle
# 定义矩形角落并绘制矩形。
# 定义文本背景矩形的 左上角 和 右下角坐标。
top_left = (15, 25)
bottom_right = (15 + text_width_m + 20, 25 + text_height_m + 20)
# 用 cv2.rectangle 函数绘制一个填充的矩形,颜色为 centroid_color 。
cv2.rectangle(self.im, top_left, bottom_right, centroid_color, -1)
# Calculate the position for the text with a 10-pixel margin and draw text
# 计算文本位置并绘制文本。
# 计算文本位置。计算文本的位置,使其位于矩形内部,距离矩形左上角10像素,并且距离矩形顶部10像素加上文本高度。
text_position = (top_left[0] + 10, top_left[1] + text_height_m + 10)
# 绘制文本。使用 cv2.putText 函数在图像 self.im 上绘制文本。文本内容为 "Pixels Distance: X.XX",其中 X.XX 是 pixels_distance 的值,保留两位小数。
# 文本颜色为白色 (255, 255, 255) ,字体缩放比例为 self.sf ,字体厚度为 self.tf ,使用抗锯齿线型 cv2.LINE_AA 。
cv2.putText(
self.im,
f"Pixels Distance: {pixels_distance:.2f}",
text_position,
0,
self.sf,
(255, 255, 255),
self.tf,
cv2.LINE_AA,
)
# 绘制线段。
# 使用 cv2.line 函数在图像 self.im 上绘制一条线段,连接两个中心点 centroids[0] 和 centroids[1] ,线段颜色为 line_color ,厚度为3像素。
cv2.line(self.im, centroids[0], centroids[1], line_color, 3)
# 绘制中心点。
# 使用 cv2.circle 函数在两个中心点的位置绘制圆形标记。圆形半径为6像素,颜色为 centroid_color ,填充样式为 -1 (表示完全填充)。
cv2.circle(self.im, centroids[0], 6, centroid_color, -1)
cv2.circle(self.im, centroids[1], 6, centroid_color, -1)
# 这个方法通常用于在图像上显示两个点之间的距离,并用线段和圆形标记这两个点。这在图像分析和处理中非常有用,尤其是在需要可视化两点之间关系的场景中。
# 这段代码定义了一个名为 visioneye 的方法,它的作用是在图像上绘制与一个框(box)相关的中心点和标记。
# 这是方法的定义行。方法接受以下参数 :
# 1.self :类的实例。
# 2.box :定义一个矩形框的坐标,格式为 (x1, y1, x2, y2) 。
# 3.center_point :矩形框的中心点坐标。
# 4.color :绘制中心点和框时使用的颜色,默认为 (235, 219, 11) ,即浅黄色。
# 5.pin_color :绘制中心点标记时使用的颜色,默认为 (255, 0, 255) ,即洋红色。
def visioneye(self, box, center_point, color=(235, 219, 11), pin_color=(255, 0, 255)):
# 用于精确定位人眼视觉映射和绘图的函数。
"""
Function for pinpoint human-vision eye mapping and plotting.
Args:
box (list): Bounding box coordinates
center_point (tuple): center point for vision eye view
color (tuple): object centroid and line color value
pin_color (tuple): visioneye point color value
"""
# 计算框的中心坐标。计算矩形框的中心坐标,通过取框的左上角坐标 (x1, y1) 和右下角坐标 (x2, y2) 的平均值来得到。
center_bbox = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
# 绘制中心点标记。使用 cv2.circle 函数在 center_point 坐标处绘制一个圆形标记,颜色为 pin_color ,半径为 self.tf * 2 , -1 表示填充圆。
cv2.circle(self.im, center_point, self.tf * 2, pin_color, -1)
# 绘制框的中心点。使用 cv2.circle 函数在 center_bbox 坐标处绘制一个圆形标记,颜色为 color ,半径为 self.tf * 2 , -1 表示填充圆。
cv2.circle(self.im, center_bbox, self.tf * 2, color, -1)
# 绘制连接线。使用 cv2.line 函数绘制一条线段,连接 center_point 和 center_bbox ,颜色为 color ,线条厚度为 self.tf 。
cv2.line(self.im, center_point, center_bbox, color, self.tf)
# 这个方法通常用于在图像上标记和连接特定的点,例如在视觉跟踪或对象识别任务中,可以用于突出显示对象的中心点和边界框的中心点。通过绘制圆形标记和连接线,可以直观地展示这些点的位置关系。
4.def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
# 这段代码是一个Python函数,名为 plot_labels ,它的作用是绘制训练标签,包括类别直方图和框统计信息。这个函数使用了 matplotlib 和 seaborn 库来生成图表,并且使用了 PIL 库来处理图像。
# 这是一个装饰器,用于捕获并处理函数执行中可能发生的异常。这个装饰器可能是自定义的,用于在发生异常时重试或记录错误信息。
# class TryExcept(contextlib.ContextDecorator): -> 这个类用于在 with 语句块中提供异常处理功能,允许用户自定义错误消息和是否打印这些消息。 -> def __init__(self, msg="", verbose=True):
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
# 这是另一个装饰器,用于设置 matplotlib 的绘图参数,以确保图表的美观和一致性。
# def plt_settings(rcparams=None, backend="Agg"): -> 它用于在函数执行期间临时应用 matplotlib 的配置参数(rc参数)和后端设置,并在函数执行完毕后恢复原来的设置。这个装饰器可以在任何需要临时改变 matplotlib 设置的函数上使用。
@plt_settings()
# 定义了 plot_labels 函数,它接受四个参数。
# 1.boxes :包含边界框信息的数组。
# 2.cls :包含类别标签的数组。
# 3.names :一个元组,包含类别名称(默认为空)。
# 4.save_dir :保存图表的目录路径(默认为空,即当前目录)。
# 5.on_plot :一个可选的回调函数,用于在图表保存后执行(默认为 None )。
def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
# 绘制训练标签,包括类别直方图和框统计。
"""Plot training labels including class histograms and box statistics."""
# 导入 pandas 和 seaborn 库,用于数据处理和绘图。
import pandas # scope for faster 'import ultralytics'
import seaborn # scope for faster 'import ultralytics'
# Filter matplotlib>=3.7.2 warning and Seaborn use_inf and is_categorical FutureWarnings
# 过滤掉一些警告信息,以避免在绘图过程中显示不必要的警告。
warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight") # 图形布局已更改为紧密。
warnings.filterwarnings("ignore", category=FutureWarning)
# Plot dataset labels
# 记录日志信息,表示开始绘制标签。
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ") # 将标签绘制到 {save_dir / 'labels.jpg'}...
# 计算类别的数量。
nc = int(cls.max() + 1) # number of classes
# 限制边界框的数量,最多处理100万个。
boxes = boxes[:1000000] # limit to 1M boxes
# pandas.DataFrame(data, index=None, columns=None, dtype=None, copy=False)
# pandas.DataFrame() 是 Python 中 pandas 库的一个核心功能,用于创建一个二维的、表格型的数据结构。DataFrame 可以被看作是由多个 Series(一维数组)组成的(每个 Series 可以看作是 DataFrame 的一列),这些 Series 共享索引(index)。DataFrame 非常适合于处理表格数据,比如 Excel 电子表格或者 SQL 数据库中的表。
# 参数 :
# data :要包含在 DataFrame 中的数据。可以是以下几种类型之一 :
# dict :字典,其中键是列名,值是数据列表或数组。
# list 或 ndarray :列表或 NumPy 数组,将被转换为 DataFrame 的列。
# Series :pandas Series 对象,其索引将成为 DataFrame 的列名。
# 其他 DataFrame :现有的 DataFrame ,可以从中复制数据。
# index :行索引。可以是以下几种类型之一 :
# 列表 :行索引的列表。
# Index :pandas Index 对象。
# None :默认情况下,将创建一个从 0 开始的整数索引。
# columns :列索引。可以是以下几种类型之一 :
# 列表 :列索引的列表。
# Index :pandas Index 对象。
# None :默认情况下,如果 data 是字典,则使用字典的键作为列名;否则,将创建一个从 0 开始的整数索引。
# dtype :DataFrame 中数据的期望类型。可以是以下几种类型之一 :
# numpy.dtype :NumPy 数据类型。
# pandas.CategoricalDtype :分类数据类型。
# None:默认情况下,pandas 将自动推断数据类型。
# copy :布尔值,指示是否复制数据。默认为 False,意味着如果 data 是 pandas 对象,则不会复制数据。
# 返回值 :
# 返回一个新的 DataFrame 对象。
# 描述 :
# pandas.DataFrame() 构造函数提供了多种方式来创建 DataFrame 对象。可以传入不同类型的数据,并指定行索引和列索引。构造函数会根据提供的数据和索引创建 DataFrame,并根据需要推断数据类型。
# 将边界框信息转换为 pandas DataFrame 。
x = pandas.DataFrame(boxes, columns=["x", "y", "width", "height"])
# Seaborn correlogram
# sns.pairplot(data, hue=None, vars=None, x_vars=None, y_vars=None, kind='scatter', diag_kind='hist', plot_kws=None, diag_kws=None, hue_kws=None, palette=None, dropna=True, vars_order=None, x_vars_order=None, y_vars_order=None, height=2.5, aspect=1, corner=True, join=False, scatter_kws=None, line_kws=None, kind_order=None)
# sns.pairplot 是 Seaborn 库中的一个函数,用于创建一个矩阵式的图表,显示数据集中所有变量之间的成对关系。这个函数特别适合于探索数据集中变量之间的关系,尤其是当你想要快速查看不同变量之间的分布和关系时。
# 参数说明 :
# data :要绘制的数据,通常是一个 pandas DataFrame。
# hue :(可选)指定用于颜色编码的列名,用于在图表中显示不同类别的数据。
# vars :(可选)要绘制的变量列表。
# x_vars 、 y_vars :(可选)分别指定 x 轴和 y 轴上的变量列表。
# kind :(可选)图表的类型,可以是 'scatter'(散点图)、'reg'(回归线图)、'resid'(残差图)、'kde'(核密度估计图)等。
# diag_kind :(可选)对角线上图表的类型,通常是 'hist'(直方图)。
# plot_kws 、 diag_kws 、 hue_kws :(可选)传递给 图表 、 对角线图表 和 颜色编码图表 的关键字参数。
# palette :(可选)颜色映射。
# dropna :(可选)是否丢弃包含缺失值的行。
# vars_order 、 x_vars_order 、 y_vars_order :(可选)变量的顺序。
# height :(可选)每个子图的高度。
# aspect :(可选)每个子图的宽高比。
# corner :(可选)是否在图表的角落显示对角线上的图表。
# join :(可选)是否在散点图中连接点。
# scatter_kws 、 line_kws :(可选)传递给散点图和线图的关键字参数。
# kind_order :(可选)图表类型的顺序。
# 返回值 :
# 返回一个 PairGrid 对象,可以进一步定制图表。
# 使用 seaborn 的 pairplot 函数绘制相关图。
seaborn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
# 保存相关图并关闭图表。
plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200)
plt.close()
# 这段代码是使用 matplotlib 和 seaborn 库来创建和定制图表的。它涉及到创建子图、绘制直方图、设置颜色和标签等操作。
# Matplotlib labels
# 创建子图布局。
# plt.subplots(2, 2, ...) 创建一个 2x2 的子图布局。
# figsize=(8, 8) 设置整个图形的大小为 8x8 英寸。
# tight_layout=True 自动调整子图参数,以确保子图之间有足够的空间,并且子图的标题和轴标签不会重叠。
# [1].ravel() 返回子图的轴对象数组,并将其展平为一维数组,方便后续索引。
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
# 绘制直方图。
# ax[0].hist(...) 在第一个子图上绘制 cls 数组的直方图。
# bins=np.linspace(0, nc, nc + 1) - 0.5 定义直方图的 bin 边界。这里使用 np.linspace 生成从 0 到 nc 的 nc+1 个点,然后每个点减去 0.5,使得每个 bin 居中于类别索引。
# rwidth=0.8 设置条形的相对宽度为 0.8。
y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
# 设置条形颜色。
# 这段代码遍历每个类别,设置直方图条形的颜色。 colors(i) 是一个返回颜色值的函数或方法,这里假设它返回一个颜色列表, x / 255 将颜色值从 0-255 范围转换为 0-1 范围,以适应 matplotlib 的颜色格式。
for i in range(nc):
# 在 matplotlib 中, patches 是一个包含多个图形元素(如矩形、圆形、椭圆等)的对象集合。当你使用 hist() 函数绘制直方图时,每个条形(bin)实际上是一个 Rectangle 对象,这些对象被存储在 patches 属性中。
# 在 y[2].patches[i].set_color([x / 255 for x in colors(i)]) 中:
# y[2] 指的是由 ax[0].hist(...) 返回的直方图对象中的第三个元素(因为索引从0开始),这个元素包含了直方图中条形(patches)的集合。
# patches 是一个列表,其中包含了直方图中每个条形对应的 Rectangle 对象。
# y[2].patches[i] 指的是这个列表中的第 i 个 Rectangle 对象,即直方图中的第 i 个条形。
# set_color() 方法用于设置这个条形的颜色。
y[2].patches[i].set_color([x / 255 for x in colors(i)])
# 设置轴标签和标题。
# 设置 y 轴的标签为 "instances"。
ax[0].set_ylabel("instances") # 实例。
# 如果 names 的长度在 1 到 30 之间,设置 x 轴的刻度和标签,标签旋转 90 度并设置字体大小为 10。
if 0 < len(names) < 30:
# Axes.set_xticks(self, ticks, labels=None, **kwargs)
# set_xticks() 是 matplotlib 库中 Axes 对象的一个方法,用于设置图表的 x 轴刻度位置。这个方法允许你自定义 x 轴上的刻度,例如,你可以指定刻度的精确位置,或者更改刻度的显示方式。
# 参数 :
# ticks :必须。这是一个数字列表,指定了刻度线应该放置的位置。 也可以是一个空列表,这将清除所有的刻度。
# labels (可选) :这是一个与 ticks 参数同样长度的列表,指定了每个刻度位置的标签。 如果不提供,将使用默认的刻度标签。
# **kwargs (可选) : 这是一些额外的关键字参数,可以用来设置刻度的属性,例如颜色、大小等。
# 返回值 :无。
# 描述 :
# set_xticks() 方法允许你控制 x 轴上的刻度位置和标签。这对于创建更精确的图表非常有用,特别是当你想要突出显示特定数据点或者更改默认的刻度间隔时。
ax[0].set_xticks(range(len(names)))
# Axes.set_xticklabels(self, labels, fontdict=None, fontproperties=None, **kwargs)
# set_xticklabels() 是 matplotlib 库中 Axes 对象的一个方法,用于设置 x 轴刻度的标签。这个方法允许你自定义 x 轴上的刻度标签,例如,你可以指定标签的文本、位置、颜色等。
# 参数 :
# labels :必须。这是一个字符串列表或数组,指定了每个刻度位置的标签文本。
# fontdict (可选) :这是一个字典,用于指定标签的字体属性,如大小、重量、风格等。例如, fontdict={'fontsize': 12, 'fontweight': 'bold'} 。
# fontproperties (可选) : 这是一个 matplotlib.font_manager.FontProperties 对象,也用于指定标签的字体属性。这可以作为 fontdict 的替代方法。
# kwargs (可选) :这是一些额外的关键字参数,可以用来设置刻度标签的属性,如颜色、对齐方式等。
# 返回值 :无。
# 描述 :
# set_xticklabels() 方法允许你控制 x 轴上的刻度标签。这对于创建更具可读性和吸引力的图表非常有用,特别是当你想要改变标签的文本、样式或者格式化标签以适应特定的数据展示需求时。
# 区别总结 :
# set_xticks 仅影响刻度的位置,不涉及标签的文本或样式。
# set_xticklabels 用于设置刻度的标签文本,并允许你指定标签的样式,如字体大小、颜色等。
ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
# 否则,设置 x 轴的标签为 "classes"。
else:
ax[0].set_xlabel("classes")
# 使用 seaborn 绘制直方图。
# seaborn.histplot(...) 使用 seaborn 绘制直方图,这里分别在第三个和第四个子图上绘制 x 和 y 坐标的分布,以及 width 和 height 的分布。
# bins=50 定义直方图的 bin 数量。 pmax=0.9 设置直方图的高度限制为最大值的 90%。
seaborn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9)
seaborn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9)
# 这段代码综合使用了 matplotlib 和 seaborn 的功能来创建复杂的图表,展示了数据的分布和统计信息。
# 这段代码是用于在 matplotlib 图表中绘制边界框(bounding boxes)的示例。
# Rectangles
# 将边界框的中心设置为0.5。这行代码将边界框的中心坐标(x, y)设置为0.5,这意味着边界框的中心将位于图像的中心位置。这可能是为了将边界框的坐标从相对坐标转换为绝对坐标。
boxes[:, 0:2] = 0.5 # center
# 转换边界框格式。
# ops.xywh2xyxy 是一个函数,用于将边界框的坐标从 (x, y, w, h) 格式(其中 x, y 是中心点坐标, w, h 是宽度和高度)转换为 (x1, y1, x2, y2) 格式(其中 x1, y1 是左上角坐标, x2, y2 是右下角坐标)。
# 然后乘以1000,可能是为了将坐标从 归一化坐标 转换为 像素坐标 。
boxes = ops.xywh2xyxy(boxes) * 1000
# 创建一个空白图像。这行代码创建了一个1000x1000像素的空白图像,所有像素值都设置为255(白色)。
img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)
# 在图像上绘制边界框。
# 这行代码遍历类别和边界框,使用 ImageDraw.Draw 对象在图像上绘制边界框。 cls[:500] 和 boxes[:500] 限制了只处理前500个边界框。 colors(cls) 用于根据类别索引返回相应的颜色。
for cls, box in zip(cls[:500], boxes[:500]):
ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
# 在子图上显示图像。
# 将图像显示在第二个子图上。
ax[1].imshow(img)
# 关闭子图的坐标轴。
ax[1].axis("off")
# 在 matplotlib 中, Axes 对象代表图表中的一个轴区域,它是图表的基本组成部分,用于绘制数据和展示信息。 Axes 对象可以包含图表的标题、轴标签、刻度、图例以及数据图形等元素。
# Axes对象的属性 :
# Axes 对象有许多属性,以下是一些常用的属性。
# artists : 一个包含轴上所有艺术家(Artist)对象(如线条、标记、文本等)的列表。
# bbox : 轴的边界框,表示轴在图表中的定位和大小。
# spines : 一个字典,包含轴的四个边框线(spines),即上、下、左、右边框线。
# xaxis 和 yaxis : 分别代表 x 轴和 y 轴的对象,可以用于访问和修改轴的属性。
# title : 轴的标题。
# legend : 轴的图例。
# lines 和 patches : 包含轴上所有线条和图形块(如矩形、圆形等)的列表。
# collections : 包含轴上所有集合对象(如路径集合)的列表。
# images : 包含轴上所有图像对象的列表。
# texts : 包含轴上所有文本对象的列表。
# tables : 包含轴上所有表格对象的列表。
# contour 和 contourf : 用于绘制等高线图的方法。
# imshow : 用于显示图像数据的方法。
# pcolor 和 pcolormesh : 用于绘制伪彩色图的方法。
# scatter : 用于绘制散点图的方法。
# bar 和 barh : 用于绘制条形图的方法。
# plot 、 semilogy 、 semilogx 等: 用于绘制各种类型的线图的方法。
# Axes对象的创建 :
# Axes 对象通常是通过 plt.subplots() 或 plt.axes() 方法创建的 :
# fig, ax = plt.subplots() # 使用 subplots 创建图表和轴对象。
# ax = plt.axes([0.1, 0.1, 0.8, 0.8]) # 直接创建轴对象。
# Axes 对象是 matplotlib 中最核心的对象之一,它提供了丰富的方法和属性来定制和控制图表的各个方面。通过操作 Axes 对象,用户可以实现复杂的图表布局和数据可视化。
# 隐藏子图的边框。这段代码遍历所有四个子图,隐藏它们的边框。
for a in [0, 1, 2, 3]:
for s in ["top", "right", "left", "bottom"]:
# 用于设置图表中特定轴( ax )的边框线( spines )的可见性。
# ax :这是一个 Axes 对象的数组或列表,代表了图表中的一个或多个轴(例如,在一个子图布局中)。
# a :这是一个索引,用于指定 ax 数组中的特定轴对象。
# spines : Axes 对象有一个 spines 属性,它是一个字典,包含了四个条目: 'top' 、 'bottom' 、 'left' 和 'right' ,分别对应轴的上、下、左、右边框线。
# s :这是 spines 字典中的一个键,用于指定要操作的边框线。
# set_visible(False) :这是一个方法,用于设置指定的 spines 的可见性。将 False 作为参数传递给这个方法会隐藏对应的边框线。
ax[a].spines[s].set_visible(False)
# 保存图表。
# save_dir / "labels.jpg" 设置了保存文件的路径和文件名。
fname = save_dir / "labels.jpg"
# plt.savefig(fname, dpi=200) 保存图表为JPEG文件,分辨率为200 DPI。
plt.savefig(fname, dpi=200)
# plt.close() 关闭图表,释放资源。
plt.close()
# 执行回调函数。如果提供了 on_plot 回调函数,则在保存图表后执行该函数,并将文件名传递给它。
if on_plot:
on_plot(fname)
# 这段代码展示了如何在 matplotlib 图表中绘制边界框,并保存图表。它涉及到图像处理、绘图和文件操作等多个步骤。
# 这个函数的主要目的是可视化边界框和类别标签,帮助用户理解数据集中的分布情况。
5.def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False, BGR=False, save=True):
# 这段代码定义了一个名为 save_one_box 的函数,其目的是从一个图像中裁剪出一个边界框(bounding box)所指定的区域,并将其保存为一个 JPEG 文件。
# 函数定义和参数。
# 这行定义了函数 save_one_box ,它接受以下参数。
# 1.xyxy :边界框的坐标,格式为 (x1, y1, x2, y2) 。
# 2.im :要裁剪的原始图像。
# 3.file :保存裁剪图像的文件路径,默认为 "im.jpg" 。
# 4.gain :裁剪区域的缩放因子,默认为 1.02 。
# 5.pad :裁剪区域的填充量,默认为 10 。
# 6.square :布尔值,指示是否将裁剪区域调整为正方形,默认为 False 。
# 7.BGR :布尔值,指示图像通道顺序是否为BGR,默认为 False (即RGB)。
# 8.save :布尔值,指示是否保存裁剪的图像,默认为 True 。
def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False, BGR=False, save=True):
# 将图像裁剪保存为 {file},裁剪大小为 {gain} 和 {pad} 像素的倍数。保存和/或返回裁剪。
# 此函数接受边界框和图像,然后根据边界框保存裁剪的图像部分。裁剪可以是可选的平方,并且该函数允许对边界框进行增益和填充调整。
"""
Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
This function takes a bounding box and an image, and then saves a cropped portion of the image according
to the bounding box. Optionally, the crop can be squared, and the function allows for gain and padding
adjustments to the bounding box.
Args:
xyxy (torch.Tensor or list): A tensor or list representing the bounding box in xyxy format.
im (numpy.ndarray): The input image.
file (Path, optional): The path where the cropped image will be saved. Defaults to 'im.jpg'.
gain (float, optional): A multiplicative factor to increase the size of the bounding box. Defaults to 1.02.
pad (int, optional): The number of pixels to add to the width and height of the bounding box. Defaults to 10.
square (bool, optional): If True, the bounding box will be transformed into a square. Defaults to False.
BGR (bool, optional): If True, the image will be saved in BGR format, otherwise in RGB. Defaults to False.
save (bool, optional): If True, the cropped image will be saved to disk. Defaults to True.
Returns:
(numpy.ndarray): The cropped image.
Example:
```python
from ultralytics.utils.plotting import save_one_box
xyxy = [50, 50, 150, 150]
im = cv2.imread("image.jpg")
cropped_im = save_one_box(xyxy, im, file="cropped.jpg", square=True)
```
"""
# 边界框坐标处理。如果 xyxy 不是 PyTorch 张量,则将其转换为张量。这允许函数接受列表或其它形式的输入,并统一处理。
if not isinstance(xyxy, torch.Tensor): # may be list
# torch.stack(tensors, dim=0, *, out=None) -> Tensor
# torch.stack 是 PyTorch 中的一个函数,用于将一系列张量(tensors)沿着一个新的维度连接起来,从而创建一个新的张量。这个新张量是由输入张量堆叠而成的,堆叠的维度由参数 dim 指定。
# 参数 :
# tensors :必须。一个张量序列,可以是列表、元组或任何可迭代的张量对象集合。所有张量必须具有相同的形状。
# dim :可选。指定沿着哪个维度堆叠张量,默认为0。这个维度将被添加到结果张量的维度中。
# out :可选。用于存储输出结果的张量。它必须具有与要创建的堆叠张量相同的形状和类型。
# 返回值 :
# 返回一个新的张量,它是输入张量沿着指定维度 dim 堆叠的结果。
# 描述 :
# torch.stack 函数与 torch.cat 函数类似,但它们的主要区别在于堆叠的维度。 torch.cat 是在已存在的维度上连接张量,而 torch.stack 是创建一个新的维度来堆叠张量。
xyxy = torch.stack(xyxy)
# 边界框坐标转换。将边界框的坐标从 xyxy 格式(左上角和右下角坐标)转换为 xywh 格式(中心点坐标和宽高)。
b = ops.xyxy2xywh(xyxy.view(-1, 4)) # boxes
# 调整边界框大小。
# 如果 square 为 True ,则将边界框的 宽度 和 高度 调整为两者中的最大值,使边界框成为正方形。
if square:
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
# 将边界框的宽度和高度分别乘以 gain 并加上 pad ,以调整裁剪区域的大小。
b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
# 裁剪图像。
# 将调整后的边界框坐标从 xywh 格式转换回 xyxy 格式。
xyxy = ops.xywh2xyxy(b).long()
# 确保边界框坐标不超过图像尺寸。
# def clip_boxes(boxes, shape): -> 用于将边界框(boxes)剪辑到特定的图像形状(shape)内,确保边界框的坐标不会超出图像的边界。返回剪辑后的边界框数组。 -> return boxes
xyxy = ops.clip_boxes(xyxy, im.shape)
# 根据边界框坐标裁剪图像。如果 BGR 为 True ,则在裁剪时反转通道顺序(从RGB变为BGR)。
crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR else -1)]
# 保存裁剪的图像。
# 如果 save 为 True ,则创建文件目录(如果不存在)。
if save:
file.parent.mkdir(parents=True, exist_ok=True) # make directory
# 生成一个唯一的文件路径,并确保文件扩展名为 .jpg 。
# def increment_path(path, exist_ok=False, sep="", mkdir=False):
# -> 在给定路径已存在的情况下,通过在路径后面添加一个数字后缀来生成一个新的路径,直到找到一个不存在的路径。回最终的路径,可能是原始路径(如果 exist_ok 为 True 且路径不存在),或者是一个增加了数字后缀的新路径。
# -> return path
f = str(increment_path(file).with_suffix(".jpg"))
# cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
# 将裁剪的图像转换为PIL图像对象,并保存为JPEG文件,质量设置为95,禁用子采样以避免色度下采样问题。
Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
# 返回裁剪的图像。函数返回裁剪的图像区域。
return crop
# 这个函数提供了一种方便的方式来从图像中裁剪出特定的区域,并将其保存为高质量的 JPEG 文件。它考虑了边界框的调整、图像通道顺序以及保存时的图像质量。
6.def plot_images(images: Union[torch.Tensor, np.ndarray], batch_idx: Union[torch.Tensor, np.ndarray], cls: Union[torch.Tensor, np.ndarray], bboxes: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.float32), confs: Optional[Union[torch.Tensor, np.ndarray]] = None, masks: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.uint8), kpts: Union[torch.Tensor, np.ndarray] = np.zeros((0, 51), dtype=np.float32), paths: Optional[List[str]] = None, fname: str = "images.jpg", names: Optional[Dict[int, str]] = None, on_plot: Optional[Callable] = None, max_size: int = 1920, max_subplots: int = 16, save: bool = True, conf_thres: float = 0.25,) -> Optional[np.ndarray]:
# 这段代码定义了一个名为 plot_images 的函数,它用于将一系列图像和相关的边界框、类别、置信度等信息绘制成一幅图像马赛克(mosaic)。
# 这个函数使用了装饰器 @threaded ,这是一个自定义装饰器,用于将函数的执行放在一个单独的线程中,以便不阻塞主线程。这样的设计在处理图像绘制时特别有用,因为它可能需要一些时间来完成,尤其是在处理大量图像时。
# def threaded(func): -> 定义了一个名为 threaded 的装饰器,它使得任何被装饰的函数可以以多线程的方式运行。使用这个装饰器可以轻松地将函数转换为多线程模式。
@threaded
# 1.images : 要绘制的图像,可以是PyTorch张量或NumPy数组。
# 2.batch_idx : 与图像对应的批次索引,可以是PyTorch张量或NumPy数组。
# 3.cls : 每个图像的类别标签,可以是PyTorch张量或NumPy数组。
# 4.bboxes : 每个图像的边界框,可以是PyTorch张量或NumPy数组,默认为一个空的NumPy数组。
# 5.confs : 每个检测的置信度,可以是PyTorch张量或NumPy数组,是一个可选参数。
# 6.masks : 每个图像的掩码,可以是PyTorch张量或NumPy数组,默认为一个空的NumPy数组。
# 7.kpts : 每个图像的关键点,可以是PyTorch张量或NumPy数组,默认为一个空的NumPy数组。
# 8.paths : 图像的文件路径列表,是一个可选参数。
# 9.fname : 保存图像的文件名,默认为"images.jpg"。
# 10.names : 类别名称的字典,键为类别ID,值对应名称,是一个可选参数。
# 11.on_plot : 一个可选的回调函数,用于在绘制每个子图时进行自定义操作。
# 12.max_size : 绘制图像的最大尺寸,默认为1920。
# 13.max_subplots : 绘制的最大子图数量,默认为16。
# 14.save : 是否保存绘制的图像,默认为True。
# 15.conf_thres : 置信度阈值,低于此阈值的检测将不被绘制,默认为0.25。
# 函数返回值是一个可选的NumPy数组,可能是保存的图像数据。
def plot_images(
images: Union[torch.Tensor, np.ndarray],
batch_idx: Union[torch.Tensor, np.ndarray],
cls: Union[torch.Tensor, np.ndarray],
bboxes: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.float32),
confs: Optional[Union[torch.Tensor, np.ndarray]] = None,
masks: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.uint8),
kpts: Union[torch.Tensor, np.ndarray] = np.zeros((0, 51), dtype=np.float32),
paths: Optional[List[str]] = None,
fname: str = "images.jpg",
names: Optional[Dict[int, str]] = None,
on_plot: Optional[Callable] = None,
max_size: int = 1920,
max_subplots: int = 16,
save: bool = True,
conf_thres: float = 0.25,
) -> Optional[np.ndarray]:
# 绘制带有标签、边界框、掩码和关键点的图像网格。
# 注意:
# 此函数支持张量和 numpy 数组输入。它会自动将张量输入转换为 numpy 数组进行处理。
"""
Plot image grid with labels, bounding boxes, masks, and keypoints.
Args:
images: Batch of images to plot. Shape: (batch_size, channels, height, width).
batch_idx: Batch indices for each detection. Shape: (num_detections,).
cls: Class labels for each detection. Shape: (num_detections,).
bboxes: Bounding boxes for each detection. Shape: (num_detections, 4) or (num_detections, 5) for rotated boxes.
confs: Confidence scores for each detection. Shape: (num_detections,).
masks: Instance segmentation masks. Shape: (num_detections, height, width) or (1, height, width).
kpts: Keypoints for each detection. Shape: (num_detections, 51).
paths: List of file paths for each image in the batch.
fname: Output filename for the plotted image grid.
names: Dictionary mapping class indices to class names.
on_plot: Optional callback function to be called after saving the plot.
max_size: Maximum size of the output image grid.
max_subplots: Maximum number of subplots in the image grid.
save: Whether to save the plotted image grid to a file.
conf_thres: Confidence threshold for displaying detections.
Returns:
np.ndarray: Plotted image grid as a numpy array if save is False, None otherwise.
Note:
This function supports both tensor and numpy array inputs. It will automatically
convert tensor inputs to numpy arrays for processing.
"""
# 这段代码是函数 plot_images 的一部分,它处理输入数据的类型转换和一些预处理步骤
# 类型转换。
# 如果 images 是一个 PyTorch 张量,它会被转换为 CPU 上的浮点数张量,然后转换为 NumPy 数组。
if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy()
# 类似地,如果 cls 、 bboxes 、 masks 、 kpts 和 batch_idx 是 PyTorch 张量,它们也会被转换为 NumPy 数组。对于 masks ,还会额外转换数据类型为整数( int )。
if isinstance(cls, torch.Tensor):
cls = cls.cpu().numpy()
if isinstance(bboxes, torch.Tensor):
bboxes = bboxes.cpu().numpy()
if isinstance(masks, torch.Tensor):
masks = masks.cpu().numpy().astype(int)
if isinstance(kpts, torch.Tensor):
kpts = kpts.cpu().numpy()
if isinstance(batch_idx, torch.Tensor):
batch_idx = batch_idx.cpu().numpy()
# 获取图像尺寸。
# 通过 images.shape 获取图像的批次大小( bs )、通道数、高度( h )和宽度( w )。
bs, _, h, w = images.shape # batch size, _, height, width
# 使用 min 函数限制绘制的图像数量不超过 max_subplots 参数指定的最大子图数量。
bs = min(bs, max_subplots) # limit plot images
# 计算子图数量。使用 np.ceil 函数计算需要多少个子图来排列这些图像,以形成一个接近正方形的布局。
ns = np.ceil(bs**0.5) # number of subplots (square)
# 图像去归一化。
# 检查图像数据的最大值是否小于或等于 1,如果是,则将图像数据乘以 255 进行去归一化。这一步是可选的,通常用于将图像数据从 [0, 1] 范围转换回 [0, 255] 范围,以便正确显示。
if np.max(images[0]) <= 1:
images *= 255 # de-normalise (optional)
# 这段代码的目的是确保所有输入数据都是 NumPy 数组格式,并且图像数据在显示前被正确地处理。去归一化步骤假设图像数据在输入之前已经被归一化到了 [0, 1] 范围内,这是深度学习中常见的预处理步骤。
# 这段代码继续处理 plot_images 函数中的图像绘制逻辑。
# Build Image
# 初始化拼贴画布。
# 使用 np.full 创建一个填充为255(白色)的画布,其尺寸为 (ns * h, ns * w, 3) ,其中 ns 是计算出的子图数量, h 和 w 是每个图像的高度和宽度。这个画布将用于将所有图像拼接成一个大的拼贴画。
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
# 遍历图像批次。通过 for 循环遍历每个图像, i 是当前图像的索引。
for i in range(bs):
# 计算图像位置。
# 对于每个图像,计算其在拼贴画布上的起始 x 和 y 坐标。这是通过将图像索引 i 除以 ns (计算行)和取余 ns (计算列)来实现的,然后将结果乘以每个图像的宽度 w 和高度 h 。
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
# 将图像复制到画布。
# 使用 NumPy 的切片操作将每个图像复制到拼贴画布的相应位置。这里, images[i].transpose(1, 2, 0) 将图像数据从 (C, H, W) 格式转换为 (H, W, C) 格式,这是 OpenCV 和大多数图像处理库所需的格式。
mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0)
# 这段代码的目的是将所有图像拼接成一个大的拼贴画布,每个图像占据拼贴的一个区块。以下是一些关键点 :
# mosaic 是最终的拼贴画布,其尺寸根据子图数量和每个图像的尺寸计算得出。
# x 和 y 计算每个图像在拼贴画布上的起始位置。
# mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0) 将每个图像复制到拼贴画布的相应位置。这段代码假设 images 数组中的每个图像都是 (C, H, W) 格式,其中 C 是通道数, H 是高度, W 是宽度。
# 如果图像数据是归一化的,可能需要先进行去归一化处理,如前一段代码所示。最终,这段代码将生成一个包含所有图像的拼贴画布,可以用于显示或保存。
# 这段代码是 plot_images 函数中用于可选调整拼贴画布大小的部分。
# Resize (optional)
# 计算缩放比例。
# 计算缩放比例 scale 。这里, max_size 是允许的最大画布尺寸, ns 是子图数量(即拼贴画布的宽度和高度), max(h, w) 是每个图像的最大尺寸(高度或宽度)。这个比例用于确保拼贴画布的尺寸不会超过 max_size 。
scale = max_size / ns / max(h, w)
# 检查是否需要缩放。如果计算出的缩放比例小于1,即原始拼贴画布的尺寸超过了 max_size ,则需要进行缩放。
if scale < 1:
# 调整图像尺寸。
# 使用 math.ceil 函数向上取整计算缩放后的高度。
h = math.ceil(scale * h)
# 使用 math.ceil 函数向上取整计算缩放后的宽度。
w = math.ceil(scale * w)
# 缩放拼贴画布。
# 使用 OpenCV 的 cv2.resize 函数对拼贴画布进行缩放。 tuple(int(x * ns) for x in (w, h)) 生成一个元组,表示缩放后的画布尺寸,其中 w 和 h 是缩放后的宽度和高度, ns 是子图数量,确保整个拼贴画布的尺寸也被相应缩放。
mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
# 这段代码的目的是确保拼贴画布的尺寸不会超过 max_size ,同时保持图像的纵横比。
# 这段代码是 plot_images 函数中用于在拼贴画布上添加注释的部分。
# Annotate
# 设置字体大小。计算字体大小 fs ,基于图像的高度和宽度以及子 图数量的总和 ,乘以0.01得到一个相对大小。
fs = int((h + w) * ns * 0.01) # font size
# 创建注释器。
# 创建一个 Annotator 对象,用于在画布上添加注释。 line_width 是线条宽度, font_size 是字体大小, pil=True 表示使用 PIL(Python Imaging Library)进行处理, example=names 可能是用于获取颜色或其他样式的示例。
# class Annotator:
# -> Ultralytics Annotator 用于训练/验证马赛克和 JPG 以及预测注释。
# -> def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"):
annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
# 遍历图像批次添加注释。
# 遍历每个图像。
for i in range(bs):
# 计算图像位置。计算每个图像在拼贴画布上的起始 x 和 y 坐标。
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
# 绘制边界框。在每个图像周围绘制一个白色的边界框。
# def rectangle(self, xy, fill=None, outline=None, width=1): -> 用于在图像上绘制矩形的。
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
# 添加文件名。
# 如果提供了路径列表,就在每个图像的左上角添加文件名。
if paths:
# 在每个图像的左上角添加文件名,文本颜色为浅灰色。
# def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_style=False): -> 用于在图像上添加文本。
annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
# 处理类别和边界框。
# 如果提供了类别信息,就处理每个图像的类别和边界框。
if len(cls) > 0:
# 找到与当前图像索引匹配的类别和边界框索引。
idx = batch_idx == i
# 获取当前图像的类别。
classes = cls[idx].astype("int")
# 判断是否有置信度信息,如果没有,则为标签。
labels = confs is None
# 处理边界框。
# 如果提供了边界框信息,就处理每个边界框。
if len(bboxes):
# 获取当前图像的边界框。
boxes = bboxes[idx]
# 获取当前图像的置信度(如果有)。
conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
# 检查边界框是否已归一化,并在需要时将其缩放到像素值。
if len(boxes):
if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1
# 调整边界框的位置,使其相对于拼贴画布的起始位置。
boxes[..., [0, 2]] *= w # scale to pixels
boxes[..., [1, 3]] *= h
elif scale < 1: # absolute coords need scale if image scales
boxes[..., :4] *= scale
boxes[..., 0] += x
boxes[..., 1] += y
# 绘制边界框和标签。
# 检查边界框是否是方向边界框(Oriented Bounding Box)。
is_obb = boxes.shape[-1] == 5 # xywhr
# 将边界框格式转换为适合绘制的格式。
boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
# 遍历每个边界框。
for j, box in enumerate(boxes.astype(np.int64).tolist()):
# 获取边界框对应的类别。
c = classes[j]
# 获取类别对应的颜色。
color = colors(c)
# 获取类别名称。
c = names.get(c, c) if names else c
# 如果只有标签或置信度高于阈值,则绘制边界框和标签。
if labels or conf[j] > conf_thres:
# 构建标签文本。
label = f"{c}" if labels else f"{c} {conf[j]:.1f}"
# 在画布上绘制边界框和标签。
# def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False): -> 用于在图像上绘制带有文本标签的边界框。
annotator.box_label(box, label, color=color, rotated=is_obb)
# 处理没有边界框的类别。
# 如果没有边界框但有类别信息,就在图像中心绘制类别名称。
elif len(classes):
# 遍历每个类别。
for c in classes:
# 获取类别对应的颜色。
color = colors(c)
# 获取类别名称。
c = names.get(c, c) if names else c
# 在图像中心绘制类别名称。
# def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_style=False): -> 用于在图像上添加文本。
annotator.text((x, y), f"{c}", txt_color=color, box_style=True)
# 这段代码的目的是将类别、边界框和文件名等信息添加到拼贴画布上,以便更直观地展示图像内容和检测结果。
# 这段代码是 plot_images 函数中用于在拼贴画布上绘制关键点(keypoints)的部分。
# Plot keypoints
# 检查关键点数据。
# 如果提供了关键点信息,就处理每个图像的关键点。
if len(kpts):
# 复制并检查关键点。
# 获取当前图像的关键点,并复制一份以避免修改原始数据。
kpts_ = kpts[idx].copy()
# 检查复制后的关键点列表是否非空。
if len(kpts_):
# 关键点坐标归一化或缩放。
# 检查关键点坐标是否已经归一化,并允许一定的容差(0.01)。
if kpts_[..., 0].max() <= 1.01 or kpts_[..., 1].max() <= 1.01: # if normalized with tolerance .01
# 如果关键点坐标归一化,则将其转换为像素坐标。
kpts_[..., 0] *= w # scale to pixels
kpts_[..., 1] *= h
# 如果图像被缩放,则关键点坐标也需要相应缩放。
elif scale < 1: # absolute coords need scale if image scales
# 对关键点坐标进行缩放。
kpts_ *= scale
# 调整关键点位置。
# 调整关键点的位置,使其相对于拼贴画布的起始位置。
kpts_[..., 0] += x
kpts_[..., 1] += y
# 绘制关键点。
# 遍历每个关键点。
for j in range(len(kpts_)):
# 如果只有标签或置信度高于阈值,则绘制关键点。
if labels or conf[j] > conf_thres:
# 使用 Annotator 对象的 kpts 方法在画布上绘制关键点。
# def kpts(self, kpts, shape=(640, 640), radius=None, kpt_line=True, conf_thres=0.25, kpt_color=None): -> 用于在图像上绘制关键点(keypoints)和关键点之间的连线,通常用于表示人体姿态估计的结果。
annotator.kpts(kpts_[j], conf_thres=conf_thres)
# 这段代码的目的是将关键点信息添加到拼贴画布上,以便更直观地展示图像中的关键部位。
# 这段代码是 plot_images 函数中用于在拼贴画布上绘制掩码(masks)的部分。
# Plot masks
# 检查掩码数据。
# 如果提供了掩码信息,就处理每个图像的掩码。
if len(masks):
# 处理掩码重叠。
# 如果索引的数量与掩码的数量相等,说明每个掩码对应一个对象,没有重叠。
if idx.shape[0] == masks.shape[0]: # overlap_masks=False
# 直接使用索引获取对应的掩码。
image_masks = masks[idx]
# 如果索引的数量不等于掩码的数量,说明掩码可能需要重叠处理。
else: # overlap_masks=True
# 获取掩码。
image_masks = masks[[i]] # (1, 640, 640)
# 计算掩码的数量。
nl = idx.sum()
# 创建一个与掩码数量相同的索引数组。
# np.arange(nl) :这个函数生成一个从 0 到 nl-1 的一维数组(即等差数列),包含 nl 个元素。
# .reshape((nl, 1, 1)) : reshape 方法将上述一维数组重新塑形为一个三维数组,具体形状为 (nl, 1, 1) 。这意味着每个元素都被放置在一个单独的三维空间中的“盒子”里。
# + 1 :最后,对数组中的每个元素加 1,使得数组中的数字从 1 开始,而不是从 0 开始。
# 这个操作通常用于创建一个可以用于索引或标记不同对象的数组。例如,在处理图像掩码时,如果每个掩码对应一个不同的对象,这个数组可以用来标记每个掩码属于哪个对象。通过这种方式,可以轻松地将掩码与它们对应的对象关联起来,尤其是在处理重叠掩码时。
index = np.arange(nl).reshape((nl, 1, 1)) + 1
# 重复掩码以匹配索引数量。
# 这个操作通常用于以下情况 :
# 当你有一组掩码,每个掩码对应一个不同的对象,但是你想要将这些掩码复制多次,以便每个复制的掩码都能与一个特定的索引或标签关联起来。
# 在某些图像处理任务中,比如实例分割,你可能需要为每个对象创建一个唯一的掩码。通过重复掩码并给每个重复的掩码分配一个唯一的标识符,你可以确保每个掩码都能正确地映射到它所代表的对象上。
# 例如,如果 image_masks 原本是一个形状为 (1, H, W) 的数组,其中包含了一个掩码,而 nl 是 5,那么 np.repeat(image_masks, nl, axis=0) 将会生成一个新的数组,形状为 (5, H, W) ,其中包含了原始掩码的 5 个副本。
image_masks = np.repeat(image_masks, nl, axis=0)
# 根据索引将掩码设置为1或0。
# 这个操作通常用于以下情况 :
# 当你需要将一个包含多个掩码的数组转换为一个二进制掩码数组,其中每个掩码只包含一个对象的掩码信息。
# image_masks 数组中的每个元素可能代表了不同对象的掩码,通过与 index 比较, np.where 函数可以创建一个新的数组,其中每个位置的值表示该位置是否属于对应的对象。
# 例如,如果 image_masks 是一个形状为 (5, H, W) 的数组,而 index 是一个形状为 (5, 1, 1) 的数组,包含了值 [1, 2, 3, 4, 5] ,那么 np.where(image_masks == index, 1.0, 0.0) 将会生成一个新的数组,其中每个掩码只包含对应索引位置的值为 1.0,其余位置的值为 0.0。
# 这样,每个掩码都被转换成了一个二进制掩码,只包含一个对象的信息。
image_masks = np.where(image_masks == index, 1.0, 0.0)
# 复制当前图像。
# 复制 Annotator 对象中的图像数据。
im = np.asarray(annotator.im).copy()
# 遍历掩码并绘制。
# 遍历每个掩码。
for j in range(len(image_masks)):
# 如果只有标签或置信度高于阈值,则绘制掩码。
if labels or conf[j] > conf_thres:
# 获取类别对应的颜色。
color = colors(classes[j])
# 获取掩码的高度和宽度。
mh, mw = image_masks[j].shape
# 如果掩码的尺寸与图像的尺寸不匹配,需要调整掩码尺寸。
if mh != h or mw != w:
# 将掩码转换为 uint8 类型。
mask = image_masks[j].astype(np.uint8)
# 调整掩码尺寸。
mask = cv2.resize(mask, (w, h))
# 将掩码转换为布尔类型。
mask = mask.astype(bool)
# 如果尺寸匹配,直接将掩码转换为布尔类型。
else:
mask = image_masks[j].astype(bool)
# 捕获并忽略异常,用于处理绘制掩码时可能发生的错误。
with contextlib.suppress(Exception):
# 将掩码应用到图像上,通过混合原图像和掩码颜色实现半透明效果。
im[y : y + h, x : x + w, :][mask] = (
im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6
)
# 更新 Annotator 对象。将处理后的图像数据更新回 Annotator 对象。
# def fromarray(self, im):
# -> 它用于更新类实例中的图像数据。这个方法的主要作用是将一个 NumPy 数组或一个 PIL 图像对象转换为 PIL 图像对象,并更新类实例中的 self.im 属性。
# -> 同时,它还创建了一个 ImageDraw.Draw 对象,用于后续的图像绘制操作。这使得类实例可以在 PIL 图像上进行各种绘制操作,如绘制文本、线条、矩形等。
annotator.fromarray(im)
# 这段代码的目的是将掩码信息添加到拼贴画布上,以便更直观地展示图像中的目标区域
# 这段代码是 plot_images 函数的最后部分,它处理图像的保存和回调函数的执行。
# 检查是否需要保存图像。
# 如果 save 参数为 False ,即不需要保存图像。
if not save:
# 返回图像数据。返回 Annotator 对象中的图像数据,以 NumPy 数组的形式。这样,即使不保存到文件系统,也可以在内存中使用或进一步处理图像。
return np.asarray(annotator.im)
# 保存图像。
# 如果 save 参数为 True ,即需要保存图像,使用 Annotator 对象的 im 属性(图像)并调用其 save 方法,将图像保存到文件系统,文件名为 fname 。
annotator.im.save(fname) # save
# 执行回调函数。 on_plot 回调函数提供了额外的自定义能力,允许用户在图像保存后执行特定的操作。
# 如果提供了 on_plot 回调函数。
if on_plot:
# 执行回调函数,并将文件名 fname 作为参数传递。这允许在图像保存后执行额外的操作,例如通知用户、记录文件路径或执行其他自定义逻辑。
on_plot(fname)
# 这段代码的目的是提供灵活性,允许用户选择是否保存图像,并在保存后执行自定义操作。
# 这个函数是一个复杂的例子,展示了如何将多个图像和相关的元数据绘制到一个马赛克图像中,并提供了灵活的参数来控制输出。
7.def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):
# 这段代码定义了一个名为 plot_results 的函数,它用于绘制存储在 CSV 文件中的结果数据。这个函数使用了 Matplotlib 库来创建图表,并根据提供的参数决定绘制哪种类型的图表。
# 这是另一个装饰器,用于设置 matplotlib 的绘图参数,以确保图表的美观和一致性。
# def plt_settings(rcparams=None, backend="Agg"): -> 它用于在函数执行期间临时应用 matplotlib 的配置参数(rc参数)和后端设置,并在函数执行完毕后恢复原来的设置。这个装饰器可以在任何需要临时改变 matplotlib 设置的函数上使用。
@plt_settings()
# 参数解释。
# 1.file : 指定一个 CSV 文件的路径,其中包含了要绘制的数据。
# 2.dir : 指定一个目录路径,如果 file 参数为空,则使用此目录来查找结果文件。
# 3.segment : 布尔值,指示是否绘制分割(segmentation)结果。
# 4.pose : 布尔值,指示是否绘制姿态(pose)估计结果。
# 5.classify : 布尔值,指示是否绘制分类(classification)结果。
# 6.on_plot : 一个可选的回调函数,用于在图表保存后执行。
def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):
# 从结果 CSV 文件绘制训练结果。该函数支持各种类型的数据,包括分割、姿势估计和分类。绘图将以“results.png”的形式保存在 CSV 所在的目录中。
"""
Plot training results from a results CSV file. The function supports various types of data including segmentation,
pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
Args:
file (str, optional): Path to the CSV file containing the training results. Defaults to 'path/to/results.csv'.
dir (str, optional): Directory where the CSV file is located if 'file' is not provided. Defaults to ''.
segment (bool, optional): Flag to indicate if the data is for segmentation. Defaults to False.
pose (bool, optional): Flag to indicate if the data is for pose estimation. Defaults to False.
classify (bool, optional): Flag to indicate if the data is for classification. Defaults to False.
on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument.
Defaults to None.
Example:
```python
from ultralytics.utils.plotting import plot_results
plot_results("path/to/results.csv", segment=True)
```
"""
# 这段代码是 plot_results 函数的一部分,它负责设置绘图环境和读取数据。
# 导入库。pandas 用于数据处理和读取CSV文件。 scipy.ndimage.gaussian_filter1d 用于对数据进行高斯滤波平滑处理。
import pandas as pd # scope for faster 'import ultralytics'
from scipy.ndimage import gaussian_filter1d
# 设置保存目录。
# 确定保存图表的目录。如果提供了 file 参数,则使用该文件的父目录;如果 file 参数为空,则使用 dir 参数指定的目录。
save_dir = Path(file).parent if file else Path(dir)
# 创建图表和轴。
# 根据 classify 、 segment 和 pose 参数的值,决定创建多少个子图( fig )和轴( ax )。这些参数控制绘制不同类型的结果(分类、分割、姿态估计)。
# figsize :设置图表的大小。 tight_layout :确保子图之间的间距合适,避免标签重叠。
if classify:
fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)
# 设置索引列表。 index :根据不同的绘图类型,设置一个索引列表,这个列表指定了哪些列的数据将被绘制在图表上。
index = [1, 4, 2, 3]
elif segment:
fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
index = [1, 2, 3, 4, 5, 6, 9, 10, 13, 14, 15, 16, 7, 8, 11, 12]
elif pose:
fig, ax = plt.subplots(2, 9, figsize=(21, 6), tight_layout=True)
index = [1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 18, 8, 9, 12, 13]
else:
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
index = [1, 2, 3, 4, 5, 8, 9, 10, 6, 7]
# np.ndarray.ravel()
# ravel() 函数是 NumPy 库中的一个方法,它将多维数组(ndarray)展平成一维数组。这个方法不会复制数组的数据,而是返回一个新的视图(view),这个视图与原始数组共享相同的数据。
# 参数 :
# order :可选参数,指定展平的顺序。默认是 'C',表示按行主序(C-style),即按列优先顺序展平。如果设置为 'F',则表示按列主序(Fortran-style),即按行优先顺序展平。如果设置为 'A',则保持数组的原始顺序。
# 返回值 :
# 返回一个展平后的一维数组。
# 展平轴数组。将轴数组展平成一维,方便后续遍历。
ax = ax.ravel()
# glob.glob(path, *, recursive=False, root_dir=None)
# glob.glob() 是 Python 标准库 glob 模块中的一个函数,它用于从文件系统中查找匹配特定模式的文件路径。
# 1. 参数定义 :
# path :一个字符串,表示要匹配的文件路径模式,可以包含通配符 * 、 ? 、 [...] 等。
# recursive :一个布尔值,默认为 False ,表示是否递归搜索子目录。
# root_dir :一个字符串或 Path 对象,表示搜索的根目录,默认为当前工作目录。
# 2. 返回值 :
# 返回一个列表,包含所有匹配模式的文件路径。
# 3. 通配符支持 :
# * :匹配任意数量的字符(包括零个)。
# ? :匹配任意单个字符。
# [...] :匹配方括号内的任意单个字符。
# ** :当 recursive=True 时,匹配任意数量的目录。
# 读取数据文件。在保存目录下查找所有以 "results" 开头的CSV文件。
files = list(save_dir.glob("results*.csv"))
# 断言检查。确保至少找到一个CSV文件,否则抛出错误,提示没有找到任何文件来绘制图表。
assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot." # 在 {save_dir.resolve()} 中未找到 results.csv 文件,没有可绘制的内容。
# 这段代码的主要作用是设置绘图环境,包括确定图表的布局和读取数据文件。
# 这段代码是 plot_results 函数的后半部分,它负责处理每个CSV文件中的数据,绘制图表,并保存图表。
# 遍历文件。循环遍历之前找到的所有CSV文件。
for f in files:
# 读取数据。
try:
# 使用 pandas 读取当前文件的数据。
data = pd.read_csv(f)
# 获取数据的列名,并去除可能的前后空格。
s = [x.strip() for x in data.columns]
# 提取数据列。
# 提取第一列数据作为x轴的值。
x = data.values[:, 0]
# 遍历 index 列表,这个列表指定了要绘制的数据列。
for i, j in enumerate(index):
# 数据处理和绘图。
# 提取指定列的数据,并转换为浮点数。
y = data.values[:, j].astype("float")
# y[y == 0] = np.nan # don't show zero values
# 在对应的子图上绘制原始数据点。
ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results
# 在同一子图上绘制经过高斯滤波平滑后的数据线。
ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line
# 设置子图的标题,标题为对应列的列名。
ax[i].set_title(s[j], fontsize=12)
# if j in {8, 9, 10}: # share train and val loss y axes
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
# 异常处理。块捕获在读取或绘制过程中可能出现的任何异常,并记录警告。
except Exception as e:
LOGGER.warning(f"WARNING: Plotting error for {f}: {e}") # 警告:{f} 的绘图错误:{e}。
# 设置图例。在第二个子图上设置图例,这里假设第二个子图是用于展示图例的。
ax[1].legend()
# 保存图表。
# 确定保存图表的文件路径。
fname = save_dir / "results.png"
# 保存图表为PNG文件,设置DPI为200以获得较高的图像质量。
fig.savefig(fname, dpi=200)
# 关闭图表,释放内存。
plt.close()
# 执行回调函数。
# 检查是否提供了回调函数。
if on_plot:
# 如果提供了回调函数,则执行它,并传递保存的文件路径。
on_plot(fname)
# 这段代码的目的是将CSV文件中的数据可视化,并提供平滑处理后的曲线,以便更好地分析结果。通过回调函数,用户可以在图表保存后执行自定义操作,例如进一步处理图表或通知用户图表已准备好。
# 这个函数是一个通用的结果绘制工具,可以根据需要绘制不同类型的结果,并允许用户通过回调函数进行进一步的自定义处理。
8.def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"):
# 这段代码定义了一个名为 plt_color_scatter 的函数,它用于创建一个散点图,其中每个点的颜色基于两个变量 v 和 f 的二维直方图。
# 参数解释 :
# 1.v :第一个变量的值,用于散点图的x轴。
# 2.f :第二个变量的值,用于散点图的y轴。
# 3.bins :用于计算直方图的箱子(bin)数量,默认为20。
# 4.cmap :颜色映射表,用于为直方图的值分配颜色,默认为 "viridis"。
# 5.alpha :散点图的透明度,默认为0.8。
# 6.edgecolors :散点图边缘的颜色,默认为 "none"。
def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"):
# 绘制基于二维直方图的点着色散点图。
"""
Plots a scatter plot with points colored based on a 2D histogram.
Args:
v (array-like): Values for the x-axis.
f (array-like): Values for the y-axis.
bins (int, optional): Number of bins for the histogram. Defaults to 20.
cmap (str, optional): Colormap for the scatter plot. Defaults to 'viridis'.
alpha (float, optional): Alpha for the scatter plot. Defaults to 0.8.
edgecolors (str, optional): Edge colors for the scatter plot. Defaults to 'none'.
Examples:
>>> v = np.random.rand(100)
>>> f = np.random.rand(100)
>>> plt_color_scatter(v, f)
"""
# Calculate 2D histogram and corresponding colors
# 计算二维直方图。使用 numpy 的 histogram2d 函数计算 v 和 f 的二维直方图。这会返回直方图的值 hist ,以及直方图在x轴和y轴上的边界 xedges 和 yedges 。
hist, xedges, yedges = np.histogram2d(v, f, bins=bins)
# 生成颜色数组。
# 这是一个列表推导式,用于为每个点 v[i] 和 f[i] 计算对应的直方图值,从而确定其颜色。
colors = [
hist[
# np.digitize(v[i], xedges, right=True) - 1 :找到 v[i] 在 xedges 中的位置,并减去1,得到直方图的行索引。
min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1),
# np.digitize(f[i], yedges, right=True) - 1 :找到 f[i] 在 yedges 中的位置,并减去1,得到直方图的列索引。
# min(..., hist.shape[0] - 1) 和 min(..., hist.shape[1] - 1) :确保索引不会超出直方图的边界。
min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1),
]
for i in range(len(v))
]
# Scatter plot
# 绘制散点图。使用 matplotlib 的 scatter 函数绘制散点图。点的颜色 c 由 colors 数组指定,颜色映射表为 cmap ,透明度为 alpha ,边缘颜色为 edgecolors 。
plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)
# 这个函数的目的是创建一个颜色编码的散点图,其中每个点的颜色反映了其在 v - f 平面上的密度。这种类型的图表可以直观地显示两个变量之间的关系,以及数据的分布情况。通过颜色的变化,可以观察到数据的聚集区域和稀疏区域。
9.def plot_tune_results(csv_file="tune_results.csv"):
# 这段代码定义了一个名为 plot_tune_results 的函数,它用于从CSV文件中读取调优结果,并绘制两种类型的图表:散点图和 fitness 与迭代次数的图表。
# 参数解释。
# 1.csv_file :CSV文件的路径,其中包含了调优结果数据,默认为 "tune_results.csv"。
def plot_tune_results(csv_file="tune_results.csv"):
# 绘制存储在“tune_results.csv”文件中的演化结果。该函数为 CSV 中的每个键生成一个散点图,并根据适应度得分进行颜色编码。性能最佳的配置会在图中突出显示。
"""
Plot the evolution results stored in an 'tune_results.csv' file. The function generates a scatter plot for each key
in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
Args:
csv_file (str, optional): Path to the CSV file containing the tuning results. Defaults to 'tune_results.csv'.
Examples:
>>> plot_tune_results("path/to/tune_results.csv")
"""
# 导入库。 pandas :用于数据处理的库。 scipy.ndimage.gaussian_filter1d :用于数据平滑处理的库。
import pandas as pd # scope for faster 'import ultralytics'
from scipy.ndimage import gaussian_filter1d
# 这段代码定义了一个名为 _save_one_file 的函数,其目的是将当前 Matplotlib 图表保存到指定的文件路径,并记录保存操作。
# 定义了一个名为 _save_one_file 的函数,它接受一个参数。
# 1.file :这个参数预期是一个字符串,表示要保存图表的文件路径。
def _save_one_file(file):
# 将一个 matplotlib 图保存到“文件”。
"""Save one matplotlib plot to 'file'."""
# 保存图表。这行代码使用 Matplotlib 的 savefig 函数将当前图表保存为文件。 file 参数是文件路径, dpi=200 设置了图像的分辨率为200 DPI,确保图像清晰。
plt.savefig(file, dpi=200)
# plt.close(fig=None, **kwargs)
# plt.close() 是 Matplotlib 库中的一个函数,用于关闭当前的图形窗口或者删除当前的图形对象。
# 参数说明 :
# fig :这是一个可选参数,可以是以下几种类型之一 :
# Figure 实例 :如果你传递一个 Figure 对象,那么这个图形对象会被关闭。
# 字符串 :如果你传递一个字符串,那么会关闭标签(label)与该字符串匹配的图形。
# 整数 :如果你传递一个整数,那么会关闭编号(number)与该整数匹配的图形。
# 功能描述 :
# plt.close() 函数用于关闭一个或多个图形窗口,或者删除一个或多个图形对象,以便释放内存。这对于在绘制大量图形或者在循环中绘制图形时管理资源非常有用。
# 如果没有指定 fig 参数, plt.close() 默认关闭所有当前的图形窗口。
# 关闭图表。这行代码使用 Matplotlib 的 close 函数关闭当前图表。这是一个好习惯,特别是在循环中绘制多个图表时,可以帮助释放内存。
plt.close()
# 记录保存操作。这行代码使用一个名为 LOGGER 的日志记录器来记录一条信息,表明图表已经被保存,并且文件路径被记录在日志中。 f"Saved {file}" 是一个格式化字符串,它将文件路径插入到日志消息中。
LOGGER.info(f"Saved {file}")
# 这个函数通常被用作一个辅助函数,可以在主函数中绘制完图表后调用,以确保图表被保存并记录。这种模块化的设计使得代码更加清晰和易于维护。
# Scatter plots for each hyperparameter
# 这段代码是 plot_tune_results 函数的一部分,它负责创建每个超参数的散点图,并保存这些图表。
# 读取CSV文件。
# 将 csv_file 参数转换为 Path 对象。
csv_file = Path(csv_file)
# 使用 pandas 读取 CSV 文件中的数据。
data = pd.read_csv(csv_file)
# 提取关键指标。
# 假设 CSV 文件的第一列是指标列(例如, fitness )。
num_metrics_columns = 1
# 提取除了指标列以外的其他列名(超参数)。
keys = [x.strip() for x in data.columns][num_metrics_columns:]
# 提取数据。
# 将数据转换为 NumPy 数组。
x = data.values
# 提取 fitness 列的数据。
fitness = x[:, 0] # fitness
# 找到 fitness 最大值的索引。
j = np.argmax(fitness) # max fitness index
# 计算图表的行列数,以便均匀分布散点图。
n = math.ceil(len(keys) ** 0.5) # columns and rows in plot
# 创建散点图。
# 创建一个新的图形,并设置图形大小和布局。
plt.figure(figsize=(10, 10), tight_layout=True)
# 遍历每个超参数。
for i, k in enumerate(keys):
# 提取当前超参数列的数据。
v = x[:, i + num_metrics_columns]
# 获取最佳单次结果的值。
mu = v[j] # best single result
# 创建子图。
plt.subplot(n, n, i + 1)
# 使用 plt_color_scatter 函数绘制颜色编码的散点图。
plt_color_scatter(v, fitness, cmap="viridis", alpha=0.8, edgecolors="none")
# 在散点图上标记最佳单次结果的位置。
plt.plot(mu, fitness.max(), "k+", markersize=15)
# 设置子图的标题,显示超参数名称和最佳值。
plt.title(f"{k} = {mu:.3g}", fontdict={"size": 9}) # limit to 40 characters
# 设置轴标签的大小。
plt.tick_params(axis="both", labelsize=8) # Set axis label size to 8
# 除了最后一行外,其他子图的y轴不显示刻度。
if i % n != 0:
plt.yticks([])
# 保存散点图。使用 _save_one_file 函数保存散点图为PNG文件。
_save_one_file(csv_file.with_name("tune_scatter_plots.png"))
# 这段代码的目的是将调优过程中收集的超参数数据可视化,以便分析不同超参数对模型性能的影响。通过颜色编码的散点图,可以直观地观察到数据的分布情况,以及每个超参数的最佳值。
# Fitness vs iteration
# 这段代码负责绘 fitness 与迭代次数(iteration)之间的关系图,并保存这个图表。
# 创建迭代次数序列。创建一个从1开始到 fitness 数组长度的整数序列,代表迭代次数。
x = range(1, len(fitness) + 1)
# 创建新图表。创建一个新的图表,设置图表的大小为宽10英寸、高6英寸,并启用紧凑布局。
plt.figure(figsize=(10, 6), tight_layout=True)
# 绘制 fitness 曲线。绘制 fitness 与迭代次数的关系图,使用圆圈标记每个数据点,不连接线段,并为这组数据设置图例标签“fitness”。
plt.plot(x, fitness, marker="o", linestyle="none", label="fitness")
# 绘制平滑曲线。使用高斯滤波对 fitness 数据进行平滑处理, sigma=3 指定平滑程度,然后绘制平滑后的曲线,使用冒号样式的线条,并设置图例标签“smoothed”。
plt.plot(x, gaussian_filter1d(fitness, sigma=3), ":", label="smoothed", linewidth=2) # smoothing line
# 设置图表标题和轴标签。
# 设置图表的标题为“Fitness vs Iteration”。
plt.title("Fitness vs Iteration")
# 设置x轴的标签为“Iteration”。
plt.xlabel("Iteration")
# 设置y轴的标签为“Fitness”。
plt.ylabel("Fitness")
# 添加网格和图例。
# 在图表中添加网格。
plt.grid(True)
# 显示图例。
plt.legend()
# 保存图表。调用 _save_one_file 函数,将当前图表保存为PNG文件,文件名为“tune_fitness.png”,这个文件保存在与CSV文件相同的目录下。
_save_one_file(csv_file.with_name("tune_fitness.png"))
# 这段代码的目的是可视化调优过程中健身度的变化趋势,通过原始数据和平滑曲线的对比,可以更清晰地观察到健身度随迭代次数的变化情况。这对于分析调优过程的效率和效果非常有用。
# 这个函数的目的是将调优过程中收集的数据可视化,以便分析不同超参数对模型性能的影响,以及观察模型性能随迭代次数的变化。通过这两种图表,用户可以直观地理解调优过程和结果。
10.def output_to_target(output, max_det=300):
# 这段代码定义了一个名为 output_to_target 的函数,它将模型的输出转换为绘图所需的目标格式。
# 参数解释。
# 1.output :模型的输出,预期是一个包含多个检测结果的列表或张量。
# 2.max_det :每个批次中考虑的最大检测数量,默认为300。
def output_to_target(output, max_det=300):
# 将模型输出转换为目标格式 [batch_id, class_id, x, y, w, h, conf] 以进行绘图。
"""Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
# 初始化目标列表。初始化一个空列表,用于存储转换后的目标数据。
targets = []
# 遍历输出。遍历模型输出, i 是索引, o 是每个批次的输出。
for i, o in enumerate(output):
# 提取检测框和相关信息。
# 从每个批次的输出中提取前 max_det 个检测结果,并将其分割成三个部分:边界框( box )、置信度( conf )和类别( cls )。这里假设输出的前6列分别对应边界框的4个坐标、置信度和类别。
box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
# 创建批次索引。创建一个与置信度张量相同长度的批次索引张量,用于标识每个检测结果属于哪个批次。
j = torch.full((conf.shape[0], 1), i)
# 转换边界框格式。ops.xyxy2xywh(box) :调用 ops 模块中的 xyxy2xywh 函数,将边界框的坐标从 (x1, y1, x2, y2) 格式转换为 (x, y, w, h) 格式。
# 合并数据。将 批次索引 、 类别 、 转换后的边界框 和 置信度 合并为一个张量,并添加到 targets 列表中。
targets.append(torch.cat((j, cls, ops.xyxy2xywh(box), conf), 1))
# 合并所有批次的目标数据。将 targets 列表中的所有张量合并为一个张量,并转换为 NumPy 数组。
targets = torch.cat(targets, 0).numpy()
# 返回目标格式数据。返回转换后的目标数据,分别对应 批次索引 、 类别 、 边界框坐标 和 置信度 。
return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
# 这个函数的目的是将模型的输出转换为一个统一的格式,以便后续的绘图和分析。这种格式通常包括批次索引、类别ID、边界框的坐标(x, y, w, h)、置信度等信息。通过这种方式,可以方便地将检测结果绘制到图像上。
11.def output_to_rotated_target(output, max_det=300):
# 这段代码定义了一个名为 output_to_rotated_target 的函数,它将模型输出转换为绘图所需的目标格式,特别是用于处理旋转边界框(即边界框包含角度信息)。
# 参数解释。
# 1.output :模型的输出,预期是一个包含多个检测结果的列表或张量。
# 2.max_det :每个批次中考虑的最大检测数量,默认为300。
def output_to_rotated_target(output, max_det=300):
# # 将模型输出转换为目标格式 [batch_id, class_id, x, y, w, h, conf] 以进行绘图。
"""Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
# 初始化目标列表。初始化一个空列表,用于存储转换后的目标数据。
targets = []
# 遍历输出。 遍历模型输出, i 是索引, o 是每个批次的输出。
for i, o in enumerate(output):
# 提取检测框和相关信息。
# 从每个批次的输出中提取前 max_det 个检测结果,并将其分割成四个部分:边界框( box )、置信度( conf )、类别( cls )和角度( angle )。这里假设输出的前几列分别对应边界框的4个坐标、置信度、类别和角度。
box, conf, cls, angle = o[:max_det].cpu().split((4, 1, 1, 1), 1)
# 创建批次索引。创建一个与置信度张量相同长度的批次索引张量,用于标识每个检测结果属于哪个批次。
j = torch.full((conf.shape[0], 1), i)
# 合并数据。将批次索引、类别、边界框、角度和置信度合并为一个张量,并添加到 targets 列表中。
targets.append(torch.cat((j, cls, box, angle, conf), 1))
# 合并所有批次的目标数据。将 targets 列表中的所有张量合并为一个张量,并转换为 NumPy 数组。
targets = torch.cat(targets, 0).numpy()
# 返回目标格式数据。返回转换后的目标数据,分别对应批次索引、类别、边界框坐标和角度、置信度。
return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
# 这个函数的目的是将模型的输出转换为一个统一的格式,以便后续的绘图和分析。这种格式通常包括批次索引、类别ID、边界框的坐标(可能包括中心点坐标、宽度、高度和旋转角度)、置信度等信息。通过这种方式,可以方便地将检测结果绘制到图像上,特别是对于需要表示旋转边界框的场景。
12.def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")):
# 这段代码定义了一个名为 feature_visualization 的函数,其目的是将深度学习模型中的中间特征(通常是卷积层的输出)进行可视化,并保存为图像文件。
# 它接受以下参数.
# 1.x :要可视化的特征张量。
# 2.module_type :模块类型,用于确定是否应该跳过可视化。
# 3.stage :当前的阶段或层级。
# 4.n=32 :要可视化的特征图的数量,默认为32。
# 5.save_dir :保存可视化结果的目录,默认为 Path("runs/detect/exp") 。
def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")):
# 在推理过程中可视化给定模型模块的特征图。
"""
Visualize feature maps of a given model module during inference.
Args:
x (torch.Tensor): Features to be visualized.
module_type (str): Module type.
stage (int): Module stage within the model.
n (int, optional): Maximum number of feature maps to plot. Defaults to 32.
save_dir (Path, optional): Directory to save results. Defaults to Path('runs/detect/exp').
"""
# 这行代码遍历一个集合,集合中包含了所有模型头的名称。
for m in {"Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder"}: # all model heads
# 这行代码检查 module_type 是否包含集合中的任何一个模型头名称。
if m in module_type:
# 如果 module_type 包含集合中的名称,则函数返回,不执行任何可视化操作。
return
# 这行代码检查 x 是否是一个 PyTorch 张量。
if isinstance(x, torch.Tensor):
# 如果 x 是张量,这行代码获取其形状,包括 批次大小 、 通道数 、 高度 和 宽度 。
_, channels, height, width = x.shape # batch, channels, height, width
# 这行代码检查特征图是否有超过1的高度和宽度,即是否有足够的空间信息进行可视化。
if height > 1 and width > 1:
# 如果特征图足够大,这行代码构建保存文件的路径和文件名。
f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
# torch.chunk(input, chunks, dim=0)
# torch.chunk 是 PyTorch 中的一个函数,它将张量(tensor)分割成指定数量的块(chunks)。每个块在指定的维度上具有相等的大小。如果张量不能被均匀分割,则最后一个块可能会比其他块小。
# 参数 :
# input :要被分割的输入张量。
# chunks :一个整数,表示要将输入张量分割成多少块。
# dim :一个整数,指定沿着哪个维度进行分割。默认是0,即第一个维度。
# 返回值 :
# 返回一个包含分割后块的元组,每个块都是一个张量。
# 这行代码将批次中的第一个特征图( x[0] )按照通道数分割成块。
blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
# 这行代码确保 n 不超过通道数。
n = min(n, channels) # number of plots
# math.ceil(x)
# math.ceil 是 Python 标准库 math 模块中的一个函数,用于返回大于或等于给定数字的最小整数。这个函数通常用于将浮点数向上舍入到最接近的整数。
# 参数 :
# x :要向上舍入的数字,可以是整数或浮点数。
# 返回值 :
# 返回大于或等于 x 的最小整数。
# math.ceil 函数在处理需要整数结果的数值计算时非常有用,例如在确定数组大小、分配内存或进行其他需要整数运算的场景中。
# 这行代码创建一个子图,最多8列,根据 n 的值自动计算行数。
_, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
# np.ndarray.ravel()
# ravel() 函数是 NumPy 库中的一个方法,它将多维数组(ndarray)展平成一维数组。这个方法不会复制数组的数据,而是返回一个新的视图(view),这个视图与原始数组共享相同的数据。
# 参数 :
# order :可选参数,指定展平的顺序。默认是 'C',表示按行主序(C-style),即按列优先顺序展平。如果设置为 'F',则表示按列主序(Fortran-style),即按行优先顺序展平。如果设置为 'A',则保持数组的原始顺序。
# 返回值 :
# 返回一个展平后的一维数组。
# 这行代码将子图的轴(ax)展平,以便于迭代。
ax = ax.ravel()
# 这行代码调整子图之间的间距。
plt.subplots_adjust(wspace=0.05, hspace=0.05)
# 这行代码开始一个循环,遍历每个要可视化的特征图块。
for i in range(n):
# 这行代码在子图上显示每个特征图块。
ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
# 这行代码关闭子图的坐标轴。
ax[i].axis("off")
# 这行代码记录日志,显示正在保存的文件和可视化的特征图数量。
LOGGER.info(f"Saving {f}... ({n}/{channels})")
# 这行代码保存子图为PNG图像文件。
plt.savefig(f, dpi=300, bbox_inches="tight")
# 这行代码关闭当前的绘图窗口,释放资源。
plt.close()
# 这行代码将第一个特征图保存为NumPy文件(.npy)。
np.save(str(f.with_suffix(".npy")), x[0].cpu().numpy()) # npy save
# 总结来说, feature_visualization 函数用于可视化模型中间层的特征图,并将可视化结果保存为图像和NumPy文件。这个函数可以帮助研究人员和开发人员理解模型在不同阶段是如何响应输入数据的。