Yolov8-源码解析-四十三-

news/2024/9/24 13:43:30

Yolov8 源码解析(四十三)

.\yolov8\ultralytics\utils\patches.py

# Ultralytics YOLO 🚀, AGPL-3.0 license
"""Monkey patches to update/extend functionality of existing functions."""import time
from pathlib import Pathimport cv2  # 导入OpenCV库
import numpy as np  # 导入NumPy库
import torch  # 导入PyTorch库# OpenCV Multilanguage-friendly functions ------------------------------------------------------------------------------
_imshow = cv2.imshow  # 将cv2.imshow赋值给_imshow变量,避免递归错误def imread(filename: str, flags: int = cv2.IMREAD_COLOR):"""Read an image from a file.Args:filename (str): Path to the file to read.flags (int, optional): Flag that can take values of cv2.IMREAD_*. Defaults to cv2.IMREAD_COLOR.Returns:(np.ndarray): The read image."""return cv2.imdecode(np.fromfile(filename, np.uint8), flags)  # 使用cv2.imdecode函数读取文件并返回图像数据def imwrite(filename: str, img: np.ndarray, params=None):"""Write an image to a file.Args:filename (str): Path to the file to write.img (np.ndarray): Image to write.params (list of ints, optional): Additional parameters. See OpenCV documentation.Returns:(bool): True if the file was written, False otherwise."""try:cv2.imencode(Path(filename).suffix, img, params)[1].tofile(filename)  # 使用cv2.imencode将图像编码并写入文件return Trueexcept Exception:return Falsedef imshow(winname: str, mat: np.ndarray):"""Displays an image in the specified window.Args:winname (str): Name of the window.mat (np.ndarray): Image to be shown."""_imshow(winname.encode("unicode_escape").decode(), mat)  # 使用_imshow显示指定名称的窗口中的图像# PyTorch functions ----------------------------------------------------------------------------------------------------
_torch_load = torch.load  # 将torch.load赋值给_torch_load变量,避免递归错误
_torch_save = torch.savedef torch_load(*args, **kwargs):"""Load a PyTorch model with updated arguments to avoid warnings.This function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings.Args:*args (Any): Variable length argument list to pass to torch.load.**kwargs (Any): Arbitrary keyword arguments to pass to torch.load.Returns:(Any): The loaded PyTorch object.Note:For PyTorch versions 2.0 and above, this function automatically sets 'weights_only=False'if the argument is not provided, to avoid deprecation warnings."""from ultralytics.utils.torch_utils import TORCH_1_13  # 导入TORCH_1_13变量,用于检测PyTorch版本if TORCH_1_13 and "weights_only" not in kwargs:kwargs["weights_only"] = False  # 如果使用的是PyTorch 1.13及以上版本且没有指定'weights_only'参数,则设置为Falsereturn _torch_load(*args, **kwargs)  # 调用torch.load加载模型def torch_save(*args, use_dill=True, **kwargs):"""Optionally use dill to serialize lambda functions where pickle does not, adding robustness with 3 retries andexponential standoff in case of save failure.```py# 此处代码块是省略部分,不需要注释```"""pass  # torch_save函数暂时没有实现内容,直接返回"""Args:*args (tuple): Positional arguments to pass to torch.save.use_dill (bool): Whether to try using dill for serialization if available. Defaults to True.**kwargs (Any): Keyword arguments to pass to torch.save."""# 尝试使用 dill 序列化库(如果可用),否则使用 pickletry:assert use_dillimport dill as pickleexcept (AssertionError, ImportError):import pickle# 如果 kwargs 中没有指定 pickle_module,则默认使用 pickle 库if "pickle_module" not in kwargs:kwargs["pickle_module"] = pickle# 最多尝试保存 4 次(包括初始尝试),以处理可能的运行时错误for i in range(4):  # 3 retriestry:# 调用 _torch_save 函数尝试保存数据return _torch_save(*args, **kwargs)except RuntimeError as e:  # unable to save, possibly waiting for device to flush or antivirus scan# 如果是最后一次尝试保存,则抛出原始的 RuntimeErrorif i == 3:raise e# 等待指数增长的时间,用于避免设备刷新或者反病毒扫描等问题time.sleep((2**i) / 2)  # exponential standoff: 0.5s, 1.0s, 2.0s

.\yolov8\ultralytics\utils\plotting.py

# 导入需要的库
import contextlib  # 上下文管理模块,用于创建上下文管理器
import math  # 数学函数模块,提供数学函数的实现
import warnings  # 警告模块,用于处理警告信息
from pathlib import Path  # 路径操作模块,用于处理文件和目录路径
from typing import Callable, Dict, List, Optional, Union  # 类型提示模块,用于类型注解import cv2  # OpenCV图像处理库
import matplotlib.pyplot as plt  # 绘图库matplotlib的pyplot模块
import numpy as np  # 数值计算库numpy
import torch  # 深度学习框架PyTorch
from PIL import Image, ImageDraw, ImageFont  # Python Imaging Library,用于图像处理from PIL import __version__ as pil_version  # PIL版本信息from ultralytics.utils import LOGGER, TryExcept, ops, plt_settings, threaded  # 导入自定义工具函数和变量
from ultralytics.utils.checks import check_font, check_version, is_ascii  # 导入自定义检查函数
from ultralytics.utils.files import increment_path  # 导入路径处理函数# 颜色类,包含Ultralytics默认色彩方案和转换函数
class Colors:"""Ultralytics default color palette https://ultralytics.com/.This class provides methods to work with the Ultralytics color palette, including converting hex color codes toRGB values.Attributes:palette (list of tuple): List of RGB color values.n (int): The number of colors in the palette.pose_palette (np.ndarray): A specific color palette array with dtype np.uint8."""def __init__(self):"""Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""hexs = ("042AFF","0BDBEB","F3F3F3","00DFB7","111F68","FF6FDD","FF444F","CCED00","00F344","BD00FF","00B4FF","DD00BA","00FFFF","26C000","01FFB3","7D24FF","7B0068","FF1B6C","FC6D2F","A2FF0B",)# 初始化颜色调色板,将16进制颜色代码转换为RGB元组self.palette = [self.hex2rgb(f"#{c}") for c in hexs]self.n = len(self.palette)# 预定义特定颜色调色板,用于特定应用场景self.pose_palette = np.array([[255, 128, 0],[255, 153, 51],[255, 178, 102],[230, 230, 0],[255, 153, 255],[153, 204, 255],[255, 102, 255],[255, 51, 255],[102, 178, 255],[51, 153, 255],[255, 153, 153],[255, 102, 102],[255, 51, 51],[153, 255, 153],[102, 255, 102],[51, 255, 51],[0, 255, 0],[0, 0, 255],[255, 0, 0],[255, 255, 255],],dtype=np.uint8,)def __call__(self, i, bgr=False):"""Converts hex color codes to RGB values."""# 返回调色板中第i个颜色的RGB值,支持BGR格式c = self.palette[int(i) % self.n]return (c[2], c[1], c[0]) if bgr else c@staticmethoddef hex2rgb(h):"""Converts hex color codes to RGB values (i.e. default PIL order)."""# 将16进制颜色代码转换为RGB元组(PIL默认顺序)return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))colors = Colors()  # 创建颜色对象实例,用于绘图颜色选择class Annotator:"""Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations."""# 定义类属性,用于图像注释Attributes:# im是要注释的图像,可以是PIL图像(Image.Image)或者numpy数组im (Image.Image or numpy array): The image to annotate.# pil标志指示是否使用PIL库进行注释,而不是cv2pil (bool): Whether to use PIL or cv2 for drawing annotations.# font用于文本注释的字体,可以是ImageFont.truetype或ImageFont.load_defaultfont (ImageFont.truetype or ImageFont.load_default): Font used for text annotations.# lw是用于绘制注释的线条宽度lw (float): Line width for drawing.# skeleton是关键点的骨架结构的列表,其中每个元素是一个列表,表示连接的两个关键点的索引skeleton (List[List[int]]): Skeleton structure for keypoints.# limb_color是绘制骨架连接的颜色调色板,以RGB整数列表形式表示limb_color (List[int]): Color palette for limbs.# kpt_color是绘制关键点的颜色调色板,以RGB整数列表形式表示kpt_color (List[int]): Color palette for keypoints."""# 初始化 Annotator 类,接受图像 im、线宽 line_width、字体大小 font_size、字体名称 font、是否使用 PIL 的标志 pil、示例 exampledef __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"):"""Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""# 检查示例是否包含非 ASCII 字符,用于确定是否使用 PILnon_ascii = not is_ascii(example)  # non-latin labels, i.e. asian, arabic, cyrillic# 检查输入的图像是否为 PIL Image 对象input_is_pil = isinstance(im, Image.Image)# 根据条件判断是否使用 PILself.pil = pil or non_ascii or input_is_pil# 计算线宽,默认为图像尺寸或形状的一半乘以 0.003,取整后至少为 2self.lw = line_width or max(round(sum(im.size if input_is_pil else im.shape) / 2 * 0.003), 2)if self.pil:  # 如果使用 PIL# 如果输入的是 PIL Image,则直接使用;否则将其转换为 PIL Imageself.im = im if input_is_pil else Image.fromarray(im)# 创建一个用于绘制的 ImageDraw 对象self.draw = ImageDraw.Draw(self.im)try:# 根据示例中是否包含非 ASCII 字符,选择适当的字体文件(Unicode 或 Latin)font = check_font("Arial.Unicode.ttf" if non_ascii else font)# 计算字体大小,默认为图像尺寸的一半乘以 0.035,取整后至少为 12size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)# 加载选择的字体文件并设置字体大小self.font = ImageFont.truetype(str(font), size)except Exception:# 如果加载字体文件出错,则使用默认字体self.font = ImageFont.load_default()# 如果 PIL 版本高于等于 9.2.0,则修复 getsize 方法的用法为 getbbox 方法的结果中的宽度和高度if check_version(pil_version, "9.2.0"):self.font.getsize = lambda x: self.font.getbbox(x)[2:4]  # text width, heightelse:  # 如果使用 cv2# 断言输入的图像数据是连续的,否则提出警告assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images."# 如果图像数据不可写,则创建其副本self.im = im if im.flags.writeable else im.copy()# 计算字体粗细,默认为线宽减 1,至少为 1self.tf = max(self.lw - 1, 1)  # font thickness# 计算字体缩放比例,默认为线宽的三分之一self.sf = self.lw / 3  # font scale# 姿态关键点的连接关系self.skeleton = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12],[7, 13], [6, 7], [6, 8], [7, 9], [8, 10], [9, 11], [2, 3],[1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7],]# 姿态关键点连接线的颜色self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]# 姿态关键点的颜色self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]# 深色调色板,用于姿态显示self.dark_colors = {(235, 219, 11), (243, 243, 243), (183, 223, 0), (221, 111, 255),(0, 237, 204), (68, 243, 0), (255, 255, 0), (179, 255, 1),(11, 255, 162),}# 浅色调色板,用于姿态显示self.light_colors = {(255, 42, 4), (79, 68, 255), (255, 0, 189), (255, 180, 0),(186, 0, 221), (0, 192, 38), (255, 36, 125), (104, 0, 123),(108, 27, 255), (47, 109, 252), (104, 31, 17),}def get_txt_color(self, color=(128, 128, 128), txt_color=(255, 255, 255)):"""Assign text color based on background color."""# 检查给定的背景颜色是否为暗色if color in self.dark_colors:# 如果是暗色,则返回预定义的深色文本颜色return 104, 31, 17elif color in self.light_colors:# 如果是亮色,则返回白色作为文本颜色return 255, 255, 255else:# 如果背景颜色既不是暗色也不是亮色,则返回默认的文本颜色return txt_colordef circle_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), margin=2):"""Draws a label with a background rectangle centered within a given bounding box.Args:box (tuple): The bounding box coordinates (x1, y1, x2, y2).label (str): The text label to be displayed.color (tuple, optional): The background color of the rectangle (R, G, B).txt_color (tuple, optional): The color of the text (R, G, B).margin (int, optional): The margin between the text and the rectangle border."""# 如果标签超过3个字符,打印警告信息,并仅使用前三个字符作为圆形标注的文本if len(label) > 3:print(f"Length of label is {len(label)}, initial 3 label characters will be considered for circle annotation!")label = label[:3]# 计算框的中心点坐标x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)# 获取文本的大小text_size = cv2.getTextSize(str(label), cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.15, self.tf)[0]# 计算需要的半径,以适应文本和边距required_radius = int(((text_size[0] ** 2 + text_size[1] ** 2) ** 0.5) / 2) + margin# 在图像上绘制圆形标注cv2.circle(self.im, (x_center, y_center), required_radius, color, -1)# 计算文本位置text_x = x_center - text_size[0] // 2text_y = y_center + text_size[1] // 2# 绘制文本cv2.putText(self.im,str(label),(text_x, text_y),cv2.FONT_HERSHEY_SIMPLEX,self.sf - 0.15,# 获取文本颜色,根据背景颜色自动选择self.get_txt_color(color, txt_color),self.tf,lineType=cv2.LINE_AA,)def text_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), margin=5):"""Draws a label with a background rectangle centered within a given bounding box.Args:box (tuple): The bounding box coordinates (x1, y1, x2, y2).label (str): The text label to be displayed.color (tuple, optional): The background color of the rectangle (R, G, B).txt_color (tuple, optional): The color of the text (R, G, B).margin (int, optional): The margin between the text and the rectangle border."""# Calculate the center of the bounding boxx_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)# Get the size of the texttext_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.1, self.tf)[0]# Calculate the top-left corner of the text (to center it)text_x = x_center - text_size[0] // 2text_y = y_center + text_size[1] // 2# Calculate the coordinates of the background rectanglerect_x1 = text_x - marginrect_y1 = text_y - text_size[1] - marginrect_x2 = text_x + text_size[0] + marginrect_y2 = text_y + margin# Draw the background rectanglecv2.rectangle(self.im, (rect_x1, rect_y1), (rect_x2, rect_y2), color, -1)# Draw the text on top of the rectanglecv2.putText(self.im,  # 目标图像,在其上绘制label,  # 要绘制的文本(text_x, text_y),  # 文本的起始坐标(左下角位置)cv2.FONT_HERSHEY_SIMPLEX,  # 字体类型self.sf - 0.1,  # 字体比例因子self.get_txt_color(color, txt_color),  # 文本颜色self.tf,  # 文本线宽lineType=cv2.LINE_AA,  # 线型)def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):"""Plot masks on image.Args:masks (tensor): Predicted masks on cuda, shape: [n, h, w]colors (List[List[Int]]): Colors for predicted masks, [[r, g, b] * n]im_gpu (tensor): Image is in cuda, shape: [3, h, w], range: [0, 1]alpha (float): Mask transparency: 0.0 fully transparent, 1.0 opaqueretina_masks (bool): Whether to use high resolution masks or not. Defaults to False."""# 如果使用 PIL,先转换为 numpy 数组if self.pil:self.im = np.asarray(self.im).copy()# 如果没有预测到任何 mask,则直接将原始图像拷贝到 self.imif len(masks) == 0:self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255# 如果图像和 masks 不在同一个设备上,则将 im_gpu 移动到 masks 所在的设备上if im_gpu.device != masks.device:im_gpu = im_gpu.to(masks.device)# 将 colors 转换为 torch.tensor,并归一化到 [0, 1] 的范围colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0  # shape(n,3)# 扩展维度以便进行广播操作,将 colors 变为 shape(n,1,1,3)colors = colors[:, None, None]  # shape(n,1,1,3)# 增加一个维度到 masks 上,使其变为 shape(n,h,w,1)masks = masks.unsqueeze(3)  # shape(n,h,w,1)# 将 masks 与颜色相乘,乘以 alpha 控制透明度,得到彩色的 masks,shape(n,h,w,3)masks_color = masks * (colors * alpha)# 计算反向透明度 masks,用于混合原始图像和 masks_color,shape(n,h,w,1)inv_alpha_masks = (1 - masks * alpha).cumprod(0)# 计算最大通道值,用于融合图像和 masks_color,shape(n,h,w,3)mcs = masks_color.max(dim=0).values  # shape(n,h,w,3)# 翻转图像的通道顺序,从 RGB 转为 BGRim_gpu = im_gpu.flip(dims=[0])# 调整张量的维度顺序,从 (3,h,w) 转为 (h,w,3)im_gpu = im_gpu.permute(1, 2, 0).contiguous()# 使用 inv_alpha_masks[-1] 和 mcs 进行图像的混合im_gpu = im_gpu * inv_alpha_masks[-1] + mcs# 将混合后的图像乘以 255,并转为 numpy 数组im_mask = im_gpu * 255im_mask_np = im_mask.byte().cpu().numpy()# 根据 retina_masks 参数选择是否缩放图像self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape)# 如果使用 PIL,将处理后的 numpy 数组转回 PIL 格式,并更新 drawif self.pil:self.fromarray(self.im)def kpts(self, kpts, shape=(640, 640), radius=5, kpt_line=True, conf_thres=0.25):"""Plot keypoints on the image.Args:kpts (tensor): Predicted keypoints with shape [17, 3]. Each keypoint has (x, y, confidence).shape (tuple): Image shape as a tuple (h, w), where h is the height and w is the width.radius (int, optional): Radius of the drawn keypoints. Default is 5.kpt_line (bool, optional): If True, the function will draw lines connecting keypointsfor human pose. Default is True.Note:`kpt_line=True` currently only supports human pose plotting."""if self.pil:# If working with PIL image, convert to numpy array for processingself.im = np.asarray(self.im).copy()  # Convert PIL image to numpy array# Get the number of keypoints and dimensions from the input tensornkpt, ndim = kpts.shape# Check if the keypoints represent a human pose (17 keypoints with 2 or 3 dimensions)is_pose = nkpt == 17 and ndim in {2, 3}# Adjust kpt_line based on whether it's a valid human pose and the argument valuekpt_line &= is_pose  # `kpt_line=True` for now only supports human pose plotting# Loop through each keypoint and plot a circle on the imagefor i, k in enumerate(kpts):# Determine color for the keypoint based on whether it's a pose or notcolor_k = [int(x) for x in self.kpt_color[i]] if is_pose else colors(i)x_coord, y_coord = k[0], k[1]# Check if the keypoint coordinates are within image boundariesif x_coord % shape[1] != 0 and y_coord % shape[0] != 0:# If confidence score is provided (3 dimensions), skip keypoints below thresholdif len(k) == 3:conf = k[2]if conf < conf_thres:continue# Draw a circle on the image at the keypoint locationcv2.circle(self.im, (int(x_coord), int(y_coord)), radius, color_k, -1, lineType=cv2.LINE_AA)# If kpt_line is True, draw lines connecting keypoints (for human pose)if kpt_line:ndim = kpts.shape[-1]# Iterate over predefined skeleton connections and draw lines between keypointsfor i, sk in enumerate(self.skeleton):pos1 = (int(kpts[(sk[0] - 1), 0]), int(kpts[(sk[0] - 1), 1]))pos2 = (int(kpts[(sk[1] - 1), 0]), int(kpts[(sk[1] - 1), 1]))# If confidence scores are provided, skip lines for keypoints below thresholdif ndim == 3:conf1 = kpts[(sk[0] - 1), 2]conf2 = kpts[(sk[1] - 1), 2]if conf1 < conf_thres or conf2 < conf_thres:continue# Check if keypoints' positions are within image boundaries before drawing linesif pos1[0] % shape[1] == 0 or pos1[1] % shape[0] == 0 or pos1[0] < 0 or pos1[1] < 0:continueif pos2[0] % shape[1] == 0 or pos2[1] % shape[0] == 0 or pos2[0] < 0 or pos2[1] < 0:continue# Draw a line connecting two keypoints on the imagecv2.line(self.im, pos1, pos2, [int(x) for x in self.limb_color[i]], thickness=2, lineType=cv2.LINE_AA)if self.pil:# Convert numpy array (image) back to PIL image format and update self.imself.fromarray(self.im)  # Convert numpy array back to PIL imagedef rectangle(self, xy, fill=None, outline=None, width=1):"""Add rectangle to image (PIL-only)."""self.draw.rectangle(xy, fill, outline, width)def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_style=False):"""Adds text to an image using PIL or cv2."""# 如果锚点是"bottom",从字体底部开始计算y坐标if anchor == "bottom":  # start y from font bottomw, h = self.font.getsize(text)  # 获取文本的宽度和高度xy[1] += 1 - hif self.pil:# 如果需要使用方框样式if box_style:w, h = self.font.getsize(text)  # 获取文本的宽度和高度# 在图像上绘制一个矩形框作为背景,并使用txt_color填充self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color)# 将txt_color作为背景颜色,将文本以白色填充前景绘制txt_color = (255, 255, 255)# 如果文本中包含换行符if "\n" in text:lines = text.split("\n")  # 拆分成多行文本_, h = self.font.getsize(text)  # 获取单行文本的高度for line in lines:self.draw.text(xy, line, fill=txt_color, font=self.font)  # 绘制每一行文本xy[1] += h  # 更新y坐标以绘制下一行文本else:self.draw.text(xy, text, fill=txt_color, font=self.font)  # 绘制单行文本else:# 如果需要使用方框样式if box_style:w, h = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0]  # 获取文本的宽度和高度h += 3  # 增加一些像素以填充文本outside = xy[1] >= h  # 判断标签是否适合在框外p2 = xy[0] + w, xy[1] - h if outside else xy[1] + hcv2.rectangle(self.im, xy, p2, txt_color, -1, cv2.LINE_AA)  # 填充矩形框# 将txt_color作为背景颜色,将文本以白色填充前景绘制txt_color = (255, 255, 255)cv2.putText(self.im, text, xy, 0, self.sf, txt_color, thickness=self.tf, lineType=cv2.LINE_AA)  # 使用cv2绘制文本def fromarray(self, im):"""Update self.im from a numpy array."""self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)  # 将numpy数组或PIL图像赋值给self.imself.draw = ImageDraw.Draw(self.im)  # 使用PIL的ImageDraw创建绘图对象def result(self):"""Return annotated image as array."""return np.asarray(self.im)  # 将PIL图像转换为numpy数组并返回def show(self, title=None):"""Show the annotated image."""Image.fromarray(np.asarray(self.im)[..., ::-1]).show(title)  # 将numpy数组转换为RGB模式的PIL图像并显示def save(self, filename="image.jpg"):"""Save the annotated image to 'filename'."""cv2.imwrite(filename, np.asarray(self.im))  # 将numpy数组保存为图像文件def get_bbox_dimension(self, bbox=None):"""Calculate the area of a bounding box.Args:bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).Returns:angle (degree): Degree value of angle between three points"""x_min, y_min, x_max, y_max = bbox  # 解构包围框坐标width = x_max - x_min  # 计算包围框宽度height = y_max - y_min  # 计算包围框高度return width, height, width * height  # 返回宽度、高度和面积的元组def draw_region(self, reg_pts=None, color=(0, 255, 0), thickness=5):"""Draw region line.Args:reg_pts (list): Region Points (for line 2 points, for region 4 points)color (tuple): Region Color valuethickness (int): Region area thickness value"""# 使用 cv2.polylines 方法在图像上绘制多边形线段,reg_pts 是多边形的顶点坐标cv2.polylines(self.im, [np.array(reg_pts, dtype=np.int32)], isClosed=True, color=color, thickness=thickness)def draw_centroid_and_tracks(self, track, color=(255, 0, 255), track_thickness=2):"""Draw centroid point and track trails.Args:track (list): object tracking points for trails displaycolor (tuple): tracks line colortrack_thickness (int): track line thickness value"""# 将轨迹点连接成连续的线段,并绘制到图像上points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))cv2.polylines(self.im, [points], isClosed=False, color=color, thickness=track_thickness)# 在轨迹的最后一个点处画一个实心圆圈,表示物体的当前位置cv2.circle(self.im, (int(track[-1][0]), int(track[-1][1])), track_thickness * 2, color, -1)def queue_counts_display(self, label, points=None, region_color=(255, 255, 255), txt_color=(0, 0, 0)):"""Displays queue counts on an image centered at the points with customizable font size and colors.Args:label (str): queue counts labelpoints (tuple): region points for center point calculation to display textregion_color (RGB): queue region colortxt_color (RGB): text display color"""# 计算区域中心点的坐标x_values = [point[0] for point in points]y_values = [point[1] for point in points]center_x = sum(x_values) // len(points)center_y = sum(y_values) // len(points)# 计算显示文本的大小和位置text_size = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0]text_width = text_size[0]text_height = text_size[1]rect_width = text_width + 20rect_height = text_height + 20rect_top_left = (center_x - rect_width // 2, center_y - rect_height // 2)rect_bottom_right = (center_x + rect_width // 2, center_y + rect_height // 2)# 在图像上绘制一个填充的矩形框作为背景cv2.rectangle(self.im, rect_top_left, rect_bottom_right, region_color, -1)text_x = center_x - text_width // 2text_y = center_y + text_height // 2# 在指定位置绘制文本cv2.putText(self.im,label,(text_x, text_y),0,fontScale=self.sf,color=txt_color,thickness=self.tf,lineType=cv2.LINE_AA,)def display_objects_labels(self, im0, text, txt_color, bg_color, x_center, y_center, margin):"""Display the bounding boxes labels in parking management app.Args:im0 (ndarray): inference imagetext (str): object/class nametxt_color (bgr color): display color for text foregroundbg_color (bgr color): display color for text backgroundx_center (float): x position center point for bounding boxy_center (float): y position center point for bounding boxmargin (int): gap between text and rectangle for better display"""# Calculate the size of the text to be displayedtext_size = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0]# Calculate the x and y coordinates for placing the text centered at (x_center, y_center)text_x = x_center - text_size[0] // 2text_y = y_center + text_size[1] // 2# Calculate the coordinates of the rectangle surrounding the textrect_x1 = text_x - marginrect_y1 = text_y - text_size[1] - marginrect_x2 = text_x + text_size[0] + marginrect_y2 = text_y + margin# Draw a filled rectangle with specified background color around the textcv2.rectangle(im0, (rect_x1, rect_y1), (rect_x2, rect_y2), bg_color, -1)# Draw the text on the image at (text_x, text_y)cv2.putText(im0, text, (text_x, text_y), 0, self.sf, txt_color, self.tf, lineType=cv2.LINE_AA)def display_analytics(self, im0, text, txt_color, bg_color, margin):"""Display the overall statistics for parking lots.Args:im0 (ndarray): inference imagetext (dict): labels dictionarytxt_color (bgr color): display color for text foregroundbg_color (bgr color): display color for text backgroundmargin (int): gap between text and rectangle for better display"""# Calculate horizontal and vertical gaps based on image dimensionshorizontal_gap = int(im0.shape[1] * 0.02)vertical_gap = int(im0.shape[0] * 0.01)text_y_offset = 0  # Initialize offset for vertical placement of text# Iterate through each label and value pair in the provided dictionaryfor label, value in text.items():txt = f"{label}: {value}"  # Format the label and value into a string# Calculate the size of the text to be displayedtext_size = cv2.getTextSize(txt, 0, self.sf, self.tf)[0]# Ensure minimum size for text dimensions to avoid errorsif text_size[0] < 5 or text_size[1] < 5:text_size = (5, 5)# Calculate the x and y coordinates for placing the text on the imagetext_x = im0.shape[1] - text_size[0] - margin * 2 - horizontal_gaptext_y = text_y_offset + text_size[1] + margin * 2 + vertical_gap# Calculate the coordinates of the rectangle surrounding the textrect_x1 = text_x - margin * 2rect_y1 = text_y - text_size[1] - margin * 2rect_x2 = text_x + text_size[0] + margin * 2rect_y2 = text_y + margin * 2# Draw a filled rectangle with specified background color around the textcv2.rectangle(im0, (rect_x1, rect_y1), (rect_x2, rect_y2), bg_color, -1)# Draw the text on the image at (text_x, text_y)cv2.putText(im0, txt, (text_x, text_y), 0, self.sf, txt_color, self.tf, lineType=cv2.LINE_AA)# Update the vertical offset for placing the next text blocktext_y_offset = rect_y2@staticmethoddef estimate_pose_angle(a, b, c):"""Calculate the pose angle between three points.Args:a (float) : The coordinates of pose point ab (float): The coordinates of pose point bc (float): The coordinates of pose point cReturns:angle (degree): Degree value of the angle between the points"""# Convert input points to numpy arrays for calculationsa, b, c = np.array(a), np.array(b), np.array(c)# Calculate the angle using arctangent and convert from radians to degreesradians = np.arctan2(c[1] - b[1], c[0] - b[0]) - np.arctan2(a[1] - b[1], a[0] - b[0])angle = np.abs(radians * 180.0 / np.pi)# Normalize angle to be within [0, 180] degreesif angle > 180.0:angle = 360 - anglereturn angledef draw_specific_points(self, keypoints, indices=None, shape=(640, 640), radius=2, conf_thres=0.25):"""Draw specific keypoints on an image.Args:keypoints (list): List of keypoints to be plottedindices (list): Indices of keypoints to be plottedshape (tuple): Size of the image (width, height)radius (int): Radius of the keypointsconf_thres (float): Confidence threshold for keypoints"""# If indices are not provided, default to drawing keypoints 2, 5, and 7if indices is None:indices = [2, 5, 7]# Iterate through keypoints and draw circles for specific indicesfor i, k in enumerate(keypoints):if i in indices:x_coord, y_coord = k[0], k[1]# Check if the keypoints are within image boundsif x_coord % shape[1] != 0 and y_coord % shape[0] != 0:if len(k) == 3:conf = k[2]# Skip drawing if confidence is below the thresholdif conf < conf_thres:continue# Draw a circle on the image at the keypoint coordinatescv2.circle(self.im, (int(x_coord), int(y_coord)), radius, (0, 255, 0), -1, lineType=cv2.LINE_AA)# Return the image with keypoints drawnreturn self.imdef plot_angle_and_count_and_stage(self, angle_text, count_text, stage_text, center_kpt, color=(104, 31, 17), txt_color=(255, 255, 255)):"""Plot angle, count, and stage information on the image.Args:angle_text (str): Text to display for the anglecount_text (str): Text to display for the countstage_text (str): Text to display for the stagecenter_kpt (tuple): Center keypoint coordinatescolor (tuple): Color of the plotted elementstxt_color (tuple): Color of the text"""# Implementation details are missing in the provided snippet.# The function definition is incomplete.passdef seg_bbox(self, mask, mask_color=(255, 0, 255), label=None, txt_color=(255, 255, 255)):"""Draw a segmented object with a bounding box on the image.Args:mask (list): List of mask data points for the segmented objectmask_color (RGB): Color for the masklabel (str): Text label for the detectiontxt_color (RGB): Text color"""# Draw the polygonal lines around the mask regioncv2.polylines(self.im, [np.int32([mask])], isClosed=True, color=mask_color, thickness=2)# Calculate text size for label and draw a rectangle around the labeltext_size, _ = cv2.getTextSize(label, 0, self.sf, self.tf)cv2.rectangle(self.im,(int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10),(int(mask[0][0]) + text_size[0] // 2 + 10, int(mask[0][1] + 10)),mask_color,-1,)# Draw the label text on the imageif label:cv2.putText(self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1])), 0, self.sf, txt_color, self.tf)def plot_distance_and_line(self, distance_m, distance_mm, centroids, line_color, centroid_color):"""Plot the distance and line on frame.Args:distance_m (float): Distance between two bbox centroids in meters.distance_mm (float): Distance between two bbox centroids in millimeters.centroids (list): Bounding box centroids data.line_color (RGB): Distance line color.centroid_color (RGB): Bounding box centroid color."""# 计算 "Distance M" 文本的宽度和高度(text_width_m, text_height_m), _ = cv2.getTextSize(f"Distance M: {distance_m:.2f}m", 0, self.sf, self.tf)# 绘制包围 "Distance M" 文本的矩形框cv2.rectangle(self.im, (15, 25), (15 + text_width_m + 10, 25 + text_height_m + 20), line_color, -1)# 在图像中绘制 "Distance M" 文本cv2.putText(self.im,f"Distance M: {distance_m:.2f}m",(20, 50),0,self.sf,centroid_color,self.tf,cv2.LINE_AA,)# 计算 "Distance MM" 文本的宽度和高度(text_width_mm, text_height_mm), _ = cv2.getTextSize(f"Distance MM: {distance_mm:.2f}mm", 0, self.sf, self.tf)# 绘制包围 "Distance MM" 文本的矩形框cv2.rectangle(self.im, (15, 75), (15 + text_width_mm + 10, 75 + text_height_mm + 20), line_color, -1)# 在图像中绘制 "Distance MM" 文本cv2.putText(self.im,f"Distance MM: {distance_mm:.2f}mm",(20, 100),0,self.sf,centroid_color,self.tf,cv2.LINE_AA,)# 在图像中绘制两个中心点之间的直线cv2.line(self.im, centroids[0], centroids[1], line_color, 3)# 在图像中绘制第一个中心点cv2.circle(self.im, centroids[0], 6, centroid_color, -1)# 在图像中绘制第二个中心点cv2.circle(self.im, centroids[1], 6, centroid_color, -1)def visioneye(self, box, center_point, color=(235, 219, 11), pin_color=(255, 0, 255)):"""Function for pinpoint human-vision eye mapping and plotting.Args:box (list): Bounding box coordinatescenter_point (tuple): center point for vision eye viewcolor (tuple): object centroid and line color valuepin_color (tuple): visioneye point color value"""# 计算 bounding box 的中心点坐标center_bbox = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)# 在图像中绘制 visioneye 点的中心点cv2.circle(self.im, center_point, self.tf * 2, pin_color, -1)# 在图像中绘制 bounding box 的中心点cv2.circle(self.im, center_bbox, self.tf * 2, color, -1)# 在图像中绘制 visioneye 点与 bounding box 中心点之间的连线cv2.line(self.im, center_point, center_bbox, color, self.tf)
@TryExcept()  # 使用 TryExcept 装饰器,处理已知问题 https://github.com/ultralytics/yolov5/issues/5395
@plt_settings()  # 使用 plt_settings 函数进行绘图设置
def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):"""Plot training labels including class histograms and box statistics."""import pandas  # 导入 pandas 库,用于数据处理import seaborn  # 导入 seaborn 库,用于统计图表绘制# 过滤掉 matplotlib>=3.7.2 的警告和 Seaborn 的 use_inf 和 is_categorical 的 FutureWarningswarnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight")warnings.filterwarnings("ignore", category=FutureWarning)# 绘制数据集标签LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")nc = int(cls.max() + 1)  # 计算类别数量boxes = boxes[:1000000]  # 限制最多处理 100 万个框x = pandas.DataFrame(boxes, columns=["x", "y", "width", "height"])  # 创建包含框坐标的 DataFrame# 绘制 Seaborn 相关性图seaborn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200)  # 保存相关性图plt.close()# 绘制 Matplotlib 标签ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)  # 绘制类别直方图for i in range(nc):y[2].patches[i].set_color([x / 255 for x in colors(i)])  # 设置直方图颜色ax[0].set_ylabel("instances")  # 设置 y 轴标签if 0 < len(names) < 30:ax[0].set_xticks(range(len(names)))ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)  # 设置 x 轴标签else:ax[0].set_xlabel("classes")  # 设置 x 轴标签为类别seaborn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9)  # 绘制 x、y 分布图seaborn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9)  # 绘制宽度、高度分布图# 绘制矩形框boxes[:, 0:2] = 0.5  # 将框坐标调整为中心点boxes = ops.xywh2xyxy(boxes) * 1000  # 转换为绝对坐标并放大img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)  # 创建空白图像for cls, box in zip(cls[:500], boxes[:500]):ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls))  # 绘制矩形框ax[1].imshow(img)  # 显示图像ax[1].axis("off")  # 关闭坐标轴显示for a in [0, 1, 2, 3]:for s in ["top", "right", "left", "bottom"]:ax[a].spines[s].set_visible(False)  # 隐藏图表边框fname = save_dir / "labels.jpg"plt.savefig(fname, dpi=200)  # 保存最终标签图像plt.close()  # 关闭绘图窗口if on_plot:on_plot(fname)  # 如果指定了回调函数,则调用回调函数# 根据传入的边界框信息 xyxy,裁剪输入图像 im,并返回裁剪后的图像。def save_one_box(xyxy, im, file='im.jpg', gain=1.02, pad=10, square=False, BGR=False, save=True):"""Args:xyxy (torch.Tensor or list): 表示边界框的张量或列表,格式为 xyxy。im (numpy.ndarray): 输入图像。file (Path, optional): 裁剪后的图像保存路径。默认为 'im.jpg'。gain (float, optional): 边界框尺寸增益因子。默认为 1.02。pad (int, optional): 边界框宽度和高度增加的像素数。默认为 10。square (bool, optional): 如果为 True,则将边界框转换为正方形。默认为 False。BGR (bool, optional): 如果为 True,则保存图像为 BGR 格式;否则保存为 RGB 格式。默认为 False。save (bool, optional): 如果为 True,则保存裁剪后的图像到磁盘。默认为 True。Returns:(numpy.ndarray): 裁剪后的图像。Example:```pythonfrom ultralytics.utils.plotting import save_one_boxxyxy = [50, 50, 150, 150]im = cv2.imread('image.jpg')cropped_im = save_one_box(xyxy, im, file='cropped.jpg', square=True)```py"""if not isinstance(xyxy, torch.Tensor):  # 如果 xyxy 不是 torch.Tensor 类型,可能是列表xyxy = torch.stack(xyxy)  # 转换为 torch.Tensorb = ops.xyxy2xywh(xyxy.view(-1, 4))  # 将 xyxy 格式的边界框转换为 xywh 格式if square:b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1)  # 尝试将矩形边界框转换为正方形b[:, 2:] = b[:, 2:] * gain + pad  # 计算边界框宽高乘以增益因子后加上 pad 像素xyxy = ops.xywh2xyxy(b).long()  # 将 xywh 格式的边界框转换回 xyxy 格式,并转换为整型坐标xyxy = ops.clip_boxes(xyxy, im.shape)  # 将边界框坐标限制在图像范围内crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]  # 根据边界框坐标裁剪图像if save:file.parent.mkdir(parents=True, exist_ok=True)  # 创建保存图像的文件夹f = str(increment_path(file).with_suffix(".jpg"))  # 生成带有递增数字的文件名,并设置为 jpg 后缀# cv2.imwrite(f, crop)  # 保存为 BGR 格式图像(存在色度抽样问题)Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0)  # 保存为 RGB 格式图像return crop  # 返回裁剪后的图像
# 使用装饰器标记该函数为可多线程执行的函数
@threaded
# 定义函数用于绘制带有标签、边界框、掩码和关键点的图像网格
def plot_images(# 图像数据,可以是 torch.Tensor 或 np.ndarray 类型,形状为 (batch_size, channels, height, width)images: Union[torch.Tensor, np.ndarray],# 每个检测的批次索引,形状为 (num_detections,)batch_idx: Union[torch.Tensor, np.ndarray],# 每个检测的类别标签,形状为 (num_detections,)cls: Union[torch.Tensor, np.ndarray],# 每个检测的边界框,形状为 (num_detections, 4) 或 (num_detections, 5)(用于旋转边界框)bboxes: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.float32),# 每个检测的置信度分数,形状为 (num_detections,)confs: Optional[Union[torch.Tensor, np.ndarray]] = None,# 实例分割掩码,形状为 (num_detections, height, width) 或 (1, height, width)masks: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.uint8),# 每个检测的关键点,形状为 (num_detections, 51)kpts: Union[torch.Tensor, np.ndarray] = np.zeros((0, 51), dtype=np.float32),# 图像文件路径列表,与批次中每个图像对应paths: Optional[List[str]] = None,# 输出图像网格的文件名fname: str = "images.jpg",# 类别索引到类别名称的映射字典names: Optional[Dict[int, str]] = None,# 绘图完成后的回调函数,可选on_plot: Optional[Callable] = None,# 输出图像网格的最大尺寸max_size: int = 1920,# 图像网格中最大子图数目max_subplots: int = 16,# 是否保存绘制的图像网格到文件save: bool = True,# 显示检测结果所需的置信度阈值conf_thres: float = 0.25,
) -> Optional[np.ndarray]:"""Plot image grid with labels, bounding boxes, masks, and keypoints.Args:images: Batch of images to plot. Shape: (batch_size, channels, height, width).batch_idx: Batch indices for each detection. Shape: (num_detections,).cls: Class labels for each detection. Shape: (num_detections,).bboxes: Bounding boxes for each detection. Shape: (num_detections, 4) or (num_detections, 5) for rotated boxes.confs: Confidence scores for each detection. Shape: (num_detections,).masks: Instance segmentation masks. Shape: (num_detections, height, width) or (1, height, width).kpts: Keypoints for each detection. Shape: (num_detections, 51).paths: List of file paths for each image in the batch.fname: Output filename for the plotted image grid.names: Dictionary mapping class indices to class names.on_plot: Optional callback function to be called after saving the plot.max_size: Maximum size of the output image grid.max_subplots: Maximum number of subplots in the image grid.save: Whether to save the plotted image grid to a file.conf_thres: Confidence threshold for displaying detections.Returns:np.ndarray: Plotted image grid as a numpy array if save is False, None otherwise.Note:This function supports both tensor and numpy array inputs. It will automaticallyconvert tensor inputs to numpy arrays for processing."""# 如果 images 是 torch.Tensor 类型,则转换为 numpy 数组类型if isinstance(images, torch.Tensor):images = images.cpu().float().numpy()# 如果 cls 是 torch.Tensor 类型,则转换为 numpy 数组类型if isinstance(cls, torch.Tensor):cls = cls.cpu().numpy()# 如果 bboxes 是 torch.Tensor 类型,则转换为 numpy 数组类型if isinstance(bboxes, torch.Tensor):bboxes = bboxes.cpu().numpy()# 如果 masks 是 torch.Tensor 类型,则转换为 numpy 数组类型,并将类型转换为 intif isinstance(masks, torch.Tensor):masks = masks.cpu().numpy().astype(int)# 如果 kpts 是 torch.Tensor 类型,则转换为 numpy 数组类型if isinstance(kpts, torch.Tensor):kpts = kpts.cpu().numpy()# 如果 batch_idx 是 torch.Tensor 类型,则转换为 numpy 数组类型if isinstance(batch_idx, torch.Tensor):batch_idx = batch_idx.cpu().numpy()# 获取图像的批次大小、通道数、高度和宽度bs, _, h, w = images.shape  # batch size, _, height, width# 限制要绘制的图像数量,最多为 max_subplotsbs = min(bs, max_subplots)# 计算图像网格中子图的行数和列数(向上取整)ns = np.ceil(bs**0.5)# 如果图像的最大像素值小于等于1,则将其转换为 0-255 范围的值(去除标准化)if np.max(images[0]) <= 1:images *= 255  # de-normalise (optional)# 构建图像拼接mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8)  # 初始化一个白色背景的图像数组# 遍历每个图像块,将其放置在合适的位置for i in range(bs):x, y = int(w * (i // ns)), int(h * (i % ns))  # 计算当前块的起始位置mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0)  # 将图像块放置到拼接图像上# 可选的调整大小操作scale = max_size / ns / max(h, w)if scale < 1:h = math.ceil(scale * h)  # 计算调整后的高度w = math.ceil(scale * w)  # 计算调整后的宽度mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))  # 调整拼接后的图像大小# 添加注释fs = int((h + w) * ns * 0.01)  # 计算字体大小annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)  # 创建一个注释器对象if not save:return np.asarray(annotator.im)  # 如果不需要保存,返回注释后的图像数组annotator.im.save(fname)  # 否则保存注释后的图像if on_plot:on_plot(fname)  # 如果有指定的绘图函数,调用它并传入保存的文件名
@plt_settings()
def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):"""Plot training results from a results CSV file. The function supports various types of data including segmentation,pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.Args:file (str, optional): Path to the CSV file containing the training results. Defaults to 'path/to/results.csv'.dir (str, optional): Directory where the CSV file is located if 'file' is not provided. Defaults to ''.segment (bool, optional): Flag to indicate if the data is for segmentation. Defaults to False.pose (bool, optional): Flag to indicate if the data is for pose estimation. Defaults to False.classify (bool, optional): Flag to indicate if the data is for classification. Defaults to False.on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument.Defaults to None.Example:```pythonfrom ultralytics.utils.plotting import plot_resultsplot_results('path/to/results.csv', segment=True)```py"""import pandas as pd  # 导入 pandas 库,用于处理 CSV 文件from scipy.ndimage import gaussian_filter1d  # 导入 scipy 库中的高斯滤波函数# 确定保存图片的目录save_dir = Path(file).parent if file else Path(dir)# 根据不同的数据类型和设置,选择合适的子图布局和指数索引if classify:fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)  # 分类数据的布局index = [1, 4, 2, 3]  # 对应子图的索引顺序elif segment:fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)  # 分割数据的布局index = [1, 2, 3, 4, 5, 6, 9, 10, 13, 14, 15, 16, 7, 8, 11, 12]  # 对应子图的索引顺序elif pose:fig, ax = plt.subplots(2, 9, figsize=(21, 6), tight_layout=True)  # 姿态估计数据的布局index = [1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 18, 8, 9, 12, 13]  # 对应子图的索引顺序else:fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)  # 默认数据的布局index = [1, 2, 3, 4, 5, 8, 9, 10, 6, 7]  # 对应子图的索引顺序ax = ax.ravel()  # 将子图数组展平,便于迭代处理files = list(save_dir.glob("results*.csv"))  # 查找保存结果的 CSV 文件列表assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."  # 断言确保找到了结果文件,否则报错for f in files:try:data = pd.read_csv(f)  # 读取 CSV 文件中的数据s = [x.strip() for x in data.columns]  # 清理列名,去除空格x = data.values[:, 0]  # 获取 X 轴数据,通常是第一列数据# 遍历子图索引,绘制每个子图的数据曲线和平滑曲线for i, j in enumerate(index):y = data.values[:, j].astype("float")  # 获取 Y 轴数据,并转换为浮点数类型# y[y == 0] = np.nan  # 不显示值为零的点,可选功能# 绘制实际结果曲线ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8)# 绘制平滑后的曲线ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2)ax[i].set_title(s[j], fontsize=12)  # 设置子图标题# 如果是指定的子图索引,共享训练和验证损失的 Y 轴# if j in {8, 9, 10}:#     ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])except Exception as e:LOGGER.warning(f"WARNING: Plotting error for {f}: {e}")  # 捕获并记录绘图过程中的异常信息ax[1].legend()  # 在第二个子图上添加图例# 指定文件名为 save_dir 下的 "results.png"fname = save_dir / "results.png"# 将当前图形保存为 PNG 文件,设置 DPI 为 200fig.savefig(fname, dpi=200)# 关闭当前图形,释放资源plt.close()# 如果定义了 on_plot 回调函数,则调用该函数,传递保存的文件名作为参数if on_plot:on_plot(fname)
def plot_tune_results(csv_file="tune_results.csv"):"""Plot the evolution results stored in an 'tune_results.csv' file. The function generates a scatter plot for each keyin the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.Args:csv_file (str, optional): Path to the CSV file containing the tuning results. Defaults to 'tune_results.csv'.Examples:>>> plot_tune_results('path/to/tune_results.csv')"""import pandas as pd  # 导入 pandas 库,用于处理数据from scipy.ndimage import gaussian_filter1d  # 导入 scipy 库中的高斯滤波函数def _save_one_file(file):"""Save one matplotlib plot to 'file'."""plt.savefig(file, dpi=200)  # 保存当前 matplotlib 图形为指定文件,设置分辨率为200dpiplt.close()  # 关闭当前 matplotlib 图形LOGGER.info(f"Saved {file}")  # 记录日志信息,显示保存成功的文件名# Scatter plots for each hyperparametercsv_file = Path(csv_file)  # 将传入的 CSV 文件路径转换为 Path 对象data = pd.read_csv(csv_file)  # 使用 pandas 读取 CSV 文件中的数据num_metrics_columns = 1  # 指定要跳过的列数(这里是第一列的列数)keys = [x.strip() for x in data.columns][num_metrics_columns:]  # 获取 CSV 文件中的列名,并去除首尾空白字符x = data.values  # 获取 CSV 文件中的所有数据值fitness = x[:, 0]  # 从数据中提取 fitness(适应度)列数据j = np.argmax(fitness)  # 找到 fitness 列中最大值的索引n = math.ceil(len(keys) ** 0.5)  # 计算绘图的行数和列数,向上取整以确保足够的子图空间plt.figure(figsize=(10, 10), tight_layout=True)  # 创建一个 10x10 英寸大小的图形,并启用紧凑布局for i, k in enumerate(keys):v = x[:, i + num_metrics_columns]  # 获取当前列(除 fitness 外的其他列)的数据mu = v[j]  # 获取当前列中 fitness 最大值对应的数据点plt.subplot(n, n, i + 1)  # 在 n x n 的子图中,选择第 i+1 个子图plt_color_scatter(v, fitness, cmap="viridis", alpha=0.8, edgecolors="none")  # 调用 plt_color_scatter 函数绘制散点图plt.plot(mu, fitness.max(), "k+", markersize=15)  # 在散点图上绘制 fitness 最大值对应的点plt.title(f"{k} = {mu:.3g}", fontdict={"size": 9})  # 设置子图标题,显示参数名和对应的最佳单个结果plt.tick_params(axis="both", labelsize=8)  # 设置坐标轴标签的大小为 8if i % n != 0:plt.yticks([])  # 如果不是每行的第一个子图,则不显示 y 轴刻度_save_one_file(csv_file.with_name("tune_scatter_plots.png"))  # 调用保存函数,将绘制好的图形保存为 PNG 文件# Fitness vs iteration# 生成 x 轴的数值范围,从1到fitness列表长度加1x = range(1, len(fitness) + 1)# 创建一个图形对象,设置图形大小为10x6,启用紧凑布局plt.figure(figsize=(10, 6), tight_layout=True)# 绘制 fitness 列表的数据点,使用圆形标记,折线样式为无,设置标签为"fitness"plt.plot(x, fitness, marker="o", linestyle="none", label="fitness")# 绘制 fitness 列表数据点的高斯平滑曲线,设置折线样式为冒号,设置标签为"smoothed",设置线宽为2,说明是平滑线plt.plot(x, gaussian_filter1d(fitness, sigma=3), ":", label="smoothed", linewidth=2)  # smoothing line# 设置图形的标题为"Fitness vs Iteration"plt.title("Fitness vs Iteration")# 设置 x 轴标签为"Iteration"plt.xlabel("Iteration")# 设置 y 轴标签为"Fitness"plt.ylabel("Fitness")# 启用网格线plt.grid(True)# 显示图例plt.legend()# 调用保存图形的函数,保存文件名为csv_file的名称加上"tune_fitness.png"作为后缀_save_one_file(csv_file.with_name("tune_fitness.png"))
def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")):"""Visualize feature maps of a given model module during inference.Args:x (torch.Tensor): Features to be visualized.module_type (str): Module type.stage (int): Module stage within the model.n (int, optional): Maximum number of feature maps to plot. Defaults to 32.save_dir (Path, optional): Directory to save results. Defaults to Path('runs/detect/exp')."""# 检查模块类型是否属于需要可视化的类型,如果不属于则直接返回for m in {"Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder"}:  # all model headsif m in module_type:return# 检查输入特征是否为Tensor类型if isinstance(x, torch.Tensor):_, channels, height, width = x.shape  # 获取特征张量的形状信息:batch, channels, height, widthif height > 1 and width > 1:f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png"  # 构建保存文件路径和名称# 按照通道数拆分特征图块blocks = torch.chunk(x[0].cpu(), channels, dim=0)  # 选择批次索引为0的数据,并按通道拆分n = min(n, channels)  # 确定要绘制的特征图块数量,不超过通道数_, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True)  # 创建绘图布局,8行 n/8 列ax = ax.ravel()plt.subplots_adjust(wspace=0.05, hspace=0.05)for i in range(n):ax[i].imshow(blocks[i].squeeze())  # 显示特征图块,去除单维度条目ax[i].axis("off")  # 关闭坐标轴显示LOGGER.info(f"Saving {f}... ({n}/{channels})")  # 记录保存文件信息plt.savefig(f, dpi=300, bbox_inches="tight")  # 保存绘制结果为PNG文件,300dpi,紧凑边界plt.close()  # 关闭绘图窗口np.save(str(f.with_suffix(".npy")), x[0].cpu().numpy())  # 保存特征数据为.npy文件

.\yolov8\ultralytics\utils\tal.py

# 导入 PyTorch 库中的相关模块
import torch
import torch.nn as nn# 从本地模块中导入必要的函数和类
from .checks import check_version
from .metrics import bbox_iou, probiou
from .ops import xywhr2xyxyxyxy# 检查当前使用的 PyTorch 版本是否符合最低要求
TORCH_1_10 = check_version(torch.__version__, "1.10.0")# 定义一个任务对齐分配器的类,用于目标检测
class TaskAlignedAssigner(nn.Module):"""A task-aligned assigner for object detection.This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines bothclassification and localization information.Attributes:topk (int): The number of top candidates to consider.num_classes (int): The number of object classes.alpha (float): The alpha parameter for the classification component of the task-aligned metric.beta (float): The beta parameter for the localization component of the task-aligned metric.eps (float): A small value to prevent division by zero."""def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):"""Initialize a TaskAlignedAssigner object with customizable hyperparameters."""# 调用父类构造函数初始化模块super().__init__()# 设置对象属性,用于指定任务对齐分配器的超参数self.topk = topk  # 设置前k个候选框的数量self.num_classes = num_classes  # 设置目标类别的数量self.bg_idx = num_classes  # 设置背景类别的索引,默认为num_classesself.alpha = alpha  # 设置任务对齐度量中分类组件的参数alphaself.beta = beta  # 设置任务对齐度量中定位组件的参数betaself.eps = eps  # 设置一个极小值,用于避免除以零的情况@torch.no_grad()def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):"""Compute the task-aligned assignment. Reference code is available athttps://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py.Args:pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)预测得分张量,形状为(bs, num_total_anchors, num_classes)pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)预测边界框张量,形状为(bs, num_total_anchors, 4)anc_points (Tensor): shape(num_total_anchors, 2)锚点坐标张量,形状为(num_total_anchors, 2)gt_labels (Tensor): shape(bs, n_max_boxes, 1)真实标签张量,形状为(bs, n_max_boxes, 1)gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)真实边界框张量,形状为(bs, n_max_boxes, 4)mask_gt (Tensor): shape(bs, n_max_boxes, 1)真实边界框掩码张量,形状为(bs, n_max_boxes, 1)Returns:target_labels (Tensor): shape(bs, num_total_anchors)目标标签张量,形状为(bs, num_total_anchors)target_bboxes (Tensor): shape(bs, num_total_anchors, 4)目标边界框张量,形状为(bs, num_total_anchors, 4)target_scores (Tensor): shape(bs, num_total_anchors, num_classes)目标得分张量,形状为(bs, num_total_anchors, num_classes)fg_mask (Tensor): shape(bs, num_total_anchors)前景掩码张量,形状为(bs, num_total_anchors)target_gt_idx (Tensor): shape(bs, num_total_anchors)目标真实边界框索引张量,形状为(bs, num_total_anchors)"""self.bs = pd_scores.shape[0]  # 记录批次大小self.n_max_boxes = gt_bboxes.shape[1]  # 记录每个样本最大边界框数if self.n_max_boxes == 0:  # 如果没有真实边界框device = gt_bboxes.devicereturn (torch.full_like(pd_scores[..., 0], self.bg_idx).to(device),  # 返回背景类索引torch.zeros_like(pd_bboxes).to(device),  # 返回零张量形状与预测边界框一致torch.zeros_like(pd_scores).to(device),  # 返回零张量形状与预测得分一致torch.zeros_like(pd_scores[..., 0]).to(device),  # 返回零张量形状与预测得分一致torch.zeros_like(pd_scores[..., 0]).to(device),  # 返回零张量形状与预测得分一致)mask_pos, align_metric, overlaps = self.get_pos_mask(pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt)  # 获取正样本掩码、对齐度量、重叠度量target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)# 选择最高重叠度的真实边界框索引、前景掩码、正样本掩码# Assigned targettarget_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)# 获取分配的目标标签、目标边界框、目标得分# Normalizealign_metric *= mask_pos  # 对齐度量乘以正样本掩码pos_align_metrics = align_metric.amax(dim=-1, keepdim=True)  # 计算每个样本的最大对齐度量pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True)  # 计算每个样本的最大重叠度norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)# 计算归一化后的对齐度量target_scores = target_scores * norm_align_metric  # 目标得分乘以归一化后的对齐度量return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idxdef get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):"""Get in_gts mask, (b, max_num_obj, h*w)."""# Select candidates within ground truth bounding boxesmask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)# Compute alignment metric and overlaps between predicted and ground truth boxesalign_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)# Select top-k candidates based on alignment metricmask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool())# Merge masks to get the final positive maskmask_pos = mask_topk * mask_in_gts * mask_gt# Return the final positive mask, alignment metric, and overlapsreturn mask_pos, align_metric, overlapsdef get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):"""Compute alignment metric given predicted and ground truth bounding boxes."""na = pd_bboxes.shape[-2]mask_gt = mask_gt.bool()  # b, max_num_obj, h*w# Initialize tensors for overlaps and bbox scoresoverlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)# Create indices tensor for accessing scores based on ground truth labelsind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)  # 2, b, max_num_objind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes)  # b, max_num_objind[1] = gt_labels.squeeze(-1)  # b, max_num_obj# Assign predicted scores to corresponding locations in bbox_scoresbbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt]  # b, max_num_obj, h*w# Extract predicted and ground truth bounding boxes where mask_gt is Truepd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]  # (b, max_num_obj, 1, 4)gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]  # (b, 1, h*w, 4)# Compute IoU between selected boxesoverlaps[mask_gt] = self.iou_calculation(gt_boxes, pd_boxes)# Calculate alignment metric using bbox_scores and overlapsalign_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)return align_metric, overlapsdef iou_calculation(self, gt_bboxes, pd_bboxes):"""IoU calculation for horizontal bounding boxes."""# Calculate IoU using bbox_iou function with specified parametersreturn bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)# 根据给定的 metrics 张量,选择每个位置的前 self.topk 个候选项的指标值和索引topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)# 如果 topk_mask 未提供,则根据 metrics 张量中的最大值确定 top-k 值,并扩展为布尔张量if topk_mask is None:topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)# 根据 topk_mask,将 topk_idxs 中未选中的位置填充为 0topk_idxs.masked_fill_(~topk_mask, 0)# 创建一个与 metrics 张量形状相同的计数张量,用于统计每个位置被选择的次数count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)# 遍历 topk 值,对每个 topk 索引位置添加计数值for k in range(self.topk):count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)# 将计数张量中大于 1 的值(即超过一次被选择的位置),置为 0,用于过滤无效的候选项count_tensor.masked_fill_(count_tensor > 1, 0)# 将计数张量转换为与 metrics 相同类型的张量,并返回结果return count_tensor.to(metrics.dtype)def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):"""Compute target labels, target bounding boxes, and target scores for the positive anchor points.Args:gt_labels (Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is thebatch size and max_num_obj is the maximum number of objects.gt_bboxes (Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).target_gt_idx (Tensor): Indices of the assigned ground truth objects for positiveanchor points, with shape (b, h*w), where h*w is the totalnumber of anchor points.fg_mask (Tensor): A boolean tensor of shape (b, h*w) indicating the positive(foreground) anchor points.Returns:(Tuple[Tensor, Tensor, Tensor]): A tuple containing the following tensors:- target_labels (Tensor): Shape (b, h*w), containing the target labels forpositive anchor points.- target_bboxes (Tensor): Shape (b, h*w, 4), containing the target bounding boxesfor positive anchor points.- target_scores (Tensor): Shape (b, h*w, num_classes), containing the target scoresfor positive anchor points, where num_classes is the numberof object classes."""# Assigned target labels, (b, 1)# Create batch indices for indexing into gt_labelsbatch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]# Adjust target_gt_idx to point to the correct location in the flattened gt_labelstarget_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes  # (b, h*w)# Extract target labels from gt_labels using flattened indicestarget_labels = gt_labels.long().flatten()[target_gt_idx]  # (b, h*w)# Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)# Reshape gt_bboxes to (b * max_num_obj, 4) and then index using target_gt_idxtarget_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx]# Assigned target scorestarget_labels.clamp_(0)  # Clamp target_labels to ensure non-negative values# 10x faster than F.one_hot()# Initialize target_scores tensor with zeros and then scatter ones at target_labels indicestarget_scores = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.num_classes),dtype=torch.int64,device=target_labels.device,)  # (b, h*w, 80)target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)# Mask target_scores based on fg_mask to only keep scores for foreground anchor pointsfg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes)  # (b, h*w, 80)target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)return target_labels, target_bboxes, target_scoresdef select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):"""Select the positive anchor center in gt.Args:xy_centers (Tensor): shape(h*w, 2) - 存储锚点的中心坐标gt_bboxes (Tensor): shape(b, n_boxes, 4) - 存储每个图像中各个边界框的坐标信息Returns:(Tensor): shape(b, n_boxes, h*w) - 返回一个布尔值张量,指示哪些锚点与边界框有显著重叠"""n_anchors = xy_centers.shape[0]  # 获取锚点的数量bs, n_boxes, _ = gt_bboxes.shape  # 获取边界框的数量和维度信息lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2)  # left-top, right-bottombbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)# 计算每个锚点与其对应边界框之间的距离,形成四个坐标差值并存储在bbox_deltas张量中return bbox_deltas.amin(3).gt_(eps)  # 判断距离是否大于阈值eps,并返回布尔值结果@staticmethoddef select_highest_overlaps(mask_pos, overlaps, n_max_boxes):"""If an anchor box is assigned to multiple gts, the one with the highest IoU will be selected.Args:mask_pos (Tensor): shape(b, n_max_boxes, h*w) - 存储布尔值指示哪些锚点与边界框有重叠overlaps (Tensor): shape(b, n_max_boxes, h*w) - 存储每个锚点与所有边界框之间的IoU值Returns:target_gt_idx (Tensor): shape(b, h*w) - 返回每个锚点与其最佳匹配边界框的索引fg_mask (Tensor): shape(b, h*w) - 返回一个布尔值张量,指示哪些锚点被分配给了边界框mask_pos (Tensor): shape(b, n_max_boxes, h*w) - 返回更新后的锚点分配信息"""# (b, n_max_boxes, h*w) -> (b, h*w)fg_mask = mask_pos.sum(-2)  # 计算每个锚点分配给边界框的数量if fg_mask.max() > 1:  # 如果一个锚点被分配给多个边界框mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1)  # (b, n_max_boxes, h*w)max_overlaps_idx = overlaps.argmax(1)  # (b, h*w)is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()  # (b, n_max_boxes, h*w)fg_mask = mask_pos.sum(-2)  # 更新后的锚点分配数量# 找到每个网格服务的哪个gt(索引)target_gt_idx = mask_pos.argmax(-2)  # (b, h*w)return target_gt_idx, fg_mask, mask_pos
class RotatedTaskAlignedAssigner(TaskAlignedAssigner):"""Assigns ground-truth objects to rotated bounding boxes using a task-aligned metric."""def iou_calculation(self, gt_bboxes, pd_bboxes):"""IoU calculation for rotated bounding boxes."""return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)@staticmethoddef select_candidates_in_gts(xy_centers, gt_bboxes):"""Select the positive anchor center in gt for rotated bounding boxes.Args:xy_centers (Tensor): shape(h*w, 2) - Anchor centers to consider.gt_bboxes (Tensor): shape(b, n_boxes, 5) - Ground-truth rotated bounding boxes.Returns:(Tensor): shape(b, n_boxes, h*w) - Boolean mask indicating positive anchor centers."""# (b, n_boxes, 5) --> (b, n_boxes, 4, 2) - Rearrange bounding box coordinates.corners = xywhr2xyxyxyxy(gt_bboxes)# (b, n_boxes, 1, 2) - Extract corner points a, b, and d from corners.a, b, _, d = corners.split(1, dim=-2)ab = b - a  # Compute vectors ab and ad from corner points.ad = d - a# (b, n_boxes, h*w, 2) - Calculate vector ap from anchor centers to point a.ap = xy_centers - anorm_ab = (ab * ab).sum(dim=-1)  # Calculate norms and dot products for IoU calculation.norm_ad = (ad * ad).sum(dim=-1)ap_dot_ab = (ap * ab).sum(dim=-1)ap_dot_ad = (ap * ad).sum(dim=-1)return (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad)  # is_in_boxdef make_anchors(feats, strides, grid_cell_offset=0.5):"""Generate anchors from features."""anchor_points, stride_tensor = [], []assert feats is not None  # Ensure features are not None.dtype, device = feats[0].dtype, feats[0].device  # Determine data type and device from the first feature.for i, stride in enumerate(strides):_, _, h, w = feats[i].shape  # Retrieve height and width of feature map.sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset  # Generate x offsets.sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset  # Generate y offsets.sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)  # Create grid points.anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))  # Stack grid points into anchor points.stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))  # Create stride tensor.return torch.cat(anchor_points), torch.cat(stride_tensor)  # Concatenate anchor points and strides.def dist2bbox(distance, anchor_points, xywh=True, dim=-1):"""Transform distance(ltrb) to box(xywh or xyxy)."""lt, rb = distance.chunk(2, dim)  # Split distance tensor into left-top and right-bottom.x1y1 = anchor_points - lt  # Compute top-left corner coordinates.x2y2 = anchor_points + rb  # Compute bottom-right corner coordinates.if xywh:c_xy = (x1y1 + x2y2) / 2  # Compute center coordinates.wh = x2y2 - x1y1  # Compute width and height.return torch.cat((c_xy, wh), dim)  # xywh bbox - Concatenate center and size.return torch.cat((x1y1, x2y2), dim)  # xyxy bbox - Concatenate top-left and bottom-right.def bbox2dist(anchor_points, bbox, reg_max):"""Transform bbox(xyxy) to dist(ltrb)."""x1y1, x2y2 = bbox.chunk(2, -1)  # Split bbox tensor into x1y1 and x2y2.return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01)  # dist (lt, rb)def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):"""Decode predicted object bounding box coordinates from anchor points and distribution."""# Function not completed in provided snippet, further implementation required.# 将预测的旋转距离张量按照指定维度分割为左上角和右下角坐标偏移量lt, rb = pred_dist.split(2, dim=dim)# 计算预测的角度的余弦和正弦值cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)# 计算中心点偏移量的 x 和 y 分量xf, yf = ((rb - lt) / 2).split(1, dim=dim)# 根据旋转角度对中心点偏移量进行调整,得到旋转后的中心点坐标x, y = xf * cos - yf * sin, xf * sin + yf * cos# 将旋转后的中心点坐标与锚点相加,得到最终的旋转后的坐标xy = torch.cat([x, y], dim=dim) + anchor_points# 将左上角和右下角坐标偏移量相加,得到最终的旋转后的边界框坐标return torch.cat([xy, lt + rb], dim=dim)

.\yolov8\ultralytics\utils\torch_utils.py

# Ultralytics YOLO 🚀, AGPL-3.0 licenseimport gc  # 导入垃圾回收模块
import math  # 导入数学模块
import os  # 导入操作系统模块
import random  # 导入随机数模块
import time  # 导入时间模块
from contextlib import contextmanager  # 导入上下文管理器模块
from copy import deepcopy  # 导入深拷贝函数
from datetime import datetime  # 导入日期时间模块
from pathlib import Path  # 导入路径模块
from typing import Union  # 导入类型注解import numpy as np  # 导入NumPy库
import torch  # 导入PyTorch库
import torch.distributed as dist  # 导入PyTorch分布式训练模块
import torch.nn as nn  # 导入PyTorch神经网络模块
import torch.nn.functional as F  # 导入PyTorch函数模块from ultralytics.utils import (  # 导入Ultralytics工具函数DEFAULT_CFG_DICT,  # 默认配置字典DEFAULT_CFG_KEYS,  # 默认配置键列表LOGGER,  # 日志记录器NUM_THREADS,  # 线程数PYTHON_VERSION,  # Python版本TORCHVISION_VERSION,  # TorchVision版本__version__,  # Ultralytics版本colorstr,  # 字符串颜色化函数
)
from ultralytics.utils.checks import check_version  # 导入版本检查函数try:import thop  # 尝试导入thop库
except ImportError:thop = None  # 如果导入失败,设为None# Version checks (all default to version>=min_version)
TORCH_1_9 = check_version(torch.__version__, "1.9.0")  # 检查PyTorch版本是否>=1.9.0
TORCH_1_13 = check_version(torch.__version__, "1.13.0")  # 检查PyTorch版本是否>=1.13.0
TORCH_2_0 = check_version(torch.__version__, "2.0.0")  # 检查PyTorch版本是否>=2.0.0
TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0")  # 检查TorchVision版本是否>=0.10.0
TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0")  # 检查TorchVision版本是否>=0.11.0
TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0")  # 检查TorchVision版本是否>=0.13.0
TORCHVISION_0_18 = check_version(TORCHVISION_VERSION, "0.18.0")  # 检查TorchVision版本是否>=0.18.0@contextmanager
def torch_distributed_zero_first(local_rank: int):"""Ensures all processes in distributed training wait for the local master (rank 0) to complete a task first."""initialized = dist.is_available() and dist.is_initialized()  # 检查是否启用了分布式训练且是否已初始化if initialized and local_rank not in {-1, 0}:  # 如果初始化且当前进程不是主进程(rank 0)dist.barrier(device_ids=[local_rank])  # 等待本地主节点(rank 0)完成任务yield  # 执行上下文管理器的主体部分if initialized and local_rank == 0:  # 如果初始化且当前进程是主进程(rank 0)dist.barrier(device_ids=[0])  # 确保所有进程在继续之前都等待主进程完成def smart_inference_mode():"""Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator."""def decorate(fn):"""Applies appropriate torch decorator for inference mode based on torch version."""if TORCH_1_9 and torch.is_inference_mode_enabled():return fn  # 如果已启用推断模式,直接返回函数else:return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)  # 根据版本选择合适的推断模式装饰器return decoratedef autocast(enabled: bool, device: str = "cuda"):"""Get the appropriate autocast context manager based on PyTorch version and AMP setting.This function returns a context manager for automatic mixed precision (AMP) training that is compatible with botholder and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions.Args:enabled (bool): Whether to enable automatic mixed precision.device (str, optional): The device to use for autocast. Defaults to 'cuda'.Returns:(torch.amp.autocast): The appropriate autocast context manager.Note:- For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.- For older versions, it uses `torch.cuda.autocast`.Example:```pywith autocast(amp=True):# Your mixed precision operations herepass```"""# 如果 TORCH_1_13 变量为真,使用 torch.amp.autocast 方法开启自动混合精度模式if TORCH_1_13:return torch.amp.autocast(device, enabled=enabled)# 如果 TORCH_1_13 变量为假,使用 torch.cuda.amp.autocast 方法开启自动混合精度模式else:return torch.cuda.amp.autocast(enabled)
def get_cpu_info():"""Return a string with system CPU information, i.e. 'Apple M2'."""import cpuinfo  # 导入cpuinfo库,用于获取CPU信息,需使用pip安装py-cpuinfok = "brand_raw", "hardware_raw", "arch_string_raw"  # 按优先顺序列出信息键(并非所有键始终可用)info = cpuinfo.get_cpu_info()  # 获取CPU信息的字典string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")  # 提取CPU信息字符串return string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")  # 处理特殊字符后返回CPU信息字符串def select_device(device="", batch=0, newline=False, verbose=True):"""Selects the appropriate PyTorch device based on the provided arguments.The function takes a string specifying the device or a torch.device object and returns a torch.device objectrepresenting the selected device. The function also validates the number of available devices and raises anexception if the requested device(s) are not available.Args:device (str | torch.device, optional): Device string or torch.device object.Options are 'None', 'cpu', or 'cuda', or '0' or '0,1,2,3'. Defaults to an empty string, which auto-selectsthe first available GPU, or CPU if no GPU is available.batch (int, optional): Batch size being used in your model. Defaults to 0.newline (bool, optional): If True, adds a newline at the end of the log string. Defaults to False.verbose (bool, optional):elif device:  # 非 CPU 设备请求时执行以下操作if device == "cuda":device = "0"visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)os.environ["CUDA_VISIBLE_DEVICES"] = device  # 设置环境变量,必须在检查可用性之前设置if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.split(","))):LOGGER.info(s)  # 记录信息到日志install = ("See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no ""CUDA devices are seen by torch.\n"if torch.cuda.device_count() == 0else "")raise ValueError(f"Invalid CUDA 'device={device}' requested."f" Use 'device=cpu' or pass valid CUDA device(s) if available,"f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}"f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}"f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"f"{install}")if not cpu and not mps and torch.cuda.is_available():  # 如果可用且未请求 CPU 或 MPSdevices = device.split(",") if device else "0"  # 定义设备列表,默认为 "0"n = len(devices)  # 设备数量if n > 1:  # 多 GPU 情况if batch < 1:raise ValueError("AutoBatch with batch<1 not supported for Multi-GPU training, ""please specify a valid batch size, i.e. batch=16.")if batch >= 0 and batch % n != 0:  # 检查 batch_size 是否可以被设备数量整除raise ValueError(f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}.")space = " " * (len(s) + 1)  # 创建空格串for i, d in enumerate(devices):p = torch.cuda.get_device_properties(i)s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n"  # 字符串拼接 GPU 信息arg = "cuda:0"  # 设置 CUDA 设备为默认值 "cuda:0"elif mps and TORCH_2_0 and torch.backends.mps.is_available():# 如果支持 MPS 并且满足条件,则优先选择 MPSs += f"MPS ({get_cpu_info()})\n"  # 添加 MPS 信息到字符串arg = "mps"  # 设置设备类型为 "mps"else:  # 否则,默认使用 CPUs += f"CPU ({get_cpu_info()})\n"  # 添加 CPU 信息到字符串arg = "cpu"  # 设置设备类型为 "cpu"if arg in {"cpu", "mps"}:torch.set_num_threads(NUM_THREADS)  # 设置 CPU 训练的线程数if verbose:LOGGER.info(s if newline else s.rstrip())  # 如果需要详细输出,则记录详细信息到日志return torch.device(arg)  # 返回对应的 Torch 设备对象
# 返回当前系统时间,确保在使用 PyTorch 时精确同步时间
def time_sync():"""PyTorch-accurate time."""# 如果 CUDA 可用,同步 CUDA 计算的时间if torch.cuda.is_available():torch.cuda.synchronize()# 返回当前时间戳return time.time()# 将 Conv2d() 和 BatchNorm2d() 层融合,实现优化 https://tehnokv.com/posts/fusing-batchnorm-and-conv/
def fuse_conv_and_bn(conv, bn):"""Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/."""# 创建融合后的卷积层对象fusedconv = (nn.Conv2d(conv.in_channels,conv.out_channels,kernel_size=conv.kernel_size,stride=conv.stride,padding=conv.padding,dilation=conv.dilation,groups=conv.groups,bias=True,).requires_grad_(False)  # 禁用梯度追踪,不需要反向传播训练.to(conv.weight.device)  # 将融合后的卷积层移到与输入卷积层相同的设备上)# 准备卷积层的权重w_conv = conv.weight.clone().view(conv.out_channels, -1)# 计算融合后的权重w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))# 准备空间偏置项b_conv = torch.zeros(conv.weight.shape[0], device=conv.weight.device) if conv.bias is None else conv.bias# 计算融合后的偏置项b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)return fusedconv# 将 ConvTranspose2d() 和 BatchNorm2d() 层融合
def fuse_deconv_and_bn(deconv, bn):"""Fuse ConvTranspose2d() and BatchNorm2d() layers."""# 创建融合后的反卷积层对象fuseddconv = (nn.ConvTranspose2d(deconv.in_channels,deconv.out_channels,kernel_size=deconv.kernel_size,stride=deconv.stride,padding=deconv.padding,output_padding=deconv.output_padding,dilation=deconv.dilation,groups=deconv.groups,bias=True,).requires_grad_(False)  # 禁用梯度追踪,不需要反向传播训练.to(deconv.weight.device)  # 将融合后的反卷积层移到与输入反卷积层相同的设备上)# 准备反卷积层的权重w_deconv = deconv.weight.clone().view(deconv.out_channels, -1)# 计算融合后的权重w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))# 准备空间偏置项b_conv = torch.zeros(deconv.weight.shape[1], device=deconv.weight.device) if deconv.bias is None else deconv.bias# 计算融合后的偏置项b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))fuseddconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)return fuseddconv# 输出模型的信息,包括参数数量、梯度数量和层的数量
def model_info(model, detailed=False, verbose=True, imgsz=640):"""Model information.imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]."""# 如果不需要详细信息,则直接返回if not verbose:return# 获取模型的参数数量n_p = get_num_params(model)  # number of parameters# 获取模型的梯度数量n_g = get_num_gradients(model)  # number of gradients# 获取模型的层数量n_l = len(list(model.modules()))  # number of layers# 如果 detailed 参数为 True,则输出详细的模型参数信息if detailed:# 使用 LOGGER 记录模型参数的详细信息表头,包括层编号、名称、梯度是否计算、参数数量、形状、平均值、标准差和数据类型LOGGER.info(f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")# 遍历模型的所有命名参数,并给每个参数分配一个序号 ifor i, (name, p) in enumerate(model.named_parameters()):# 去除参数名中的 "module_list." 字符串name = name.replace("module_list.", "")# 使用 LOGGER 记录每个参数的详细信息,包括序号、名称、是否需要梯度、参数数量、形状、平均值、标准差和数据类型LOGGER.info("%5g %40s %9s %12g %20s %10.3g %10.3g %10s"% (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype))# 计算模型的浮点运算量(FLOPs)flops = get_flops(model, imgsz)# 检查模型是否支持融合计算,如果支持,则添加 " (fused)" 到输出中fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else ""# 如果计算得到的 FLOPs 不为空,则添加到输出中fs = f", {flops:.1f} GFLOPs" if flops else ""# 获取模型的 YAML 文件路径或者直接从模型属性中获取 YAML 文件路径,并去除路径中的 "yolo" 替换为 "YOLO",或默认为 "Model"yaml_file = getattr(model, "yaml_file", "") or getattr(model, "yaml", {}).get("yaml_file", "")model_name = Path(yaml_file).stem.replace("yolo", "YOLO") or "Model"# 使用 LOGGER 记录模型的总结信息,包括模型名称、层数量、参数数量、梯度数量和计算量信息LOGGER.info(f"{model_name} summary{fused}: {n_l:,} layers, {n_p:,} parameters, {n_g:,} gradients{fs}")# 返回模型的层数量、参数数量、梯度数量和计算量return n_l, n_p, n_g, flops
# 返回 YOLO 模型中的总参数数量
def get_num_params(model):return sum(x.numel() for x in model.parameters())# 返回 YOLO 模型中具有梯度的参数总数
def get_num_gradients(model):return sum(x.numel() for x in model.parameters() if x.requires_grad)# 为日志记录器返回包含有用模型信息的字典
def model_info_for_loggers(trainer):if trainer.args.profile:  # 如果需要进行 ONNX 和 TensorRT 的性能分析from ultralytics.utils.benchmarks import ProfileModels# 使用 ProfileModels 进行模型性能分析,获取结果results = ProfileModels([trainer.last], device=trainer.device).profile()[0]results.pop("model/name")  # 移除结果中的模型名称else:  # 否则仅返回最近验证的 PyTorch 时间信息results = {"model/parameters": get_num_params(trainer.model),  # 计算模型参数数量"model/GFLOPs": round(get_flops(trainer.model), 3),  # 计算模型的 GFLOPs}results["model/speed_PyTorch(ms)"] = round(trainer.validator.speed["inference"], 3)  # 记录 PyTorch 推理速度return results# 返回 YOLO 模型的 FLOPs(浮点运算数)
def get_flops(model, imgsz=640):if not thop:return 0.0  # 如果 thop 包未安装,返回 0.0 GFLOPstry:model = de_parallel(model)  # 取消模型的并行化p = next(model.parameters())if not isinstance(imgsz, list):imgsz = [imgsz, imgsz]  # 如果 imgsz 是 int 或 float,扩展为列表try:stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32  # 获取输入张量的步幅大小im = torch.empty((1, p.shape[1], stride, stride), device=p.device)  # 创建输入图像张量flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2  # 使用 thop 计算 GFLOPsreturn flops * imgsz[0] / stride * imgsz[1] / stride  # 计算基于图像尺寸的 GFLOPsexcept Exception:im = torch.empty((1, p.shape[1], *imgsz), device=p.device)  # 创建输入图像张量return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2  # 计算基于图像尺寸的 GFLOPsexcept Exception:return 0.0  # 发生异常时返回 0.0 GFLOPs# 使用 Torch 分析器计算模型的 FLOPs(thop 包的替代方案,但速度通常较慢 2-10 倍)
def get_flops_with_torch_profiler(model, imgsz=640):if not TORCH_2_0:  # 如果 Torch 版本低于 2.0,返回 0.0return 0.0model = de_parallel(model)  # 取消模型的并行化p = next(model.parameters())if not isinstance(imgsz, list):imgsz = [imgsz, imgsz]  # 如果 imgsz 是 int 或 float,扩展为列表try:# 使用模型的步幅大小来确定输入张量的步幅stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2  # 最大步幅# 创建一个空的张量作为输入图像,格式为BCHWim = torch.empty((1, p.shape[1], stride, stride), device=p.device)with torch.profiler.profile(with_flops=True) as prof:# 对模型进行推理,记录性能指标model(im)# 计算模型的浮点运算量(FLOPs)flops = sum(x.flops for x in prof.key_averages()) / 1e9# 根据输入图像大小调整计算的FLOPs,例如 640x640 GFLOPsflops = flops * imgsz[0] / stride * imgsz[1] / strideexcept Exception:# 对于RTDETR模型,使用实际图像大小作为输入张量的大小im = torch.empty((1, p.shape[1], *imgsz), device=p.device)  # 输入图像为BCHW格式with torch.profiler.profile(with_flops=True) as prof:# 对模型进行推理,记录性能指标model(im)# 计算模型的浮点运算量(FLOPs)flops = sum(x.flops for x in prof.key_averages()) / 1e9# 返回计算得到的FLOPsreturn flops
def initialize_weights(model):"""Initialize model weights to random values."""# Iterate over all modules in the modelfor m in model.modules():t = type(m)# Check if the module is a 2D convolutional layerif t is nn.Conv2d:pass  # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')# Check if the module is a 2D batch normalization layerelif t is nn.BatchNorm2d:# Set epsilon (eps) and momentum parametersm.eps = 1e-3m.momentum = 0.03# Check if the module is one of the specified activation functionselif t in {nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU}:# Enable inplace operation for the activation functionm.inplace = Truedef scale_img(img, ratio=1.0, same_shape=False, gs=32):"""Scales and pads an image tensor of shape img(bs,3,y,x) based on given ratio and grid size gs, optionallyretaining the original shape."""# If ratio is 1.0, return the original image tensorif ratio == 1.0:return img# Retrieve height and width from the image tensor shapeh, w = img.shape[2:]# Compute the new scaled size based on the given ratios = (int(h * ratio), int(w * ratio))  # new size# Resize the image tensor using bilinear interpolationimg = F.interpolate(img, size=s, mode="bilinear", align_corners=False)  # resize# If not retaining the original shape, pad or crop the image tensorif not same_shape:# Calculate the padded height and width based on the ratio and grid sizeh, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))# Pad the image tensor to match the calculated dimensionsreturn F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447)  # value = imagenet meandef copy_attr(a, b, include=(), exclude=()):"""Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes."""# Iterate through attributes in object 'b'for k, v in b.__dict__.items():# Skip attributes based on conditions: not in include list, starts with '_', or in exclude listif (len(include) and k not in include) or k.startswith("_") or k in exclude:continueelse:# Set attribute 'k' in object 'a' to the value 'v' from object 'b'setattr(a, k, v)def get_latest_opset():"""Return the second-most recent ONNX opset version supported by this version of PyTorch, adjusted for maturity."""# Check if using PyTorch version 1.13 or newerif TORCH_1_13:# Dynamically compute the second-most recent ONNX opset version supportedreturn max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1# For PyTorch versions <= 1.12, return predefined opset versionsversion = torch.onnx.producer_version.rsplit(".", 1)[0]  # i.e. '2.3'return {"1.12": 15, "1.11": 14, "1.10": 13, "1.9": 12, "1.8": 12}.get(version, 12)def intersect_dicts(da, db, exclude=()):"""Returns a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values."""# Create a dictionary comprehension to filter keys based on conditionsreturn {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}def is_parallel(model):"""Returns True if model is of type DP or DDP."""# Check if the model is an instance of DataParallel or DistributedDataParallelreturn isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))def de_parallel(model):"""De-parallelize a model: returns single-GPU model if model is of type DP or DDP."""# Return the underlying module of a DataParallel or DistributedDataParallel modelreturn model.module if is_parallel(model) else modeldef one_cycle(y1=0.0, y2=1.0, steps=100):"""Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf."""# Generate a lambda function that implements a sinusoidal rampreturn lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + y1def init_seeds(seed=0, deterministic=False):"""Initialize random number generator seeds."""# This function initializes seeds for random number generators# It is intended to be implemented further, but the current snippet does not contain the complete implementation.pass# 初始化随机数生成器(RNG)种子,以确保实验的可复现性 https://pytorch.org/docs/stable/notes/randomness.html.random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)  # 用于多GPU情况下的种子设置,确保异常安全性# torch.backends.cudnn.benchmark = True  # AutoBatch问题 https://github.com/ultralytics/yolov5/issues/9287# 如果需要确定性行为,则执行以下操作if deterministic:if TORCH_2_0:# 使用确定性算法,并在不可确定时发出警告torch.use_deterministic_algorithms(True, warn_only=True)torch.backends.cudnn.deterministic = True# 设置CUBLAS工作空间大小的配置os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"os.environ["PYTHONHASHSEED"] = str(seed)else:# 提示升级到torch>=2.0.0以实现确定性训练LOGGER.warning("WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.")else:# 关闭确定性算法,允许非确定性行为torch.use_deterministic_algorithms(False)torch.backends.cudnn.deterministic = False
class ModelEMA:"""Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models. Keeps a movingaverage of everything in the model state_dict (parameters and buffers)For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverageTo disable EMA set the `enabled` attribute to `False`."""def __init__(self, model, decay=0.9999, tau=2000, updates=0):"""Initialize EMA for 'model' with given arguments."""self.ema = deepcopy(de_parallel(model)).eval()  # FP32 EMAself.updates = updates  # number of EMA updatesself.decay = lambda x: decay * (1 - math.exp(-x / tau))  # decay exponential ramp (to help early epochs)for p in self.ema.parameters():p.requires_grad_(False)self.enabled = Truedef update(self, model):"""Update EMA parameters."""if self.enabled:self.updates += 1d = self.decay(self.updates)msd = de_parallel(model).state_dict()  # model state_dictfor k, v in self.ema.state_dict().items():if v.dtype.is_floating_point:  # true for FP16 and FP32v *= dv += (1 - d) * msd[k].detach()# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype},  model {msd[k].dtype}'def update_attr(self, model, include=(), exclude=("process_group", "reducer")):"""Updates attributes and saves stripped model with optimizer removed."""if self.enabled:copy_attr(self.ema, model, include, exclude)def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:"""Strip optimizer from 'f' to finalize training, optionally save as 's'.Args:f (str): file path to model to strip the optimizer from. Default is 'best.pt'.s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.Returns:NoneExample:```pyfrom pathlib import Pathfrom ultralytics.utils.torch_utils import strip_optimizerfor f in Path('path/to/model/checkpoints').rglob('*.pt'):strip_optimizer(f)```Note:Use `ultralytics.nn.torch_safe_load` for missing modules with `x = torch_safe_load(f)[0]`"""try:x = torch.load(f, map_location=torch.device("cpu"))assert isinstance(x, dict), "checkpoint is not a Python dictionary"assert "model" in x, "'model' missing from checkpoint"except Exception as e:LOGGER.warning(f"WARNING ⚠️ Skipping {f}, not a valid Ultralytics model: {e}")returnupdates = {"date": datetime.now().isoformat(),"version": __version__,"license": "AGPL-3.0 License (https://ultralytics.com/license)","docs": "https://docs.ultralytics.com",}# Update model# 如果字典 x 中有 "ema" 键,则将 "model" 键的值设为 "ema" 的值,替换模型为 EMA 模型if x.get("ema"):x["model"] = x["ema"]  # replace model with EMA# 如果 "model" 对象具有 "args" 属性,将其转换为字典类型,从 IterableSimpleNamespace 转换为 dictif hasattr(x["model"], "args"):x["model"].args = dict(x["model"].args)  # convert from IterableSimpleNamespace to dict# 如果 "model" 对象具有 "criterion" 属性,将其设置为 None,去除损失函数的标准if hasattr(x["model"], "criterion"):x["model"].criterion = None  # strip loss criterion# 将模型转换为半精度浮点数表示,即 FP16x["model"].half()  # to FP16# 将模型的所有参数设置为不需要梯度计算for p in x["model"].parameters():p.requires_grad = False# 更新字典中的其他键args = {**DEFAULT_CFG_DICT, **x.get("train_args", {})}  # 将 DEFAULT_CFG_DICT 和 x 中的 "train_args" 合并为一个字典for k in "optimizer", "best_fitness", "ema", "updates":  # 遍历指定的键x[k] = None  # 将字典 x 中指定键的值设为 Nonex["epoch"] = -1  # 将 epoch 键的值设为 -1# 创建一个新字典,其中仅包含 DEFAULT_CFG_KEYS 中存在的键值对,并将其赋给 "train_args"x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS}  # strip non-default keys# x['model'].args = x['train_args']  # 此行代码被注释掉了,不再使用# 将 updates 和 x 中的内容合并为一个字典,并保存到文件中,不使用 dill 序列化torch.save({**updates, **x}, s or f, use_dill=False)  # combine dicts (prefer to the right)# 获取文件的大小,并将其转换为兆字节(MB)mb = os.path.getsize(s or f) / 1e6  # file size# 记录日志,显示优化器已从文件中剥离,同时显示文件名和文件大小LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
# 将给定优化器的状态字典转换为FP16格式,重点在于转换'torch.Tensor'类型的数据
def convert_optimizer_state_dict_to_fp16(state_dict):# 遍历优化器状态字典中的'state'键对应的所有状态for state in state_dict["state"].values():# 遍历每个状态的键值对for k, v in state.items():# 排除键为"step"且值为'torch.Tensor'类型且数据类型为torch.float32的情况if k != "step" and isinstance(v, torch.Tensor) and v.dtype is torch.float32:# 将符合条件的Tensor类型数据转换为半精度(FP16)state[k] = v.half()# 返回转换后的状态字典return state_dict# Ultralytics速度、内存和FLOPs(浮点运算数)分析器
def profile(input, ops, n=10, device=None):# 结果存储列表results = []# 如果设备参数不是torch.device类型,则选择设备if not isinstance(device, torch.device):device = select_device(device)# 打印日志信息,包括各项参数LOGGER.info(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"f"{'input':>24s}{'output':>24s}")for x in input if isinstance(input, list) else [input]:# 如果输入是列表,则遍历列表中的每个元素;否则将输入放入列表中并遍历x = x.to(device)# 将当前元素移动到指定的设备上(如GPU)x.requires_grad = True# 设置当前元素的梯度跟踪为Truefor m in ops if isinstance(ops, list) else [ops]:# 如果操作是列表,则遍历列表中的每个操作;否则将操作放入列表中并遍历m = m.to(device) if hasattr(m, "to") else m# 如果操作具有"to"方法,则将其移动到指定的设备上;否则保持不变m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m# 如果操作具有"half"方法,并且输入是torch.Tensor类型且数据类型为torch.float16,则将操作转换为半精度(float16);否则保持不变tf, tb, t = 0, 0, [0, 0, 0]# 初始化时间记录变量:前向传播时间,反向传播时间,时间记录列表try:flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0# 使用thop库对操作进行浮点操作计算(FLOPs),并将结果转换为GFLOPs(十亿次浮点操作每秒)except Exception:flops = 0# 如果计算FLOPs出现异常,则将FLOPs设置为0try:for _ in range(n):t[0] = time_sync()# 记录前向传播开始时间y = m(x)# 执行操作的前向传播t[1] = time_sync()# 记录前向传播结束时间try:(sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()# 计算输出y的总和,并对总和进行反向传播t[2] = time_sync()# 记录反向传播结束时间except Exception:  # no backward methodt[2] = float("nan")# 如果没有反向传播方法,则将反向传播时间设置为NaNtf += (t[1] - t[0]) * 1000 / n# 计算每个操作的平均前向传播时间(毫秒)tb += (t[2] - t[1]) * 1000 / n# 计算每个操作的平均反向传播时间(毫秒)mem = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0# 如果CUDA可用,则计算当前GPU上的内存使用量(单位:GB)s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y))# 获取输入x和输出y的形状信息p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0# 计算操作m中的参数数量LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}")# 将结果记录到日志中,包括参数数量、FLOPs、内存占用、时间等信息results.append([p, flops, mem, tf, tb, s_in, s_out])# 将结果添加到结果列表中except Exception as e:LOGGER.info(e)# 记录异常信息到日志中results.append(None)# 将空结果添加到结果列表中gc.collect()# 尝试释放未使用的内存torch.cuda.empty_cache()# 清空CUDA缓存return results# 返回所有操作的结果列表
class EarlyStopping:"""Early stopping class that stops training when a specified number of epochs have passed without improvement."""def __init__(self, patience=50):"""Initialize early stopping object.Args:patience (int, optional): Number of epochs to wait after fitness stops improving before stopping."""self.best_fitness = 0.0  # 初始化最佳适应度为0.0,即最佳平均精度(mAP)self.best_epoch = 0  # 初始化最佳轮次为0self.patience = patience or float("inf")  # 设置等待适应度停止提高的轮次数,若未提供则设为无穷大self.possible_stop = False  # 是否可能在下一个轮次停止训练的标志def __call__(self, epoch, fitness):"""Check whether to stop training.Args:epoch (int): Current epoch of trainingfitness (float): Fitness value of current epochReturns:(bool): True if training should stop, False otherwise"""if fitness is None:  # 检查适应度是否为None(当val=False时会发生)return Falseif fitness >= self.best_fitness:  # 如果当前适应度大于或等于最佳适应度self.best_epoch = epoch  # 更新最佳轮次为当前轮次self.best_fitness = fitness  # 更新最佳适应度为当前适应度delta = epoch - self.best_epoch  # 计算未改善的轮次数self.possible_stop = delta >= (self.patience - 1)  # 更新可能在下一个轮次停止训练的标志stop = delta >= self.patience  # 若未改善的轮次数超过设定的等待轮次数,则停止训练if stop:prefix = colorstr("EarlyStopping: ")  # 设置输出前缀LOGGER.info(f"{prefix}Training stopped early as no improvement observed in last {self.patience} epochs. "f"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n"f"To update EarlyStopping(patience={self.patience}) pass a new patience value, "f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping.")  # 输出停止训练信息return stop  # 返回是否停止训练的标志

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.ryyt.cn/news/56845.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈,一经查实,立即删除!

相关文章

Yolov8-源码解析-四十二-

Yolov8 源码解析(四十二) .\yolov8\ultralytics\utils\loss.py # 导入PyTorch库中需要的模块 import torch import torch.nn as nn import torch.nn.functional as F# 从Ultralytics工具包中导入一些特定的功能 from ultralytics.utils.metrics import OKS_SIGMA from ultral…

Yolov8-源码解析-二十六-

Yolov8 源码解析(二十六) .\yolov8\tests\test_engine.py # 导入所需的模块和库 import sys # 系统模块 from unittest import mock # 导入 mock 模块# 导入自定义模块和类 from tests import MODEL # 导入 tests 模块中的 MODEL 对象 from ultralytics import YOLO # 导…

Yolov8-源码解析-二十八-

Yolov8 源码解析(二十八) .\yolov8\ultralytics\data\base.py # Ultralytics YOLO 🚀, AGPL-3.0 licenseimport glob # 导入用于获取文件路径的模块 import math # 导入数学函数模块 import os # 导入操作系统功能模块 import random # 导入生成随机数的模块 from copy…

Yolov8-源码解析-八-

Yolov8 源码解析(八)comments: true description: Learn how to manage and optimize queues using Ultralytics YOLOv8 to reduce wait times and increase efficiency in various real-world applications. keywords: queue management, YOLOv8, Ultralytics, reduce wait …

FLUX 源码解析(全)

.\flux\demo_gr.py # 导入操作系统相关模块 import os # 导入时间相关模块 import time # 从 io 模块导入 BytesIO 类 from io import BytesIO # 导入 UUID 生成模块 import uuid# 导入 PyTorch 库 import torch # 导入 Gradio 库 import gradio as gr # 导入 NumPy 库 import …

【优技教育】Oracle 19c OCP 082题库(第13题)- 2024年修正版

【优技教育】Oracle 19c OCP 082题库(Q 13题)- 2024年修正版 考试科目:1Z0-082 考试题量:90 通过分数:60% 考试时间:150min 本文为(CUUG 原创)整理并解析,转发请注明出处,禁止抄袭及未经注明出处的转载。 原文地址:http://www.cuug.com.cn/ocp/082kaoshitiku/3817564823…

最快捷查看电脑启动项内容

很多人好奇很多电脑的默认启动项从哪里的看,其实就在运行窗口开两个命令就行了。 第一个,看先用户端设置的启动项: shell:Startup 这个是针对当前登录用户的。 第二个,查看电脑最高权限的通用启动项shell:Common Startup 这个是针对所有用户的。 操作的方式很简单 就是把要…

react中使用echarts关系图

一,工作需求,展示几类数据关系,可缩放大小,可拖拽位置,在节点之间的连线上展示相关日期,每个节点展示本身信息,并且要求每个关系节点能点击。 实现情况如图所示:二,实现过程中遇到的问题: 关系图完美呈现,但关系节点点击后,整个关系图会杂乱无章的浮动,导致不知道…