Yolov8-源码解析-二十六-

news/2024/9/24 13:43:47

Yolov8 源码解析(二十六)

.\yolov8\tests\test_engine.py

# 导入所需的模块和库
import sys  # 系统模块
from unittest import mock  # 导入 mock 模块# 导入自定义模块和类
from tests import MODEL  # 导入 tests 模块中的 MODEL 对象
from ultralytics import YOLO  # 导入 ultralytics 库中的 YOLO 类
from ultralytics.cfg import get_cfg  # 导入 ultralytics 库中的 get_cfg 函数
from ultralytics.engine.exporter import Exporter  # 导入 ultralytics 库中的 Exporter 类
from ultralytics.models.yolo import classify, detect, segment  # 导入 ultralytics 库中的 classify, detect, segment 函数
from ultralytics.utils import ASSETS, DEFAULT_CFG, WEIGHTS_DIR  # 导入 ultralytics 库中的 ASSETS, DEFAULT_CFG, WEIGHTS_DIR 变量def test_func(*args):  # 定义测试函数,用于评估 YOLO 模型性能指标"""Test function callback for evaluating YOLO model performance metrics."""print("callback test passed")  # 打印测试通过消息def test_export():"""Tests the model exporting function by adding a callback and asserting its execution."""exporter = Exporter()  # 创建 Exporter 对象exporter.add_callback("on_export_start", test_func)  # 添加回调函数到导出开始事件assert test_func in exporter.callbacks["on_export_start"], "callback test failed"  # 断言回调函数已成功添加f = exporter(model=YOLO("yolov8n.yaml").model)  # 导出模型YOLO(f)(ASSETS)  # 使用导出后的模型进行推理def test_detect():"""Test YOLO object detection training, validation, and prediction functionality."""overrides = {"data": "coco8.yaml", "model": "yolov8n.yaml", "imgsz": 32, "epochs": 1, "save": False}  # 定义参数覆盖字典cfg = get_cfg(DEFAULT_CFG)  # 获取默认配置cfg.data = "coco8.yaml"  # 设置配置数据文件cfg.imgsz = 32  # 设置配置图像尺寸# Trainertrainer = detect.DetectionTrainer(overrides=overrides)  # 创建检测训练器对象trainer.add_callback("on_train_start", test_func)  # 添加回调函数到训练开始事件assert test_func in trainer.callbacks["on_train_start"], "callback test failed"  # 断言回调函数已成功添加trainer.train()  # 执行训练# Validatorval = detect.DetectionValidator(args=cfg)  # 创建检测验证器对象val.add_callback("on_val_start", test_func)  # 添加回调函数到验证开始事件assert test_func in val.callbacks["on_val_start"], "callback test failed"  # 断言回调函数已成功添加val(model=trainer.best)  # 使用最佳模型进行验证# Predictorpred = detect.DetectionPredictor(overrides={"imgsz": [64, 64]})  # 创建检测预测器对象pred.add_callback("on_predict_start", test_func)  # 添加回调函数到预测开始事件assert test_func in pred.callbacks["on_predict_start"], "callback test failed"  # 断言回调函数已成功添加# 确认 sys.argv 为空没有问题with mock.patch.object(sys, "argv", []):result = pred(source=ASSETS, model=MODEL)  # 执行预测assert len(result), "predictor test failed"  # 断言预测结果不为空overrides["resume"] = trainer.last  # 设置训练器的恢复模型trainer = detect.DetectionTrainer(overrides=overrides)  # 创建新的检测训练器对象try:trainer.train()  # 执行训练except Exception as e:print(f"Expected exception caught: {e}")  # 捕获并打印预期的异常returnException("Resume test failed!")  # 报告恢复测试失败def test_segment():"""Tests image segmentation training, validation, and prediction pipelines using YOLO models."""overrides = {"data": "coco8-seg.yaml", "model": "yolov8n-seg.yaml", "imgsz": 32, "epochs": 1, "save": False}  # 定义参数覆盖字典cfg = get_cfg(DEFAULT_CFG)  # 获取默认配置cfg.data = "coco8-seg.yaml"  # 设置配置数据文件cfg.imgsz = 32  # 设置配置图像尺寸# YOLO(CFG_SEG).train(**overrides)  # works# Trainertrainer = segment.SegmentationTrainer(overrides=overrides)  # 创建分割训练器对象trainer.add_callback("on_train_start", test_func)  # 添加回调函数到训练开始事件assert test_func in trainer.callbacks["on_train_start"], "callback test failed"  # 断言回调函数已成功添加trainer.train()  # 执行训练# Validatorval = segment.SegmentationValidator(args=cfg)  # 创建分割验证器对象# 添加回调函数到“on_val_start”事件,使其在val对象开始时调用test_func函数val.add_callback("on_val_start", test_func)# 断言确认test_func确实添加到val对象的“on_val_start”事件回调列表中assert test_func in val.callbacks["on_val_start"], "callback test failed"# 使用trainer.best模型对val对象进行验证,验证best.pt模型的性能val(model=trainer.best)  # validate best.pt# 创建SegmentationPredictor对象pred,覆盖参数imgsz为[64, 64]pred = segment.SegmentationPredictor(overrides={"imgsz": [64, 64]})# 添加回调函数到“on_predict_start”事件,使其在pred对象开始预测时调用test_func函数pred.add_callback("on_predict_start", test_func)# 断言确认test_func确实添加到pred对象的“on_predict_start”事件回调列表中assert test_func in pred.callbacks["on_predict_start"], "callback test failed"# 使用指定的模型进行预测,源数据为ASSETS,模型为WEIGHTS_DIR / "yolov8n-seg.pt"result = pred(source=ASSETS, model=WEIGHTS_DIR / "yolov8n-seg.pt")# 断言确保结果非空,验证预测器的功能assert len(result), "predictor test failed"# 测试恢复功能overrides["resume"] = trainer.last  # 设置恢复参数为trainer的最后状态trainer = segment.SegmentationTrainer(overrides=overrides)  # 使用指定参数创建SegmentationTrainer对象try:trainer.train()  # 尝试训练模型except Exception as e:# 捕获异常并输出异常信息print(f"Expected exception caught: {e}")return# 如果发生异常未被捕获,则抛出异常信息“Resume test failed!”Exception("Resume test failed!")
def test_classify():"""Test image classification including training, validation, and prediction phases."""# 定义需要覆盖的配置项overrides = {"data": "imagenet10", "model": "yolov8n-cls.yaml", "imgsz": 32, "epochs": 1, "save": False# 根据默认配置获取配置对象cfg = get_cfg(DEFAULT_CFG)# 调整配置项中的数据集为 imagenet10cfg.data = "imagenet10"# 调整配置项中的图片尺寸为 32cfg.imgsz = 32# YOLO(CFG_SEG).train(**overrides)  # works# 创建分类训练器对象,应用 overrides 中的配置项trainer = classify.ClassificationTrainer(overrides=overrides)# 添加在训练开始时执行的回调函数 test_functrainer.add_callback("on_train_start", test_func)# 断言 test_func 是否成功添加到训练器的 on_train_start 回调中assert test_func in trainer.callbacks["on_train_start"], "callback test failed"# 开始训练trainer.train()# 创建分类验证器对象,使用 cfg 中的配置项val = classify.ClassificationValidator(args=cfg)# 添加在验证开始时执行的回调函数 test_funcval.add_callback("on_val_start", test_func)# 断言 test_func 是否成功添加到验证器的 on_val_start 回调中assert test_func in val.callbacks["on_val_start"], "callback test failed"# 执行验证,使用训练器中的最佳模型val(model=trainer.best)# 创建分类预测器对象,应用 imgsz 为 [64, 64] 的配置项pred = classify.ClassificationPredictor(overrides={"imgsz": [64, 64]})# 添加在预测开始时执行的回调函数 test_funcpred.add_callback("on_predict_start", test_func)# 断言 test_func 是否成功添加到预测器的 on_predict_start 回调中assert test_func in pred.callbacks["on_predict_start"], "callback test failed"# 使用 ASSETS 中的数据源和训练器中的最佳模型进行预测result = pred(source=ASSETS, model=trainer.best)# 断言预测结果不为空,表示预测器测试通过assert len(result), "predictor test failed"

.\yolov8\tests\test_explorer.py

# 导入必要的库和模块:PIL 图像处理库和 pytest 测试框架
import PIL
import pytest# 从 ultralytics 包中导入 Explorer 类和 ASSETS 资源
from ultralytics import Explorer
from ultralytics.utils import ASSETS# 使用 pytest 的标记 @pytest.mark.slow 标记此函数为慢速测试
@pytest.mark.slow
def test_similarity():"""测试 Explorer 中相似性计算和 SQL 查询的正确性和返回长度。"""# 创建 Explorer 对象,使用配置文件 'coco8.yaml'exp = Explorer(data="coco8.yaml")# 创建嵌入表格exp.create_embeddings_table()# 获取索引为 1 的相似项similar = exp.get_similar(idx=1)# 断言相似项的长度为 4assert len(similar) == 4# 使用图像文件 'bus.jpg' 获取相似项similar = exp.get_similar(img=ASSETS / "bus.jpg")# 断言相似项的长度为 4assert len(similar) == 4# 获取索引为 [1, 2] 的相似项,限制返回结果为 2 个similar = exp.get_similar(idx=[1, 2], limit=2)# 断言相似项的长度为 2assert len(similar) == 2# 获取相似性索引sim_idx = exp.similarity_index()# 断言相似性索引的长度为 4assert len(sim_idx) == 4# 执行 SQL 查询,查询条件为 'labels LIKE '%zebra%''sql = exp.sql_query("WHERE labels LIKE '%zebra%'")# 断言 SQL 查询结果的长度为 1assert len(sql) == 1@pytest.mark.slow
def test_det():"""测试检测功能,并验证嵌入表格是否包含边界框。"""# 创建 Explorer 对象,使用配置文件 'coco8.yaml' 和模型 'yolov8n.pt'exp = Explorer(data="coco8.yaml", model="yolov8n.pt")# 强制创建嵌入表格exp.create_embeddings_table(force=True)# 断言表格中的边界框列的长度大于 0assert len(exp.table.head()["bboxes"]) > 0# 获取索引为 [1, 2] 的相似项,限制返回结果为 10 个similar = exp.get_similar(idx=[1, 2], limit=10)# 断言相似项的长度大于 0assert len(similar) > 0# 执行绘制相似项的操作,返回值应为 PIL 图像对象similar = exp.plot_similar(idx=[1, 2], limit=10)# 断言返回值是 PIL 图像对象assert isinstance(similar, PIL.Image.Image)@pytest.mark.slow
def test_seg():"""测试分割功能,并确保嵌入表格包含分割掩码。"""# 创建 Explorer 对象,使用配置文件 'coco8-seg.yaml' 和模型 'yolov8n-seg.pt'exp = Explorer(data="coco8-seg.yaml", model="yolov8n-seg.pt")# 强制创建嵌入表格exp.create_embeddings_table(force=True)# 断言表格中的分割掩码列的长度大于 0assert len(exp.table.head()["masks"]) > 0# 获取索引为 [1, 2] 的相似项,限制返回结果为 10 个similar = exp.get_similar(idx=[1, 2], limit=10)# 断言相似项的长度大于 0assert len(similar) > 0# 执行绘制相似项的操作,返回值应为 PIL 图像对象similar = exp.plot_similar(idx=[1, 2], limit=10)# 断言返回值是 PIL 图像对象assert isinstance(similar, PIL.Image.Image)@pytest.mark.slow
def test_pose():"""测试姿势估计功能,并验证嵌入表格是否包含关键点。"""# 创建 Explorer 对象,使用配置文件 'coco8-pose.yaml' 和模型 'yolov8n-pose.pt'exp = Explorer(data="coco8-pose.yaml", model="yolov8n-pose.pt")# 强制创建嵌入表格exp.create_embeddings_table(force=True)# 断言表格中的关键点列的长度大于 0assert len(exp.table.head()["keypoints"]) > 0# 获取索引为 [1, 2] 的相似项,限制返回结果为 10 个similar = exp.get_similar(idx=[1, 2], limit=10)# 断言相似项的长度大于 0assert len(similar) > 0# 执行绘制相似项的操作,返回值应为 PIL 图像对象similar = exp.plot_similar(idx=[1, 2], limit=10)# 断言返回值是 PIL 图像对象assert isinstance(similar, PIL.Image.Image)

.\yolov8\tests\test_exports.py

# 导入所需的库和模块
import shutil  # 文件操作工具,用于复制、移动和删除文件和目录
import uuid  # 用于生成唯一的UUID
from itertools import product  # 用于生成迭代器的笛卡尔积
from pathlib import Path  # 用于处理文件和目录路径的类import pytest  # 测试框架# 导入测试所需的模块和函数
from tests import MODEL, SOURCE
from ultralytics import YOLO  # 导入YOLO模型
from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS  # 导入配置信息
from ultralytics.utils import (IS_RASPBERRYPI,  # 检查是否在树莓派上运行LINUX,  # 检查是否在Linux系统上运行MACOS,  # 检查是否在macOS系统上运行WINDOWS,  # 检查是否在Windows系统上运行checks,  # 各种系统和Python版本的检查工具集合
)
from ultralytics.utils.torch_utils import TORCH_1_9, TORCH_1_13  # Torch相关的工具函数和版本检查
# 测试导出 YOLO 模型到 ONNX 格式,使用不同的配置和参数进行测试
def test_export_onnx_matrix(task, dynamic, int8, half, batch, simplify):# 调用 YOLO 类,根据任务选择相应的模型,然后导出为 ONNX 格式的文件file = YOLO(TASK2MODEL[task]).export(format="onnx",imgsz=32,dynamic=dynamic,int8=int8,half=half,batch=batch,simplify=simplify,)# 使用导出的模型进行推理,传入相同的源数据多次以达到批处理要求YOLO(file)([SOURCE] * batch, imgsz=64 if dynamic else 32)  # exported model inference# 清理生成的文件Path(file).unlink()  # cleanup@pytest.mark.slow
@pytest.mark.parametrize("task, dynamic, int8, half, batch", product(TASKS, [False], [False], [False], [1, 2]))
# 测试导出 YOLO 模型到 TorchScript 格式,考虑不同的配置和参数组合
def test_export_torchscript_matrix(task, dynamic, int8, half, batch):# 调用 YOLO 类,根据任务选择相应的模型,然后导出为 TorchScript 格式的文件file = YOLO(TASK2MODEL[task]).export(format="torchscript",imgsz=32,dynamic=dynamic,int8=int8,half=half,batch=batch,)# 使用导出的模型进行推理,传入特定的源数据以达到批处理要求YOLO(file)([SOURCE] * 3, imgsz=64 if dynamic else 32)  # exported model inference at batch=3# 清理生成的文件Path(file).unlink()  # cleanup@pytest.mark.slow
# 在 macOS 上测试导出 YOLO 模型到 CoreML 格式,使用各种参数配置
@pytest.mark.skipif(not MACOS, reason="CoreML inference only supported on macOS")
@pytest.mark.skipif(not TORCH_1_9, reason="CoreML>=7.2 not supported with PyTorch<=1.8")
@pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="CoreML not supported in Python 3.12")
@pytest.mark.parametrize("task, dynamic, int8, half, batch",[  # 生成所有组合,但排除 int8 和 half 都为 True 的情况(task, dynamic, int8, half, batch)for task, dynamic, int8, half, batch in product(TASKS, [False], [True, False], [True, False], [1])if not (int8 and half)  # 排除 int8 和 half 都为 True 的情况],
)
def test_export_coreml_matrix(task, dynamic, int8, half, batch):# 调用 YOLO 类,根据任务选择相应的模型,然后导出为 CoreML 格式的文件file = YOLO(TASK2MODEL[task]).export(format="coreml",imgsz=32,dynamic=dynamic,int8=int8,half=half,batch=batch,)# 使用导出的模型进行推理,传入特定的源数据以达到批处理要求YOLO(file)([SOURCE] * batch, imgsz=32)  # exported model inference at batch=3# 清理生成的文件夹shutil.rmtree(file)  # cleanup@pytest.mark.slow
# 在 Python 版本大于等于 3.10 时,在 Linux 上测试导出 YOLO 模型到 TFLite 格式
@pytest.mark.skipif(not checks.IS_PYTHON_MINIMUM_3_10, reason="TFLite export requires Python>=3.10")
@pytest.mark.skipif(not LINUX, reason="Test disabled as TF suffers from install conflicts on Windows and macOS")
@pytest.mark.parametrize("task, dynamic, int8, half, batch",[  # 生成所有组合,但排除 int8 和 half 都为 True 的情况(task, dynamic, int8, half, batch)for task, dynamic, int8, half, batch in product(TASKS, [False], [True, False], [True, False], [1])if not (int8 and half)  # 排除 int8 和 half 都为 True 的情况],
)
# 测试导出 YOLO 模型到 TFLite 格式,考虑各种导出配置
def test_export_tflite_matrix(task, dynamic, int8, half, batch):# 调用 YOLO 类,根据任务选择相应的模型,然后导出为 TFLite 格式的文件file = YOLO(TASK2MODEL[task]).export(format="tflite",imgsz=32,dynamic=dynamic,int8=int8,half=half,batch=batch,)# 使用导出的模型进行推理,传入特定的源数据以达到批处理要求YOLO(file)([SOURCE] * batch, imgsz=32)  # exported model inference at batch=3# 清理生成的文件夹shutil.rmtree(file)  # cleanup# 使用指定任务的模型从YOLO导出模型,并以tflite格式输出到文件file = YOLO(TASK2MODEL[task]).export(format="tflite",imgsz=32,dynamic=dynamic,int8=int8,half=half,batch=batch,)# 使用导出的模型进行推理,输入为[SOURCE]的重复项,批量大小为3,图像尺寸为32YOLO(file)([SOURCE] * batch, imgsz=32)  # 批量大小为3时导出模型的推理# 删除导出的模型文件,进行清理工作Path(file).unlink()  # 清理
# 根据条件跳过测试,若 TORCH_1_9 为假则跳过,提示 PyTorch<=1.8 不支持 CoreML>=7.2
@pytest.mark.skipif(not TORCH_1_9, reason="CoreML>=7.2 not supported with PyTorch<=1.8")
# 若在 Windows 系统上则跳过,提示 CoreML 在 Windows 上不受支持
@pytest.mark.skipif(WINDOWS, reason="CoreML not supported on Windows")  # RuntimeError: BlobWriter not loaded
# 若在树莓派上则跳过,提示 CoreML 在树莓派上不受支持
@pytest.mark.skipif(IS_RASPBERRYPI, reason="CoreML not supported on Raspberry Pi")
# 若 Python 版本为 3.12 则跳过,提示 CoreML 不支持 Python 3.12
@pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="CoreML not supported in Python 3.12")
def test_export_coreml():"""Test YOLO exports to CoreML format, optimized for macOS only."""if MACOS:# 在 macOS 上导出 YOLO 模型到 CoreML 格式,并优化为指定的 imgsz 大小file = YOLO(MODEL).export(format="coreml", imgsz=32)# 使用导出的 CoreML 模型进行预测,仅支持在 macOS 上进行,对于 nms=False 的模型YOLO(file)(SOURCE, imgsz=32)  # model prediction only supported on macOS for nms=False modelselse:# 在非 macOS 系统上导出 YOLO 模型到 CoreML 格式,使用默认的 nms=True 和指定的 imgsz 大小YOLO(MODEL).export(format="coreml", nms=True, imgsz=32)# 若 Python 版本小于 3.10 则跳过,提示 TFLite 导出要求 Python>=3.10
@pytest.mark.skipif(not checks.IS_PYTHON_MINIMUM_3_10, reason="TFLite export requires Python>=3.10")
# 若不在 Linux 系统上则跳过,提示在 Windows 和 macOS 上 TensorFlow 安装可能会冲突
@pytest.mark.skipif(not LINUX, reason="Test disabled as TF suffers from install conflicts on Windows and macOS")
def test_export_tflite():"""Test YOLO exports to TFLite format under specific OS and Python version conditions."""# 创建 YOLO 模型对象model = YOLO(MODEL)# 导出 YOLO 模型到 TFLite 格式,使用指定的 imgsz 大小file = model.export(format="tflite", imgsz=32)# 使用导出的 TFLite 模型进行预测YOLO(file)(SOURCE, imgsz=32)# 直接跳过此测试,无特定原因说明
@pytest.mark.skipif(True, reason="Test disabled")
# 若不在 Linux 系统上则跳过,提示 TensorFlow 在 Windows 和 macOS 上安装可能会冲突
@pytest.mark.skipif(not LINUX, reason="TF suffers from install conflicts on Windows and macOS")
def test_export_pb():"""Test YOLO exports to TensorFlow's Protobuf (*.pb) format."""# 创建 YOLO 模型对象model = YOLO(MODEL)# 导出 YOLO 模型到 TensorFlow 的 Protobuf 格式,使用指定的 imgsz 大小file = model.export(format="pb", imgsz=32)# 使用导出的 Protobuf 模型进行预测YOLO(file)(SOURCE, imgsz=32)# 直接跳过此测试,无特定原因说明
@pytest.mark.skipif(True, reason="Test disabled as Paddle protobuf and ONNX protobuf requirementsk conflict.")
def test_export_paddle():"""Test YOLO exports to Paddle format, noting protobuf conflicts with ONNX."""# 导出 YOLO 模型到 Paddle 格式,使用指定的 imgsz 大小YOLO(MODEL).export(format="paddle", imgsz=32)# 标记为慢速测试
@pytest.mark.slow
def test_export_ncnn():"""Test YOLO exports to NCNN format."""# 导出 YOLO 模型到 NCNN 格式,使用指定的 imgsz 大小file = YOLO(MODEL).export(format="ncnn", imgsz=32)# 使用导出的 NCNN 模型进行预测YOLO(file)(SOURCE, imgsz=32)  # exported model inference

.\yolov8\tests\test_integrations.py

# Ultralytics YOLO 🚀, AGPL-3.0 license# 引入必要的库和模块
import contextlib
import os
import subprocess
import time
from pathlib import Pathimport pytest# 从自定义的模块导入常量和函数
from tests import MODEL, SOURCE, TMP
from ultralytics import YOLO, download
from ultralytics.utils import DATASETS_DIR, SETTINGS
from ultralytics.utils.checks import check_requirements# 使用 pytest 标记,当条件不满足时跳过测试
@pytest.mark.skipif(not check_requirements("ray", install=False), reason="ray[tune] not installed")
def test_model_ray_tune():"""Tune YOLO model using Ray for hyperparameter optimization."""# 调用 YOLO 类来进行模型调参YOLO("yolov8n-cls.yaml").tune(use_ray=True, data="imagenet10", grace_period=1, iterations=1, imgsz=32, epochs=1, plots=False, device="cpu")# 使用 pytest 标记,当条件不满足时跳过测试
@pytest.mark.skipif(not check_requirements("mlflow", install=False), reason="mlflow not installed")
def test_mlflow():"""Test training with MLflow tracking enabled (see https://mlflow.org/ for details)."""# 设置 MLflow 跟踪开启SETTINGS["mlflow"] = True# 调用 YOLO 类来进行模型训练YOLO("yolov8n-cls.yaml").train(data="imagenet10", imgsz=32, epochs=3, plots=False, device="cpu")# 使用 pytest 标记,当条件不满足时跳过测试
@pytest.mark.skipif(True, reason="Test failing in scheduled CI https://github.com/ultralytics/ultralytics/pull/8868")
@pytest.mark.skipif(not check_requirements("mlflow", install=False), reason="mlflow not installed")
def test_mlflow_keep_run_active():"""Ensure MLflow run status matches MLFLOW_KEEP_RUN_ACTIVE environment variable settings."""import mlflow# 设置 MLflow 跟踪开启SETTINGS["mlflow"] = Truerun_name = "Test Run"os.environ["MLFLOW_RUN"] = run_name# 测试 MLFLOW_KEEP_RUN_ACTIVE=True 的情况os.environ["MLFLOW_KEEP_RUN_ACTIVE"] = "True"YOLO("yolov8n-cls.yaml").train(data="imagenet10", imgsz=32, epochs=1, plots=False, device="cpu")# 获取当前 MLflow 运行的状态status = mlflow.active_run().info.statusassert status == "RUNNING", "MLflow run should be active when MLFLOW_KEEP_RUN_ACTIVE=True"run_id = mlflow.active_run().info.run_id# 测试 MLFLOW_KEEP_RUN_ACTIVE=False 的情况os.environ["MLFLOW_KEEP_RUN_ACTIVE"] = "False"YOLO("yolov8n-cls.yaml").train(data="imagenet10", imgsz=32, epochs=1, plots=False, device="cpu")# 获取指定运行 ID 的 MLflow 运行状态status = mlflow.get_run(run_id=run_id).info.statusassert status == "FINISHED", "MLflow run should be ended when MLFLOW_KEEP_RUN_ACTIVE=False"# 测试 MLFLOW_KEEP_RUN_ACTIVE 未设置的情况os.environ.pop("MLFLOW_KEEP_RUN_ACTIVE", None)YOLO("yolov8n-cls.yaml").train(data="imagenet10", imgsz=32, epochs=1, plots=False, device="cpu")# 获取指定运行 ID 的 MLflow 运行状态status = mlflow.get_run(run_id=run_id).info.statusassert status == "FINISHED", "MLflow run should be ended by default when MLFLOW_KEEP_RUN_ACTIVE is not set"# 使用 pytest 标记,当条件不满足时跳过测试
@pytest.mark.skipif(not check_requirements("tritonclient", install=False), reason="tritonclient[all] not installed")
def test_triton():"""Test NVIDIA Triton Server functionalities with YOLO model.See https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver."""# 检查 tritonclient 是否安装check_requirements("tritonclient[all]")# 导入 Triton 的推理服务器客户端模块from tritonclient.http import InferenceServerClient  # noqa# Create variablesmodel_name = "yolo"  # 设置模型名称为 "yolo"triton_repo = TMP / "triton_repo"  # Triton仓库路径设为临时文件目录下的 triton_repo 文件夹triton_model = triton_repo / model_name  # Triton模型路径为 Triton仓库路径下的模型名称文件夹路径# Export model to ONNXf = YOLO(MODEL).export(format="onnx", dynamic=True)  # 将模型导出为ONNX格式文件,并保存路径到变量f# Prepare Triton repo(triton_model / "1").mkdir(parents=True, exist_ok=True)  # 在 Triton模型路径下创建版本号为1的子文件夹,若存在则忽略Path(f).rename(triton_model / "1" / "model.onnx")  # 将导出的ONNX模型文件移动到 Triton模型路径下的版本1文件夹中命名为model.onnx(triton_model / "config.pbtxt").touch()  # 在 Triton模型路径下创建一个名为config.pbtxt的空文件# Define image https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonservertag = "nvcr.io/nvidia/tritonserver:23.09-py3"  # 定义Docker镜像标签为nvcr.io/nvidia/tritonserver:23.09-py3,大小为6.4 GB# Pull the imagesubprocess.call(f"docker pull {tag}", shell=True)  # 使用Docker命令拉取指定标签的镜像# Run the Triton server and capture the container IDcontainer_id = (subprocess.check_output(f"docker run -d --rm -v {triton_repo}:/models -p 8000:8000 {tag} tritonserver --model-repository=/models",shell=True,).decode("utf-8").strip())  # 启动 Triton 服务器,并获取容器的ID# Wait for the Triton server to starttriton_client = InferenceServerClient(url="localhost:8000", verbose=False, ssl=False)  # 创建 Triton 客户端实例连接到本地的 Triton 服务器,端口为8000,关闭详细信息输出,不使用SSL# Wait until model is readyfor _ in range(10):  # 循环10次with contextlib.suppress(Exception):  # 忽略异常assert triton_client.is_model_ready(model_name)  # 断言检查模型是否准备就绪break  # 如果模型就绪,跳出循环time.sleep(1)  # 等待1秒钟# Check Triton inferenceYOLO(f"http://localhost:8000/{model_name}", "detect")(SOURCE)  # 使用导出的模型进行 Triton 推理,传入参数SOURCE作为输入# Kill and remove the container at the end of the testsubprocess.call(f"docker kill {container_id}", shell=True)  # 使用Docker命令终止指定ID的容器并删除
@pytest.mark.skipif(not check_requirements("pycocotools", install=False), reason="pycocotools not installed")
def test_pycocotools():"""Validate YOLO model predictions on COCO dataset using pycocotools."""from ultralytics.models.yolo.detect import DetectionValidatorfrom ultralytics.models.yolo.pose import PoseValidatorfrom ultralytics.models.yolo.segment import SegmentationValidator# Download annotations after each dataset downloads firsturl = "https://github.com/ultralytics/assets/releases/download/v8.2.0/"# 设置检测模型的参数和初始化检测器args = {"model": "yolov8n.pt", "data": "coco8.yaml", "save_json": True, "imgsz": 64}validator = DetectionValidator(args=args)# 运行检测器,执行评估validator()# 标记为COCO数据集validator.is_coco = True# 下载实例注释文件download(f"{url}instances_val2017.json", dir=DATASETS_DIR / "coco8/annotations")# 对评估的JSON文件进行评估_ = validator.eval_json(validator.stats)# 设置分割模型的参数和初始化分割器args = {"model": "yolov8n-seg.pt", "data": "coco8-seg.yaml", "save_json": True, "imgsz": 64}validator = SegmentationValidator(args=args)# 运行分割器,执行评估validator()# 标记为COCO数据集validator.is_coco = True# 下载实例注释文件download(f"{url}instances_val2017.json", dir=DATASETS_DIR / "coco8-seg/annotations")# 对评估的JSON文件进行评估_ = validator.eval_json(validator.stats)# 设置姿势估计模型的参数和初始化姿势估计器args = {"model": "yolov8n-pose.pt", "data": "coco8-pose.yaml", "save_json": True, "imgsz": 64}validator = PoseValidator(args=args)# 运行姿势估计器,执行评估validator()# 标记为COCO数据集validator.is_coco = True# 下载人体关键点注释文件download(f"{url}person_keypoints_val2017.json", dir=DATASETS_DIR / "coco8-pose/annotations")# 对评估的JSON文件进行评估_ = validator.eval_json(validator.stats)

.\yolov8\tests\test_python.py

# Ultralytics YOLO 🚀, AGPL-3.0 licenseimport contextlib  # 上下文管理工具
import urllib  # URL 处理模块
from copy import copy  # 复制对象的浅拷贝
from pathlib import Path  # 处理路径的对象import cv2  # OpenCV 库
import numpy as np  # 数组操作库
import pytest  # 测试框架
import torch  # PyTorch 深度学习库
import yaml  # YAML 格式处理库
from PIL import Image  # Python 图像库from tests import CFG, IS_TMP_WRITEABLE, MODEL, SOURCE, TMP  # 导入测试模块
from ultralytics import RTDETR, YOLO  # 导入 YOLO 和 RTDETR 模型类
from ultralytics.cfg import MODELS, TASK2DATA, TASKS  # 导入配置相关模块
from ultralytics.data.build import load_inference_source  # 导入数据构建函数
from ultralytics.utils import (  # 导入工具函数和变量ASSETS,DEFAULT_CFG,DEFAULT_CFG_PATH,LOGGER,ONLINE,ROOT,WEIGHTS_DIR,WINDOWS,checks,
)
from ultralytics.utils.downloads import download  # 导入下载函数
from ultralytics.utils.torch_utils import TORCH_1_9  # 导入 PyTorch 工具函数def test_model_forward():"""Test the forward pass of the YOLO model."""model = YOLO(CFG)  # 使用给定配置创建 YOLO 模型对象model(source=None, imgsz=32, augment=True)  # 测试不同参数的模型前向传播def test_model_methods():"""Test various methods and properties of the YOLO model to ensure correct functionality."""model = YOLO(MODEL)  # 使用给定模型路径创建 YOLO 模型对象# Model methodsmodel.info(verbose=True, detailed=True)  # 调用模型的信息打印方法,详细展示model = model.reset_weights()  # 重置模型的权重model = model.load(MODEL)  # 加载指定模型model.to("cpu")  # 将模型转移到 CPU 设备model.fuse()  # 融合模型model.clear_callback("on_train_start")  # 清除指定的回调函数model.reset_callbacks()  # 重置所有回调函数# Model properties_ = model.names  # 获取模型的类别名称_ = model.device  # 获取模型当前设备_ = model.transforms  # 获取模型的数据转换_ = model.task_map  # 获取模型的任务映射def test_model_profile():"""Test profiling of the YOLO model with `profile=True` to assess performance and resource usage."""from ultralytics.nn.tasks import DetectionModel  # 导入检测模型类model = DetectionModel()  # 创建检测模型对象im = torch.randn(1, 3, 64, 64)  # 创建输入张量_ = model.predict(im, profile=True)  # 使用性能分析模式进行模型预测@pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable")
def test_predict_txt():"""Tests YOLO predictions with file, directory, and pattern sources listed in a text file."""txt_file = TMP / "sources.txt"  # 创建临时文件路径with open(txt_file, "w") as f:for x in [ASSETS / "bus.jpg", ASSETS, ASSETS / "*", ASSETS / "**/*.jpg"]:f.write(f"{x}\n")  # 将多种数据源写入文本文件_ = YOLO(MODEL)(source=txt_file, imgsz=32)  # 使用文本文件中的数据源进行 YOLO 模型预测@pytest.mark.parametrize("model_name", MODELS)
def test_predict_img(model_name):"""Test YOLO model predictions on various image input types and sources, including online images."""model = YOLO(WEIGHTS_DIR / model_name)  # 使用给定模型名称加载 YOLO 模型im = cv2.imread(str(SOURCE))  # 读取输入图像为 numpy 数组assert len(model(source=Image.open(SOURCE), save=True, verbose=True, imgsz=32)) == 1  # 使用 PIL 图像进行模型预测assert len(model(source=im, save=True, save_txt=True, imgsz=32)) == 1  # 使用 numpy 数组进行模型预测assert len(model(torch.rand((2, 3, 32, 32)), imgsz=32)) == 2  # 使用 Tensor 数据进行批处理预测assert len(model(source=[im, im], save=True, save_txt=True, imgsz=32)) == 2  # 使用多个输入进行批处理预测assert len(list(model(source=[im, im], save=True, stream=True, imgsz=32))) == 2  # 使用流式数据进行预测assert len(model(torch.zeros(320, 640, 3).numpy().astype(np.uint8), imgsz=32)) == 1  # 使用 Tensor 转换为 numpy 数组进行预测batch = [str(SOURCE),  # 将 SOURCE 转换为字符串并存储在列表中,表示文件名Path(SOURCE),  # 使用 SOURCE 创建一个 Path 对象,并存储在列表中,表示路径"https://github.com/ultralytics/assets/releases/download/v0.0.0/zidane.jpg" if ONLINE else SOURCE,  # 如果 ONLINE 变量为真,则使用 GitHub 上的 URL,否则使用 SOURCE 变量,表示统一资源标识符(URI)cv2.imread(str(SOURCE)),  # 使用 OpenCV 读取 SOURCE 变量指定的图像,并将其存储在列表中Image.open(SOURCE),  # 使用 PIL 库打开 SOURCE 变量指定的图像,并将其存储在列表中np.zeros((320, 640, 3), dtype=np.uint8),  # 创建一个 320x640 大小,数据类型为 uint8 的全零数组,并存储在列表中,表示使用 numpy 库]assert len(model(batch, imgsz=32)) == len(batch)  # 断言模型处理批量数据的输出长度与输入列表 batch 的长度相同
@pytest.mark.parametrize("model", MODELS)
def test_predict_visualize(model):"""Test model prediction methods with 'visualize=True' to generate and display prediction visualizations."""# 使用不同的模型参数化测试模型的预测方法,设置 visualize=True 以生成和显示预测的可视化结果YOLO(WEIGHTS_DIR / model)(SOURCE, imgsz=32, visualize=True)def test_predict_grey_and_4ch():"""Test YOLO prediction on SOURCE converted to greyscale and 4-channel images with various filenames."""# 测试 YOLO 模型在将 SOURCE 转换为灰度图和四通道图像,并使用不同的文件名进行测试im = Image.open(SOURCE)directory = TMP / "im4"directory.mkdir(parents=True, exist_ok=True)source_greyscale = directory / "greyscale.jpg"source_rgba = directory / "4ch.png"source_non_utf = directory / "non_UTF_测试文件_tést_image.jpg"source_spaces = directory / "image with spaces.jpg"im.convert("L").save(source_greyscale)  # 将图像转换为灰度图并保存im.convert("RGBA").save(source_rgba)  # 将图像转换为四通道 PNG 并保存im.save(source_non_utf)  # 使用包含非 UTF 字符的文件名保存图像im.save(source_spaces)  # 使用包含空格的文件名保存图像# 推断过程model = YOLO(MODEL)for f in source_rgba, source_greyscale, source_non_utf, source_spaces:for source in Image.open(f), cv2.imread(str(f)), f:# 对每个文件进行模型预测,设置 save=True 和 verbose=True,imgsz=32results = model(source, save=True, verbose=True, imgsz=32)assert len(results) == 1  # 验证是否运行了一次图像预测f.unlink()  # 清理生成的临时文件@pytest.mark.slow
@pytest.mark.skipif(not ONLINE, reason="environment is offline")
def test_youtube():"""Test YOLO model on a YouTube video stream, handling potential network-related errors."""# 在 YouTube 视频流上测试 YOLO 模型,处理可能出现的网络相关错误model = YOLO(MODEL)try:model.predict("https://youtu.be/G17sBkb38XQ", imgsz=96, save=True)# 处理因网络连接问题引起的错误,例如 'urllib.error.HTTPError: HTTP Error 429: Too Many Requests'except (urllib.error.HTTPError, ConnectionError) as e:LOGGER.warning(f"WARNING: YouTube Test Error: {e}")@pytest.mark.skipif(not ONLINE, reason="environment is offline")
@pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable")
def test_track_stream():"""Tests streaming tracking on a short 10 frame video using ByteTrack tracker and different GMC methods.Note imgsz=160 required for tracking for higher confidence and better matches."""# 测试在短10帧视频上使用 ByteTrack 跟踪器和不同的全局运动补偿(GMC)方法进行实时跟踪video_url = "https://github.com/ultralytics/assets/releases/download/v0.0.0/decelera_portrait_min.mov"model = YOLO(MODEL)model.track(video_url, imgsz=160, tracker="bytetrack.yaml")  # 使用 ByteTrack 跟踪器进行跟踪model.track(video_url, imgsz=160, tracker="botsort.yaml", save_frames=True)  # 测试帧保存功能# 测试不同的全局运动补偿(GMC)方法for gmc in "orb", "sift", "ecc":with open(ROOT / "cfg/trackers/botsort.yaml", encoding="utf-8") as f:data = yaml.safe_load(f)tracker = TMP / f"botsort-{gmc}.yaml"data["gmc_method"] = gmcwith open(tracker, "w", encoding="utf-8") as f:yaml.safe_dump(data, f)model.track(video_url, imgsz=160, tracker=tracker)def test_val():# 这是一个空测试函数,没有任何代码内容# 使用 YOLO 模型的验证模式进行测试# 实例化 YOLO 类,并调用其 val 方法,传入以下参数:#   - data="coco8.yaml": 指定配置文件为 "coco8.yaml"#   - imgsz=32: 指定图像尺寸为 32#   - save_hybrid=True: 设置保存混合结果为 TrueYOLO(MODEL).val(data="coco8.yaml", imgsz=32, save_hybrid=True)
def test_train_scratch():"""Test training the YOLO model from scratch using the provided configuration."""# 创建一个 YOLO 模型对象,使用给定的配置 CFGmodel = YOLO(CFG)# 使用指定参数训练模型:数据为 coco8.yaml,训练周期为 2,图像大小为 32 像素,缓存方式为磁盘,批量大小为 -1,关闭马赛克效果,命名为 "model"model.train(data="coco8.yaml", epochs=2, imgsz=32, cache="disk", batch=-1, close_mosaic=1, name="model")# 使用模型处理 SOURCE 数据model(SOURCE)def test_train_pretrained():"""Test training of the YOLO model starting from a pre-trained checkpoint."""# 创建一个 YOLO 模型对象,从预训练的检查点 WEIGHTS_DIR / "yolov8n-seg.pt" 开始model = YOLO(WEIGHTS_DIR / "yolov8n-seg.pt")# 使用指定参数训练模型:数据为 coco8-seg.yaml,训练周期为 1,图像大小为 32 像素,缓存方式为 RAM,复制粘贴概率为 0.5,混合比例为 0.5,命名为 0model.train(data="coco8-seg.yaml", epochs=1, imgsz=32, cache="ram", copy_paste=0.5, mixup=0.5, name=0)# 使用模型处理 SOURCE 数据model(SOURCE)def test_all_model_yamls():"""Test YOLO model creation for all available YAML configurations in the `cfg/models` directory."""# 遍历 cfg/models 目录下所有的 YAML 配置文件for m in (ROOT / "cfg" / "models").rglob("*.yaml"):# 如果文件名包含 "rtdetr"if "rtdetr" in m.name:# 如果使用的是 Torch 版本 1.9 及以上if TORCH_1_9:# 创建 RTDETR 模型对象,传入 m.name 文件名,对 SOURCE 数据进行处理,图像大小为 640_ = RTDETR(m.name)(SOURCE, imgsz=640)  # 必须为 640else:# 创建 YOLO 模型对象,传入 m.name 文件名YOLO(m.name)def test_workflow():"""Test the complete workflow including training, validation, prediction, and exporting."""# 创建一个 YOLO 模型对象,使用指定的 MODELmodel = YOLO(MODEL)# 训练模型:数据为 coco8.yaml,训练周期为 1,图像大小为 32 像素,优化器选择 SGDmodel.train(data="coco8.yaml", epochs=1, imgsz=32, optimizer="SGD")# 进行模型验证,图像大小为 32 像素model.val(imgsz=32)# 对 SOURCE 数据进行预测,图像大小为 32 像素model.predict(SOURCE, imgsz=32)# 导出模型为 TorchScript 格式model.export(format="torchscript")def test_predict_callback_and_setup():"""Test callback functionality during YOLO prediction setup and execution."""def on_predict_batch_end(predictor):"""Callback function that handles operations at the end of a prediction batch."""# 获取 predictor.batch 的路径、图像和批量大小path, im0s, _ = predictor.batch# 将 im0s 转换为列表(如果不是),以便处理多图像情况im0s = im0s if isinstance(im0s, list) else [im0s]# 创建与预测结果、图像和批量大小相关联的元组列表bs = [predictor.dataset.bs for _ in range(len(path))]predictor.results = zip(predictor.results, im0s, bs)  # results is List[batch_size]# 创建一个 YOLO 模型对象,使用指定的 MODELmodel = YOLO(MODEL)# 添加 on_predict_batch_end 回调函数到模型中model.add_callback("on_predict_batch_end", on_predict_batch_end)# 加载推理数据源,获取数据集的批量大小dataset = load_inference_source(source=SOURCE)bs = dataset.bs  # noqa access predictor properties# 对数据集进行预测,流式处理,图像大小为 160 像素results = model.predict(dataset, stream=True, imgsz=160)  # source already setup# 遍历预测结果列表for r, im0, bs in results:# 打印图像形状信息print("test_callback", im0.shape)# 打印批量大小信息print("test_callback", bs)# 获取预测结果的边界框对象boxes = r.boxes  # Boxes object for bbox outputsprint(boxes)@pytest.mark.parametrize("model", MODELS)
def test_results(model):"""Ensure YOLO model predictions can be processed and printed in various formats."""# 使用指定模型 WEIGHTS_DIR / model 创建 YOLO 模型对象,并对 SOURCE 数据进行预测,图像大小为 160 像素results = YOLO(WEIGHTS_DIR / model)([SOURCE, SOURCE], imgsz=160)# 遍历预测结果列表for r in results:# 将结果转换为 CPU 上的 numpy 数组r = r.cpu().numpy()# 打印 numpy 数组的属性信息及路径print(r, len(r), r.path)  # print numpy attributes# 将结果转换为 CPU 上的 torch.float32 类型r = r.to(device="cpu", dtype=torch.float32)# 将结果保存为文本文件,保存置信度信息r.save_txt(txt_file=TMP / "runs/tests/label.txt", save_conf=True)# 将结果中的区域裁剪保存到指定目录r.save_crop(save_dir=TMP / "runs/tests/crops/")# 将结果转换为 JSON 格式,并进行规范化处理r.tojson(normalize=True)# 绘制结果的图像,返回 PIL 图像r.plot(pil=True)# 绘制结果的置信度图及边界框信息r.plot(conf=True, boxes=True)# 再次打印结果及路径信息print(r, len(r), r.path)  # print after methodsdef test_labels_and_crops():# 这个函数是空的,未提供代码pass"""Test output from prediction args for saving YOLO detection labels and crops; ensures accurate saving."""# 定义图片列表,包括源路径和指定的图像文件路径imgs = [SOURCE, ASSETS / "zidane.jpg"]# 使用预训练的 YOLO 模型处理图像列表,设置图像大小为160,保存检测结果的文本和裁剪图像results = YOLO(WEIGHTS_DIR / "yolov8n.pt")(imgs, imgsz=160, save_txt=True, save_crop=True)# 保存路径为结果中第一个元素的保存目录save_path = Path(results[0].save_dir)# 遍历每个结果for r in results:# 提取图像文件名作为标签文件名的基础im_name = Path(r.path).stem# 提取每个检测框的类别索引,转换为整数列表cls_idxs = r.boxes.cls.int().tolist()# 检查标签文件路径是否存在labels = save_path / f"labels/{im_name}.txt"assert labels.exists()  # 断言标签文件存在# 检查检测结果的数量是否与标签文件中的行数匹配assert len(r.boxes.data) == len([line for line in labels.read_text().splitlines() if line])# 获取所有裁剪图像的路径crop_dirs = list((save_path / "crops").iterdir())crop_files = [f for p in crop_dirs for f in p.glob("*")]# 断言每个类别索引对应的裁剪目录在裁剪目录中存在assert all(r.names.get(c) in {d.name for d in crop_dirs} for c in cls_idxs)# 断言裁剪文件数量与检测框数量相匹配assert len([f for f in crop_files if im_name in f.name]) == len(r.boxes.data)
@pytest.mark.skipif(not ONLINE, reason="environment is offline")
# 标记为跳过测试,如果环境处于离线状态
def test_data_utils():"""Test utility functions in ultralytics/data/utils.py, including dataset stats and auto-splitting."""# 导入需要测试的函数和模块from ultralytics.data.utils import HUBDatasetStats, autosplitfrom ultralytics.utils.downloads import zip_directory# from ultralytics.utils.files import WorkingDirectory# with WorkingDirectory(ROOT.parent / 'tests'):# 遍历任务列表,进行测试for task in TASKS:# 构建数据文件的路径,例如 coco8.zipfile = Path(TASK2DATA[task]).with_suffix(".zip")  # i.e. coco8.zip# 下载数据文件download(f"https://github.com/ultralytics/hub/raw/main/example_datasets/{file}", unzip=False, dir=TMP)# 创建数据集统计对象stats = HUBDatasetStats(TMP / file, task=task)# 生成数据集统计信息的 JSON 文件stats.get_json(save=True)# 处理图像数据stats.process_images()# 自动划分数据集autosplit(TMP / "coco8")# 压缩指定路径下的文件夹zip_directory(TMP / "coco8/images/val")  # zip@pytest.mark.skipif(not ONLINE, reason="environment is offline")
# 标记为跳过测试,如果环境处于离线状态
def test_data_converter():"""Test dataset conversion functions from COCO to YOLO format and class mappings."""# 导入需要测试的函数from ultralytics.data.converter import coco80_to_coco91_class, convert_coco# 下载 COCO 数据集的实例文件file = "instances_val2017.json"download(f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{file}", dir=TMP)# 将 COCO 数据集转换为 YOLO 格式convert_coco(labels_dir=TMP, save_dir=TMP / "yolo_labels", use_segments=True, use_keypoints=False, cls91to80=True)# 将 COCO80 类别映射为 COCO91 类别coco80_to_coco91_class()def test_data_annotator():"""Automatically annotate data using specified detection and segmentation models."""# 导入自动标注数据的函数from ultralytics.data.annotator import auto_annotate# 使用指定的检测和分割模型自动标注数据auto_annotate(ASSETS,det_model=WEIGHTS_DIR / "yolov8n.pt",sam_model=WEIGHTS_DIR / "mobile_sam.pt",output_dir=TMP / "auto_annotate_labels",)def test_events():"""Test event sending functionality."""# 导入事件发送功能模块from ultralytics.hub.utils import Events# 创建事件对象events = Events()events.enabled = Truecfg = copy(DEFAULT_CFG)  # does not require deepcopycfg.mode = "test"# 发送事件events(cfg)def test_cfg_init():"""Test configuration initialization utilities from the 'ultralytics.cfg' module."""# 导入配置初始化相关的函数from ultralytics.cfg import check_dict_alignment, copy_default_cfg, smart_value# 检查字典对齐性with contextlib.suppress(SyntaxError):check_dict_alignment({"a": 1}, {"b": 2})# 复制默认配置copy_default_cfg()# 删除复制的配置文件(Path.cwd() / DEFAULT_CFG_PATH.name.replace(".yaml", "_copy.yaml")).unlink(missing_ok=False)# 对多个值应用智能化处理[smart_value(x) for x in ["none", "true", "false"]]def test_utils_init():"""Test initialization utilities in the Ultralytics library."""# 导入初始化工具函数from ultralytics.utils import get_git_branch, get_git_origin_url, get_ubuntu_version, is_github_action_running# 获取 Ubuntu 版本信息get_ubuntu_version()# 检查是否在 GitHub Action 环境下运行is_github_action_running()# 获取 Git 仓库的远程 URLget_git_origin_url()# 获取 Git 分支信息get_git_branch()def test_utils_checks():"""Test various utility checks for filenames, git status, requirements, image sizes, and versions."""# 导入各种检查函数from ultralytics.utils import checks# 检查 YOLOv5u 文件名格式checks.check_yolov5u_filename("yolov5n.pt")# 检查 Git 仓库状态checks.git_describe(ROOT)# 检查项目的要求是否符合 requirements.txt 中指定的依赖checks.check_requirements()  # check requirements.txt# 检查图像大小是否在指定范围内,确保宽度和高度均不超过 600 像素checks.check_imgsz([600, 600], max_dim=1)# 检查是否可以显示图像,若不能显示则发出警告checks.check_imshow(warn=True)# 检查指定模块的版本是否符合要求,这里检查 ultralytics 模块是否至少是 8.0.0 版本checks.check_version("ultralytics", "8.0.0")# 打印当前设置和参数,用于调试和确认运行时的配置checks.print_args()
@pytest.mark.skipif(WINDOWS, reason="Windows profiling is extremely slow (cause unknown)")
# 如果在 Windows 下运行,跳过此测试,原因是 Windows 上的性能分析非常缓慢(原因不明)
def test_utils_benchmarks():"""Benchmark model performance using 'ProfileModels' from 'ultralytics.utils.benchmarks'."""# 导入性能分析工具 'ProfileModels' 来评估模型性能from ultralytics.utils.benchmarks import ProfileModels# 使用 ProfileModels 类来对 'yolov8n.yaml' 模型进行性能分析,设置图像大小为 32,最小运行时间为 1 秒,运行 3 次,预热 1 次ProfileModels(["yolov8n.yaml"], imgsz=32, min_time=1, num_timed_runs=3, num_warmup_runs=1).profile()def test_utils_torchutils():"""Test Torch utility functions including profiling and FLOP calculations."""# 导入相关模块和函数进行测试,包括性能分析和 FLOP 计算from ultralytics.nn.modules.conv import Convfrom ultralytics.utils.torch_utils import get_flops_with_torch_profiler, profile, time_sync# 创建一个随机张量作为输入x = torch.randn(1, 64, 20, 20)# 创建一个 Conv 模型实例m = Conv(64, 64, k=1, s=2)# 使用 profile 函数对模型 m 进行性能分析,运行 3 次profile(x, [m], n=3)# 使用 get_flops_with_torch_profiler 函数获取模型 m 的 FLOPget_flops_with_torch_profiler(m)# 执行时间同步操作time_sync()@pytest.mark.slow
@pytest.mark.skipif(not ONLINE, reason="environment is offline")
# 如果处于离线环境,跳过此测试
def test_utils_downloads():"""Test file download utilities from ultralytics.utils.downloads."""# 导入文件下载工具函数 get_google_drive_file_infofrom ultralytics.utils.downloads import get_google_drive_file_info# 调用 get_google_drive_file_info 函数下载特定 Google Drive 文件的信息get_google_drive_file_info("https://drive.google.com/file/d/1cqT-cJgANNrhIHCrEufUYhQ4RqiWG_lJ/view?usp=drive_link")def test_utils_ops():"""Test utility operations functions for coordinate transformation and normalization."""# 导入坐标转换和归一化等操作函数from ultralytics.utils.ops import (ltwh2xywh,ltwh2xyxy,make_divisible,xywh2ltwh,xywh2xyxy,xywhn2xyxy,xywhr2xyxyxyxy,xyxy2ltwh,xyxy2xywh,xyxy2xywhn,xyxyxyxy2xywhr,)# 使用 make_divisible 函数,确保 17 能够被 8 整除make_divisible(17, torch.tensor([8]))# 创建随机框坐标张量boxes = torch.rand(10, 4)  # xywh# 检查通过 xywh2xyxy 和 xyxy2xywh 函数的转换后的张量是否相等torch.allclose(boxes, xyxy2xywh(xywh2xyxy(boxes)))# 检查通过 xywhn2xyxy 和 xyxy2xywhn 函数的转换后的张量是否相等torch.allclose(boxes, xyxy2xywhn(xywhn2xyxy(boxes)))# 检查通过 ltwh2xywh 和 xywh2ltwh 函数的转换后的张量是否相等torch.allclose(boxes, ltwh2xywh(xywh2ltwh(boxes)))# 检查通过 xyxy2ltwh 和 ltwh2xyxy 函数的转换后的张量是否相等torch.allclose(boxes, xyxy2ltwh(ltwh2xyxy(boxes)))# 创建带有方向信息的随机框坐标张量boxes = torch.rand(10, 5)  # xywhr for OBB# 随机生成方向信息boxes[:, 4] = torch.randn(10) * 30# 检查通过 xywhr2xyxyxyxy 和 xyxyxyxy2xywhr 函数的转换后的张量是否相等,相对误差容忍度为 1e-3torch.allclose(boxes, xyxyxyxy2xywhr(xywhr2xyxyxyxy(boxes)), rtol=1e-3)def test_utils_files():"""Test file handling utilities including file age, date, and paths with spaces."""# 导入文件处理工具函数,包括文件年龄、日期和带空格路径的处理from ultralytics.utils.files import file_age, file_date, get_latest_run, spaces_in_path# 获取指定文件的年龄file_age(SOURCE)# 获取指定文件的日期file_date(SOURCE)# 获取根目录下运行记录的最新一次运行get_latest_run(ROOT / "runs")# 创建一个带有空格路径的临时目录path = TMP / "path/with spaces"path.mkdir(parents=True, exist_ok=True)# 在带有空格路径的临时目录中执行 spaces_in_path 函数,返回处理后的新路径并打印with spaces_in_path(path) as new_path:print(new_path)@pytest.mark.slow
def test_utils_patches_torch_save():"""Test torch_save backoff when _torch_save raises RuntimeError to ensure robustness."""# 导入测试函数和 mockfrom unittest.mock import MagicMock, patch# 导入要测试的函数 torch_savefrom ultralytics.utils.patches import torch_save# 创建一个 mock 对象,模拟 RuntimeError 异常mock = MagicMock(side_effect=RuntimeError)# 使用 patch 替换 _torch_save 函数,使其在调用时抛出 RuntimeError 异常with patch("ultralytics.utils.patches._torch_save", new=mock):# 断言调用 torch_save 函数时会抛出 RuntimeError 异常with pytest.raises(RuntimeError):torch_save(torch.zeros(1), TMP / "test.pt")# 断言,验证 mock 对象的方法被调用的次数是否等于 4assert mock.call_count == 4, "torch_save was not attempted the expected number of times"
def test_nn_modules_conv():"""Test Convolutional Neural Network modules including CBAM, Conv2, and ConvTranspose."""from ultralytics.nn.modules.conv import CBAM, Conv2, ConvTranspose, DWConvTranspose2d, Focusc1, c2 = 8, 16  # 输入通道数和输出通道数x = torch.zeros(4, c1, 10, 10)  # BCHW,创建一个大小为4x8x10x10的张量(批量大小x通道数x高度x宽度)# 运行所有未在测试中涵盖的模块DWConvTranspose2d(c1, c2)(x)  # 使用DWConvTranspose2d进行转置卷积操作ConvTranspose(c1, c2)(x)  # 使用ConvTranspose进行转置卷积操作Focus(c1, c2)(x)  # 使用Focus模块处理输入CBAM(c1)(x)  # 使用CBAM模块处理输入# 合并操作m = Conv2(c1, c2)  # 创建Conv2对象m.fuse_convs()  # 融合卷积操作m(x)  # 对输入x进行Conv2操作def test_nn_modules_block():"""Test various blocks in neural network modules including C1, C3TR, BottleneckCSP, C3Ghost, and C3x."""from ultralytics.nn.modules.block import C1, C3TR, BottleneckCSP, C3Ghost, C3xc1, c2 = 8, 16  # 输入通道数和输出通道数x = torch.zeros(4, c1, 10, 10)  # BCHW,创建一个大小为4x8x10x10的张量(批量大小x通道数x高度x宽度)# 运行所有未在测试中涵盖的模块C1(c1, c2)(x)  # 使用C1模块处理输入C3x(c1, c2)(x)  # 使用C3x模块处理输入C3TR(c1, c2)(x)  # 使用C3TR模块处理输入C3Ghost(c1, c2)(x)  # 使用C3Ghost模块处理输入BottleneckCSP(c1, c2)(x)  # 使用BottleneckCSP模块处理输入@pytest.mark.skipif(not ONLINE, reason="environment is offline")
def test_hub():"""Test Ultralytics HUB functionalities (e.g. export formats, logout)."""from ultralytics.hub import export_fmts_hub, logoutfrom ultralytics.hub.utils import smart_requestexport_fmts_hub()  # 调用导出格式函数logout()  # 执行注销操作smart_request("GET", "https://github.com", progress=True)  # 发起一个GET请求至GitHub@pytest.fixture
def image():"""Load and return an image from a predefined source using OpenCV."""return cv2.imread(str(SOURCE))  # 使用OpenCV从预定义源加载并返回一张图像@pytest.mark.parametrize("auto_augment, erasing, force_color_jitter",[(None, 0.0, False),("randaugment", 0.5, True),("augmix", 0.2, False),("autoaugment", 0.0, True),],
)
def test_classify_transforms_train(image, auto_augment, erasing, force_color_jitter):"""Tests classification transforms during training with various augmentations to ensure proper functionality."""from ultralytics.data.augment import classify_augmentationstransform = classify_augmentations(size=224,mean=(0.5, 0.5, 0.5),std=(0.5, 0.5, 0.5),scale=(0.08, 1.0),ratio=(3.0 / 4.0, 4.0 / 3.0),hflip=0.5,vflip=0.5,auto_augment=auto_augment,hsv_h=0.015,hsv_s=0.4,hsv_v=0.4,force_color_jitter=force_color_jitter,erasing=erasing,)transformed_image = transform(Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)))assert transformed_image.shape == (3, 224, 224)  # 断言转换后图像的形状为(3, 224, 224)assert torch.is_tensor(transformed_image)  # 断言转换后图像是一个PyTorch张量assert transformed_image.dtype == torch.float32  # 断言转换后图像的数据类型为torch.float32@pytest.mark.slow
@pytest.mark.skipif(not ONLINE, reason="environment is offline")
def test_model_tune():"""Tune YOLO model for performance improvement."""YOLO("yolov8n-pose.pt").tune(data="coco8-pose.yaml", plots=False, imgsz=32, epochs=1, iterations=2, device="cpu")# 使用 YOLO 模型加载 "yolov8n-cls.pt" 权重文件,并进行调参和微调YOLO("yolov8n-cls.pt").tune(data="imagenet10", plots=False, imgsz=32, epochs=1, iterations=2, device="cpu")
# 定义测试函数,用于测试模型嵌入(embeddings)
def test_model_embeddings():"""Test YOLO model embeddings."""# 创建 YOLO 检测模型对象,使用指定模型model_detect = YOLO(MODEL)# 创建 YOLO 分割模型对象,使用指定权重文件model_segment = YOLO(WEIGHTS_DIR / "yolov8n-seg.pt")# 分别测试批次大小为1和2的情况for batch in [SOURCE], [SOURCE, SOURCE]:  # test batch size 1 and 2# 断言检测模型返回的嵌入特征长度与批次大小相同assert len(model_detect.embed(source=batch, imgsz=32)) == len(batch)# 断言分割模型返回的嵌入特征长度与批次大小相同assert len(model_segment.embed(source=batch, imgsz=32)) == len(batch)# 使用 pytest.mark.skipif 标记,如果条件满足,则跳过该测试
@pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="YOLOWorld with CLIP is not supported in Python 3.12")
# 定义测试函数,测试支持 CLIP 的 YOLO 模型
def test_yolo_world():"""Tests YOLO world models with CLIP support, including detection and training scenarios."""# 创建 YOLO World 模型对象,加载指定模型model = YOLO("yolov8s-world.pt")  # no YOLOv8n-world model yet# 设置模型的分类类别为 ["tree", "window"]model.set_classes(["tree", "window"])# 运行模型进行目标检测,设定置信度阈值为 0.01model(SOURCE, conf=0.01)# 创建 YOLO Worldv2 模型对象,加载指定模型model = YOLO("yolov8s-worldv2.pt")  # no YOLOv8n-world model yet# 从预训练模型开始训练,最后阶段包括评估# 使用 dota8.yaml,该文件少量类别以减少 CLIP 模型推理时间model.train(data="dota8.yaml",epochs=1,imgsz=32,cache="disk",close_mosaic=1,)# 测试 WorWorldTrainerFromScratchfrom ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch# 创建 YOLO Worldv2 模型对象,加载指定模型model = YOLO("yolov8s-worldv2.yaml")  # no YOLOv8n-world model yet# 从头开始训练模型model.train(data={"train": {"yolo_data": ["dota8.yaml"]}, "val": {"yolo_data": ["dota8.yaml"]}},epochs=1,imgsz=32,cache="disk",close_mosaic=1,trainer=WorldTrainerFromScratch,)# 定义测试函数,测试 YOLOv10 模型的训练、验证和预测步骤,使用最小配置
def test_yolov10():"""Test YOLOv10 model training, validation, and prediction steps with minimal configurations."""# 创建 YOLOv10n 模型对象,加载指定模型配置文件model = YOLO("yolov10n.yaml")# 训练模型,使用 coco8.yaml 数据集,训练1轮,图像尺寸为32,使用磁盘缓存,关闭马赛克model.train(data="coco8.yaml", epochs=1, imgsz=32, close_mosaic=1, cache="disk")# 验证模型,使用 coco8.yaml 数据集,图像尺寸为32model.val(data="coco8.yaml", imgsz=32)# 进行预测,图像尺寸为32,保存文本输出和裁剪后的图像,进行数据增强model.predict(imgsz=32, save_txt=True, save_crop=True, augment=True)# 对给定的 SOURCE 数据进行预测model(SOURCE)

.\yolov8\tests\test_solutions.py

# 导入需要的库和模块
import cv2  # OpenCV库,用于图像和视频处理
import pytest  # 测试框架pytest# 从ultralytics包中导入YOLO对象检测模型和解决方案
from ultralytics import YOLO, solutions
# 从ultralytics.utils.downloads模块中导入安全下载函数
from ultralytics.utils.downloads import safe_download# 主要解决方案演示视频的下载链接
MAJOR_SOLUTIONS_DEMO = "https://github.com/ultralytics/assets/releases/download/v0.0.0/solutions_ci_demo.mp4"
# 运动监控解决方案演示视频的下载链接
WORKOUTS_SOLUTION_DEMO = "https://github.com/ultralytics/assets/releases/download/v0.0.0/solution_ci_pose_demo.mp4"# 使用pytest.mark.slow标记的测试函数,测试主要解决方案
@pytest.mark.slow
def test_major_solutions():"""Test the object counting, heatmap, speed estimation and queue management solution."""# 下载主要解决方案演示视频safe_download(url=MAJOR_SOLUTIONS_DEMO)# 加载YOLO模型,用于目标检测model = YOLO("yolov8n.pt")# 获取YOLO模型的类别名称names = model.names# 打开主要解决方案演示视频cap = cv2.VideoCapture("solutions_ci_demo.mp4")assert cap.isOpened(), "Error reading video file"# 设置感兴趣区域的四个顶点坐标region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)]# 初始化解决方案对象:目标计数器、热度图、速度估计器和队列管理器counter = solutions.ObjectCounter(reg_pts=region_points, names=names, view_img=False)heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, names=names, view_img=False)speed = solutions.SpeedEstimator(reg_pts=region_points, names=names, view_img=False)queue = solutions.QueueManager(names=names, reg_pts=region_points, view_img=False)# 循环处理视频中的每一帧while cap.isOpened():success, im0 = cap.read()if not success:break# 备份原始图像original_im0 = im0.copy()# 使用YOLO模型进行目标跟踪tracks = model.track(im0, persist=True, show=False)# 调用解决方案对象的方法处理每一帧图像并获取结果_ = counter.start_counting(original_im0.copy(), tracks)_ = heatmap.generate_heatmap(original_im0.copy(), tracks)_ = speed.estimate_speed(original_im0.copy(), tracks)_ = queue.process_queue(original_im0.copy(), tracks)# 释放视频流cap.release()# 关闭所有窗口cv2.destroyAllWindows()# 使用pytest.mark.slow标记的测试函数,测试AI健身监控解决方案
@pytest.mark.slow
def test_aigym():"""Test the workouts monitoring solution."""# 下载运动监控解决方案演示视频safe_download(url=WORKOUTS_SOLUTION_DEMO)# 加载YOLO模型,用于姿态检测model = YOLO("yolov8n-pose.pt")# 打开运动监控解决方案演示视频cap = cv2.VideoCapture("solution_ci_pose_demo.mp4")assert cap.isOpened(), "Error reading video file"# 初始化AI健身监控对象gym_object = solutions.AIGym(line_thickness=2, pose_type="squat", kpts_to_check=[5, 11, 13])# 循环处理视频中的每一帧while cap.isOpened():success, im0 = cap.read()if not success:break# 使用YOLO模型进行姿态检测results = model.track(im0, verbose=False)# 调用AI健身监控对象的方法处理每一帧图像并获取结果_ = gym_object.start_counting(im0, results)# 释放视频流cap.release()# 关闭所有窗口cv2.destroyAllWindows()# 使用pytest.mark.slow标记的测试函数,测试实例分割解决方案
@pytest.mark.slow
def test_instance_segmentation():"""Test the instance segmentation solution."""# 从ultralytics.utils.plotting模块中导入Annotator和colorsfrom ultralytics.utils.plotting import Annotator, colors# 加载YOLO模型,用于实例分割model = YOLO("yolov8n-seg.pt")# 获取YOLO模型的类别名称names = model.names# 打开主要解决方案演示视频(假设这里的视频与前面的测试相同)cap = cv2.VideoCapture("solutions_ci_demo.mp4")assert cap.isOpened(), "Error reading video file"# 循环检查视频流是否打开,如果打开则继续执行while cap.isOpened():# 从视频流中读取一帧图像,同时返回读取状态和图像数据success, im0 = cap.read()# 如果读取不成功(可能是视频流已经结束),则退出循环if not success:break# 使用模型对当前帧图像进行预测,返回预测结果results = model.predict(im0)# 创建一个注解器对象,用于在图像上绘制标注annotator = Annotator(im0, line_width=2)# 如果预测结果中包含实例的掩码信息if results[0].masks is not None:# 获取预测结果中每个实例的类别和掩码信息clss = results[0].boxes.cls.cpu().tolist()masks = results[0].masks.xy# 遍历每个实例的掩码和类别,为其添加边界框和标签for mask, cls in zip(masks, clss):# 根据类别获取对应的颜色,并设置是否使用模糊效果color = colors(int(cls), True)# 在图像上绘制带有边界框的实例掩码,并添加类别标签annotator.seg_bbox(mask=mask, mask_color=color, label=names[int(cls)])# 释放视频流资源cap.release()# 关闭所有 OpenCV 窗口,释放图形界面资源cv2.destroyAllWindows()
# 使用 pytest 的标记 @pytest.mark.slow 来标记这个测试函数为慢速测试
@pytest.mark.slow
# 定义一个测试函数,用于测试 Streamlit 预测的实时推理解决方案
def test_streamlit_predict():"""Test streamlit predict live inference solution."""# 调用 solutions 模块中的 inference 函数进行测试solutions.inference()

.\yolov8\tests\__init__.py

# 导入Ultralytics YOLO的相关模块和函数,该项目使用AGPL-3.0许可证# 从ultralytics.utils模块中导入常量和函数
from ultralytics.utils import ASSETS, ROOT, WEIGHTS_DIR, checks, is_dir_writeable# 设置用于测试的常量
# MODEL代表YOLO模型的权重文件路径,包含空格
MODEL = WEIGHTS_DIR / "path with spaces" / "yolov8n.pt"  # test spaces in path# CFG是YOLO配置文件的文件名
CFG = "yolov8n.yaml"# SOURCE是用于测试的示例图片文件路径
SOURCE = ASSETS / "bus.jpg"# TMP是用于存储测试文件的临时目录路径
TMP = (ROOT / "../tests/tmp").resolve()  # temp directory for test files# 检查临时目录TMP是否可写
IS_TMP_WRITEABLE = is_dir_writeable(TMP)# 检查CUDA是否可用
CUDA_IS_AVAILABLE = checks.cuda_is_available()# 获取CUDA设备的数量
CUDA_DEVICE_COUNT = checks.cuda_device_count()# 导出所有的常量和变量,以便在模块外部使用
__all__ = ("MODEL","CFG","SOURCE","TMP","IS_TMP_WRITEABLE","CUDA_IS_AVAILABLE","CUDA_DEVICE_COUNT",
)

Models

Welcome to the Ultralytics Models directory! Here you will find a wide variety of pre-configured model configuration files (*.yamls) that can be used to create custom YOLO models. The models in this directory have been expertly crafted and fine-tuned by the Ultralytics team to provide the best performance for a wide range of object detection and image segmentation tasks.

These model configurations cover a wide range of scenarios, from simple object detection to more complex tasks like instance segmentation and object tracking. They are also designed to run efficiently on a variety of hardware platforms, from CPUs to GPUs. Whether you are a seasoned machine learning practitioner or just getting started with YOLO, this directory provides a great starting point for your custom model development needs.

To get started, simply browse through the models in this directory and find one that best suits your needs. Once you've selected a model, you can use the provided *.yaml file to train and deploy your custom YOLO model with ease. See full details at the Ultralytics Docs, and if you need help or have any questions, feel free to reach out to the Ultralytics team for support. So, don't wait, start creating your custom YOLO model now!

Usage

Model *.yaml files may be used directly in the Command Line Interface (CLI) with a yolo command:

# Train a YOLOv8n model using the coco8 dataset for 100 epochs
yolo task=detect mode=train model=yolov8n.yaml data=coco8.yaml epochs=100

They may also be used directly in a Python environment, and accept the same arguments as in the CLI example above:

from ultralytics import YOLO# Initialize a YOLOv8n model from a YAML configuration file
model = YOLO("model.yaml")# If a pre-trained model is available, use it instead
# model = YOLO("model.pt")# Display model information
model.info()# Train the model using the COCO8 dataset for 100 epochs
model.train(data="coco8.yaml", epochs=100)

Pre-trained Model Architectures

Ultralytics supports many model architectures. Visit Ultralytics Models to view detailed information and usage. Any of these models can be used by loading their configurations or pretrained checkpoints if available.

Contribute New Models

Have you trained a new YOLO variant or achieved state-of-the-art performance with specific tuning? We'd love to showcase your work in our Models section! Contributions from the community in the form of new models, architectures, or optimizations are highly valued and can significantly enrich our repository.

By contributing to this section, you're helping us offer a wider array of model choices and configurations to the community. It's a fantastic way to share your knowledge and expertise while making the Ultralytics YOLO ecosystem even more versatile.

To get started, please consult our Contributing Guide for step-by-step instructions on how to submit a Pull Request (PR) 🛠️. Your contributions are eagerly awaited!

Let's join hands to extend the range and capabilities of the Ultralytics YOLO models 🙏!

.\yolov8\ultralytics\cfg\__init__.py

# 导入必要的库和模块
import contextlib  # 提供上下文管理工具的模块
import shutil  # 提供高级文件操作功能的模块
import subprocess  # 用于执行外部命令的模块
import sys  # 提供与 Python 解释器及其环境相关的功能
from pathlib import Path  # 提供处理路径的类和函数
from types import SimpleNamespace  # 提供创建简单命名空间的类
from typing import Dict, List, Union  # 提供类型提示支持# 从Ultralytics的utils模块中导入多个工具和变量
from ultralytics.utils import (ASSETS,  # 资源目录的路径DEFAULT_CFG,  # 默认配置文件名DEFAULT_CFG_DICT,  # 默认配置字典DEFAULT_CFG_PATH,  # 默认配置文件的路径LOGGER,  # 日志记录器RANK,  # 运行的排名ROOT,  # 根目录路径RUNS_DIR,  # 运行结果保存的目录路径SETTINGS,  # 设置信息SETTINGS_YAML,  # 设置信息的YAML文件路径TESTS_RUNNING,  # 是否正在运行测试的标志IterableSimpleNamespace,  # 可迭代的简单命名空间__version__,  # Ultralytics工具包的版本信息checks,  # 检查函数colorstr,  # 带有颜色的字符串处理函数deprecation_warn,  # 弃用警告函数yaml_load,  # 加载YAML文件的函数yaml_print,  # 打印YAML内容的函数
)# 定义有效的任务和模式集合
MODES = {"train", "val", "predict", "export", "track", "benchmark"}  # 可执行的模式集合
TASKS = {"detect", "segment", "classify", "pose", "obb"}  # 可执行的任务集合# 将任务映射到其对应的数据文件
TASK2DATA = {"detect": "coco8.yaml","segment": "coco8-seg.yaml","classify": "imagenet10","pose": "coco8-pose.yaml","obb": "dota8.yaml",
}# 将任务映射到其对应的模型文件
TASK2MODEL = {"detect": "yolov8n.pt","segment": "yolov8n-seg.pt","classify": "yolov8n-cls.pt","pose": "yolov8n-pose.pt","obb": "yolov8n-obb.pt",
}# 将任务映射到其对应的指标文件
TASK2METRIC = {"detect": "metrics/mAP50-95(B)","segment": "metrics/mAP50-95(M)","classify": "metrics/accuracy_top1","pose": "metrics/mAP50-95(P)","obb": "metrics/mAP50-95(B)",
}# 从TASKS集合中提取模型文件集合
MODELS = {TASK2MODEL[task] for task in TASKS}# 获取命令行参数,如果不存在则设置为空列表
ARGV = sys.argv or ["", ""]# 定义CLI帮助信息,说明如何使用Ultralytics 'yolo'命令
CLI_HELP_MSG = f"""Arguments received: {str(['yolo'] + ARGV[1:])}. Ultralytics 'yolo' commands use the following syntax:yolo TASK MODE ARGSWhere   TASK (optional) is one of {TASKS}MODE (required) is one of {MODES}ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.See all ARGS at https://docs.ultralytics.com/usage/cfg or with 'yolo cfg'1. Train a detection model for 10 epochs with an initial learning_rate of 0.01yolo train data=coco8.yaml model=yolov8n.pt epochs=10 lr0=0.012. Predict a YouTube video using a pretrained segmentation model at image size 320:yolo predict model=yolov8n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=3203. Val a pretrained detection model at batch-size 1 and image size 640:yolo val model=yolov8n.pt data=coco8.yaml batch=1 imgsz=6404. Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)yolo export model=yolov8n-cls.pt format=onnx imgsz=224,1285. Explore your datasets using semantic search and SQL with a simple GUI powered by Ultralytics Explorer APIyolo explorer data=data.yaml model=yolov8n.pt6. Streamlit real-time object detection on your webcam with Ultralytics YOLOv8yolo streamlit-predict7. Run special commands:yolo helpyolo checksyolo versionyolo settingsyolo copy-cfgyolo cfgDocs: https://docs.ultralytics.comCommunity: https://community.ultralytics.com
"""GitHub: https://github.com/ultralytics/ultralytics"""GitHub: https://github.com/ultralytics/ultralytics# 在代码中添加一个字符串文档注释,指向项目的GitHub页面"""
# Define keys for arg type checks
CFG_FLOAT_KEYS = {  # integer or float arguments, i.e. x=2 and x=2.0"warmup_epochs","box","cls","dfl","degrees","shear","time","workspace","batch",
}
CFG_FRACTION_KEYS = {  # fractional float arguments with 0.0<=values<=1.0"dropout","lr0","lrf","momentum","weight_decay","warmup_momentum","warmup_bias_lr","label_smoothing","hsv_h","hsv_s","hsv_v","translate","scale","perspective","flipud","fliplr","bgr","mosaic","mixup","copy_paste","conf","iou","fraction",
}
CFG_INT_KEYS = {  # integer-only arguments"epochs","patience","workers","seed","close_mosaic","mask_ratio","max_det","vid_stride","line_width","nbs","save_period",
}
CFG_BOOL_KEYS = {  # boolean-only arguments"save","exist_ok","verbose","deterministic","single_cls","rect","cos_lr","overlap_mask","val","save_json","save_hybrid","half","dnn","plots","show","save_txt","save_conf","save_crop","save_frames","show_labels","show_conf","visualize","augment","agnostic_nms","retina_masks","show_boxes","keras","optimize","int8","dynamic","simplify","nms","profile","multi_scale",
}def cfg2dict(cfg):"""Converts a configuration object to a dictionary.Args:cfg (str | Path | Dict | SimpleNamespace): Configuration object to be converted. Can be a file path,a string, a dictionary, or a SimpleNamespace object.Returns:(Dict): Configuration object in dictionary format.Examples:Convert a YAML file path to a dictionary:>>> config_dict = cfg2dict('config.yaml')Convert a SimpleNamespace to a dictionary:>>> from types import SimpleNamespace>>> config_sn = SimpleNamespace(param1='value1', param2='value2')>>> config_dict = cfg2dict(config_sn)Pass through an already existing dictionary:>>> config_dict = cfg2dict({'param1': 'value1', 'param2': 'value2'})Notes:- If cfg is a path or string, it's loaded as YAML and converted to a dictionary.- If cfg is a SimpleNamespace object, it's converted to a dictionary using vars().- If cfg is already a dictionary, it's returned unchanged."""if isinstance(cfg, (str, Path)):cfg = yaml_load(cfg)  # load dict from YAML file or stringelif isinstance(cfg, SimpleNamespace):cfg = vars(cfg)  # convert SimpleNamespace to dictionaryreturn cfgdef get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, overrides: Dict = None):"""Load and merge configuration data from a file or dictionary, with optional overrides.Args:cfg (str | Path | Dict | SimpleNamespace): Configuration source to load from.Defaults to DEFAULT_CFG_DICT if not provided.overrides (Dict): Optional dictionary containing configuration overrides.Returns:(Dict): Merged configuration dictionary.Notes:- cfg can be a YAML file path, string, dictionary, or SimpleNamespace object.- If overrides are provided, they overwrite values from cfg."""# 将 cfg 转换为字典形式,统一处理配置数据来源为不同类型的情况(文件路径、字典、SimpleNamespace 对象)cfg = cfg2dict(cfg)# 合并 overridesif overrides:# 将 overrides 转换为字典形式overrides = cfg2dict(overrides)# 如果 cfg 中没有 "save_dir" 键,则在合并过程中忽略 "save_dir" 键if "save_dir" not in cfg:overrides.pop("save_dir", None)  # 特殊的覆盖键,忽略处理# 检查 cfg 和 overrides 字典的对齐性,确保正确性check_dict_alignment(cfg, overrides)# 合并 cfg 和 overrides 字典,以 overrides 为优先cfg = {**cfg, **overrides}  # 合并 cfg 和 overrides 字典(优先使用 overrides)# 对于数字类型的 "project" 和 "name" 进行特殊处理,转换为字符串for k in "project", "name":if k in cfg and isinstance(cfg[k], (int, float)):cfg[k] = str(cfg[k])# 如果配置中 "name" 等于 "model",则将其更新为 "model" 键对应值的第一个点之前的部分if cfg.get("name") == "model":cfg["name"] = cfg.get("model", "").split(".")[0]# 发出警告信息,提示自动更新 "name" 为新值LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.")# 对配置数据进行类型和值的检查check_cfg(cfg)# 返回包含合并配置的 IterableSimpleNamespace 实例return IterableSimpleNamespace(**cfg)
# 验证和修正 Ultralytics 库的配置参数类型和值def check_cfg(cfg, hard=True):"""Checks configuration argument types and values for the Ultralytics library.This function validates the types and values of configuration arguments, ensuring correctness and convertingthem if necessary. It checks for specific key types defined in global variables such as CFG_FLOAT_KEYS,CFG_FRACTION_KEYS, CFG_INT_KEYS, and CFG_BOOL_KEYS.Args:cfg (Dict): Configuration dictionary to validate.hard (bool): If True, raises exceptions for invalid types and values; if False, attempts to convert them.Examples:>>> config = {...     'epochs': 50,     # valid integer...     'lr0': 0.01,      # valid float...     'momentum': 1.2,  # invalid float (out of 0.0-1.0 range)...     'save': 'true',   # invalid bool... }>>> check_cfg(config, hard=False)>>> print(config){'epochs': 50, 'lr0': 0.01, 'momentum': 1.2, 'save': False}  # corrected 'save' keyNotes:- The function modifies the input dictionary in-place.- None values are ignored as they may be from optional arguments.- Fraction keys are checked to be within the range [0.0, 1.0]."""# 遍历配置字典中的每个键值对for k, v in cfg.items():# 忽略值为 None 的情况,因为它们可能是可选参数的结果if v is not None:# 如果键在浮点数键集合中,但值不是 int 或 float 类型if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)):# 如果 hard 为 True,则抛出类型错误异常,否则尝试将值转换为 float 类型if hard:raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")cfg[k] = float(v)# 如果键在分数键集合中elif k in CFG_FRACTION_KEYS:# 如果值不是 int 或 float 类型,进行类型检查和可能的转换if not isinstance(v, (int, float)):if hard:raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")cfg[k] = v = float(v)# 检查分数值是否在 [0.0, 1.0] 范围内,否则抛出值错误异常if not (0.0 <= v <= 1.0):raise ValueError(f"'{k}={v}' is an invalid value. " f"Valid '{k}' values are between 0.0 and 1.0.")# 如果键在整数键集合中,但值不是 int 类型elif k in CFG_INT_KEYS and not isinstance(v, int):if hard:raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. " f"'{k}' must be an int (i.e. '{k}=8')")cfg[k] = int(v)# 如果键在布尔键集合中,但值不是 bool 类型elif k in CFG_BOOL_KEYS and not isinstance(v, bool):if hard:raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')")cfg[k] = bool(v)def get_save_dir(args, name=None):"""# 根据参数和默认设置确定输出目录路径。# 判断是否存在 args 中的 save_dir 属性,若存在则直接使用该路径if getattr(args, "save_dir", None):save_dir = args.save_direlse:# 如果不存在 save_dir 属性,则从 ultralytics.utils.files 中导入 increment_path 函数from ultralytics.utils.files import increment_path# 根据条件设定 project 的路径,若在测试环境中(TESTS_RUNNING 为真),则使用默认路径,否则使用 RUNS_DIRproject = args.project or (ROOT.parent / "tests/tmp/runs" if TESTS_RUNNING else RUNS_DIR) / args.task# 根据参数或默认值设置 name 的值,优先级顺序是提供的 name > args.name > args.modename = name or args.name or f"{args.mode}"# 使用 increment_path 函数生成一个递增的路径,以确保路径的唯一性,根据 exist_ok 参数决定是否创建新路径save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in {-1, 0} else True)# 返回生成的路径作为 Path 对象return Path(save_dir)
def _handle_deprecation(custom):"""Handles deprecated configuration keys by mapping them to current equivalents with deprecation warnings.Args:custom (Dict): Configuration dictionary potentially containing deprecated keys.Examples:>>> custom_config = {"boxes": True, "hide_labels": "False", "line_thickness": 2}>>> _handle_deprecation(custom_config)>>> print(custom_config){'show_boxes': True, 'show_labels': True, 'line_width': 2}Notes:This function modifies the input dictionary in-place, replacing deprecated keys with their currentequivalents. It also handles value conversions where necessary, such as inverting boolean values for'hide_labels' and 'hide_conf'."""# 遍历输入字典的副本,以便安全地修改原字典for key in custom.copy().keys():# 如果发现 'boxes' 键,发出弃用警告,并将其映射到 'show_boxes'if key == "boxes":deprecation_warn(key, "show_boxes")custom["show_boxes"] = custom.pop("boxes")# 如果发现 'hide_labels' 键,发出弃用警告,并根据值将其映射到 'show_labels'if key == "hide_labels":deprecation_warn(key, "show_labels")custom["show_labels"] = custom.pop("hide_labels") == "False"# 如果发现 'hide_conf' 键,发出弃用警告,并根据值将其映射到 'show_conf'if key == "hide_conf":deprecation_warn(key, "show_conf")custom["show_conf"] = custom.pop("hide_conf") == "False"# 如果发现 'line_thickness' 键,发出弃用警告,并将其映射到 'line_width'if key == "line_thickness":deprecation_warn(key, "line_width")custom["line_width"] = custom.pop("line_thickness")# 返回更新后的自定义配置字典return customdef check_dict_alignment(base: Dict, custom: Dict, e=None):"""Checks alignment between custom and base configuration dictionaries, handling deprecated keys and providing errormessages for mismatched keys.Args:base (Dict): The base configuration dictionary containing valid keys.custom (Dict): The custom configuration dictionary to be checked for alignment.e (Exception | None): Optional error instance passed by the calling function.Raises:SystemExit: If mismatched keys are found between the custom and base dictionaries.Examples:>>> base_cfg = {'epochs': 50, 'lr0': 0.01, 'batch_size': 16}>>> custom_cfg = {'epoch': 100, 'lr': 0.02, 'batch_size': 32}>>> try:...     check_dict_alignment(base_cfg, custom_cfg)... except SystemExit:...     print("Mismatched keys found")Notes:- Suggests corrections for mismatched keys based on similarity to valid keys.- Automatically replaces deprecated keys in the custom configuration with updated equivalents.- Prints detailed error messages for each mismatched key to help users correct their configurations."""# 处理自定义配置中的弃用键,将其更新为当前版本的等效键custom = _handle_deprecation(custom)# 获取基础配置和自定义配置的键集合base_keys, custom_keys = (set(x.keys()) for x in (base, custom))# 找出自定义配置中存在但基础配置中不存在的键mismatched = [k for k in custom_keys if k not in base_keys]# 如果存在不匹配的情况,则执行以下代码块if mismatched:# 导入模块 difflib 中的 get_close_matches 函数from difflib import get_close_matches# 初始化空字符串,用于存储错误信息string = ""# 遍历所有不匹配的项for x in mismatched:# 使用 get_close_matches 函数寻找在 base_keys 中与 x 最接近的匹配项matches = get_close_matches(x, base_keys)  # key list# 将匹配项转换为字符串,如果 base 中存在对应项,则添加其值matches = [f"{k}={base[k]}" if base.get(k) is not None else k for k in matches]# 如果有找到匹配项,生成匹配信息字符串match_str = f"Similar arguments are i.e. {matches}." if matches else ""# 构造错误信息字符串,指出不是有效 YOLO 参数的项及其可能的匹配项string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n"# 抛出 SyntaxError 异常,包含错误信息和 CLI_HELP_MSG 的帮助信息raise SyntaxError(string + CLI_HELP_MSG) from e
# 处理命令行参数列表中隔离的 '=',合并相关参数
def merge_equals_args(args: List[str]) -> List[str]:"""Merges arguments around isolated '=' in a list of strings, handling three cases:1. ['arg', '=', 'val'] becomes ['arg=val'],2. ['arg=', 'val'] becomes ['arg=val'],3. ['arg', '=val'] becomes ['arg=val'].Args:args (List[str]): A list of strings where each element represents an argument.Returns:(List[str]): A list of strings where the arguments around isolated '=' are merged.Examples:>>> args = ["arg1", "=", "value", "arg2=", "value2", "arg3", "=value3"]>>> merge_equals_args(args)['arg1=value', 'arg2=value2', 'arg3=value3']"""new_args = []for i, arg in enumerate(args):if arg == "=" and 0 < i < len(args) - 1:  # merge ['arg', '=', 'val']new_args[-1] += f"={args[i + 1]}"del args[i + 1]elif arg.endswith("=") and i < len(args) - 1 and "=" not in args[i + 1]:  # merge ['arg=', 'val']new_args.append(f"{arg}{args[i + 1]}")del args[i + 1]elif arg.startswith("=") and i > 0:  # merge ['arg', '=val']new_args[-1] += argelse:new_args.append(arg)return new_args# 处理 Ultralytics HUB 命令行接口 (CLI) 命令,用于认证
def handle_yolo_hub(args: List[str]) -> None:"""Handles Ultralytics HUB command-line interface (CLI) commands for authentication.This function processes Ultralytics HUB CLI commands such as login and logout. It should be called when executing ascript with arguments related to HUB authentication.Args:args (List[str]): A list of command line arguments. The first argument should be either 'login'or 'logout'. For 'login', an optional second argument can be the API key.Examples:```bashyolo hub login YOUR_API_KEY```Notes:- The function imports the 'hub' module from ultralytics to perform login and logout operations.- For the 'login' command, if no API key is provided, an empty string is passed to the login function.- The 'logout' command does not require any additional arguments."""from ultralytics import hubif args[0] == "login":key = args[1] if len(args) > 1 else ""# 使用提供的 API 密钥登录到 Ultralytics HUBhub.login(key)elif args[0] == "logout":# 从 Ultralytics HUB 注销hub.logout()# 处理 YOLO 设置命令行接口 (CLI) 命令
def handle_yolo_settings(args: List[str]) -> None:"""Handles YOLO settings command-line interface (CLI) commands.This function processes YOLO settings CLI commands such as reset and updating individual settings. It should becalled when executing a script with arguments related to YOLO settings management.Args:args (List[str]): A list of command line arguments for YOLO settings management."""url = "https://docs.ultralytics.com/quickstart/#ultralytics-settings"  # 帮助文档的URLtry:# 如果有任何参数if any(args):# 如果第一个参数是"reset"if args[0] == "reset":SETTINGS_YAML.unlink()  # 删除设置文件SETTINGS.reset()  # 创建新的设置LOGGER.info("Settings reset successfully")  # 提示用户设置已成功重置else:  # 否则,保存一个新的设置# 生成键值对字典,解析每个参数new = dict(parse_key_value_pair(a) for a in args)# 检查新设置和现有设置的对齐情况check_dict_alignment(SETTINGS, new)# 更新设置SETTINGS.update(new)LOGGER.info(f"💡 Learn about settings at {url}")  # 提示用户查看设置文档yaml_print(SETTINGS_YAML)  # 打印当前的设置到YAML文件except Exception as e:# 捕获异常并记录警告信息,提醒用户查看帮助文档LOGGER.warning(f"WARNING ⚠️ settings error: '{e}'. Please see {url} for help.")
# 检查并确保 'streamlit' 包的版本符合要求(至少为1.29.0)
checks.check_requirements("streamlit>=1.29.0")
# 输出日志信息,指示正在加载 Explorer 仪表板
LOGGER.info("💡 Loading Explorer dashboard...")
# 定义运行 Streamlit 的命令行参数列表
cmd = ["streamlit", "run", ROOT / "data/explorer/gui/dash.py", "--server.maxMessageSize", "2048"]
# 将命令行参数转换成字典形式,解析其中的键值对
new = dict(parse_key_value_pair(a) for a in args)
# 检查并对齐参数字典的默认值与自定义值
check_dict_alignment(base={k: DEFAULT_CFG_DICT[k] for k in ["model", "data"]}, custom=new)
# 遍历自定义参数字典,将其键值对添加到命令行参数列表中
for k, v in new.items():cmd += [k, v]
# 运行拼装好的命令行参数列表,启动 Streamlit 应用
subprocess.run(cmd)Notes:- Split the input string `pair` into two parts based on the first '=' character.- Remove leading and trailing whitespace from both `k` (key) and `v` (value).- Raise an assertion error if `v` (value) becomes empty after stripping."""k, v = pair.split("=", 1)  # split on first '=' signk, v = k.strip(), v.strip()  # remove spacesassert v, f"missing '{k}' value"return k, smart_value(v)
# Ultralytics入口函数,用于解析和执行命令行参数
def entrypoint(debug=""):"""Ultralytics entrypoint function for parsing and executing command-line arguments.This function serves as the main entry point for the Ultralytics CLI, parsing command-line arguments andexecuting the corresponding tasks such as training, validation, prediction, exporting models, and more.Args:debug (str): Space-separated string of command-line arguments for debugging purposes.Examples:Train a detection model for 10 epochs with an initial learning_rate of 0.01:>>> entrypoint("train data=coco8.yaml model=yolov8n.pt epochs=10 lr0=0.01")Predict a YouTube video using a pretrained segmentation model at image size 320:>>> entrypoint("predict model=yolov8n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320")Validate a pretrained detection model at batch-size 1 and image size 640:>>> entrypoint("val model=yolov8n.pt data=coco8.yaml batch=1 imgsz=640")Notes:- If no arguments are passed, the function will display the usage help message.- For a list of all available commands and their arguments, see the provided help messages and theUltralytics documentation at https://docs.ultralytics.com."""# 解析调试参数,若未传入参数则使用全局变量ARGVargs = (debug.split(" ") if debug else ARGV)[1:]# 若没有传入参数,则打印使用帮助信息并返回if not args:  # no arguments passedLOGGER.info(CLI_HELP_MSG)return# 定义特殊命令及其对应的操作special = {"help": lambda: LOGGER.info(CLI_HELP_MSG),  # 打印帮助信息"checks": checks.collect_system_info,  # 收集系统信息"version": lambda: LOGGER.info(__version__),  # 打印版本信息"settings": lambda: handle_yolo_settings(args[1:]),  # 处理设置命令"cfg": lambda: yaml_print(DEFAULT_CFG_PATH),  # 打印默认配置路径"hub": lambda: handle_yolo_hub(args[1:]),  # 处理hub命令"login": lambda: handle_yolo_hub(args),  # 处理登录命令"copy-cfg": copy_default_cfg,  # 复制默认配置文件"explorer": lambda: handle_explorer(args[1:]),  # 处理explorer命令"streamlit-predict": lambda: handle_streamlit_inference(),  # 处理streamlit预测命令}# 将特殊命令合并到完整的参数字典中,包括默认配置、任务和模式full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}# 定义特殊命令的常见误用,例如-h, -help, --help等,添加到特殊命令字典中special.update({k[0]: v for k, v in special.items()})  # 单数形式special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith("s")})  # 单数形式special = {**special, **{f"-{k}": v for k, v in special.items()}, **{f"--{k}": v for k, v in special.items()}}# 初始化覆盖参数字典overrides = {}# 遍历合并等号周围的参数,并进行处理for a in merge_equals_args(args):if a.startswith("--"):# 警告:参数'a'不需要前导破折号'--',更新为'{a[2:]}'。LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require leading dashes '--', updating to '{a[2:]}'.")a = a[2:]if a.endswith(","):# 警告:参数'a'不需要尾随逗号',',更新为'{a[:-1]}'。LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.")a = a[:-1]if "=" in a:try:# 解析键值对(a),并处理特定情况下的覆盖k, v = parse_key_value_pair(a)if k == "cfg" and v is not None:  # 如果传递了自定义yaml路径LOGGER.info(f"Overriding {DEFAULT_CFG_PATH} with {v}")# 更新覆盖字典,排除键为'cfg'的条目overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != "cfg"}else:overrides[k] = vexcept (NameError, SyntaxError, ValueError, AssertionError) as e:# 检查覆盖参数时出现异常check_dict_alignment(full_args_dict, {a: ""}, e)elif a in TASKS:overrides["task"] = aelif a in MODES:overrides["mode"] = aelif a.lower() in special:# 如果参数在特殊命令中,则执行对应的操作并返回special[a.lower()]()returnelif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool):# 对于默认布尔参数,例如'yolo show',自动设为Trueoverrides[a] = Trueelif a in DEFAULT_CFG_DICT:# 抛出语法错误,提示缺少等号以设置参数值raise SyntaxError(f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign "f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}")else:# 检查参数字典对齐性,处理未知参数情况check_dict_alignment(full_args_dict, {a: ""})# 检查参数字典的键对齐性,确保没有漏掉任何参数check_dict_alignment(full_args_dict, overrides)# 获取覆盖参数中的模式(mode)mode = overrides.get("mode")if mode is None:# 如果 mode 参数为 None,则使用默认值 'predict' 或从 DEFAULT_CFG 中获取的默认模式mode = DEFAULT_CFG.mode or "predict"# 发出警告日志,指示 'mode' 参数缺失,并显示可用的模式列表 MODESLOGGER.warning(f"WARNING ⚠️ 'mode' argument is missing. Valid modes are {MODES}. Using default 'mode={mode}'.")elif mode not in MODES:# 如果 mode 参数不在预定义的模式列表 MODES 中,则抛出 ValueError 异常raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}")# Task# 从 overrides 字典中弹出 'task' 键对应的值task = overrides.pop("task", None)if task:if task not in TASKS:# 如果提供的 task 不在 TASKS 列表中,则抛出 ValueError 异常raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")if "model" not in overrides:# 如果 'model' 不在 overrides 中,则设置 'model' 为 TASK2MODEL[task]overrides["model"] = TASK2MODEL[task]# Model# 从 overrides 字典中弹出 'model' 键对应的值,如果不存在,则使用 DEFAULT_CFG 中的默认模型model = overrides.pop("model", DEFAULT_CFG.model)if model is None:# 如果 model 仍为 None,则使用默认模型 'yolov8n.pt',并发出警告日志model = "yolov8n.pt"LOGGER.warning(f"WARNING ⚠️ 'model' argument is missing. Using default 'model={model}'.")# 更新 overrides 字典中的 'model' 键为当前的 model 值overrides["model"] = model# 获取模型文件的基本文件名,并转换为小写stem = Path(model).stem.lower()# 根据模型文件名的特征选择合适的模型类if "rtdetr" in stem:  # 猜测架构from ultralytics import RTDETR# 使用 RTDETR 类初始化模型对象,没有指定 task 参数model = RTDETR(model)elif "fastsam" in stem:from ultralytics import FastSAM# 使用 FastSAM 类初始化模型对象model = FastSAM(model)elif "sam" in stem:from ultralytics import SAM# 使用 SAM 类初始化模型对象model = SAM(model)else:from ultralytics import YOLO# 使用 YOLO 类初始化模型对象,并传入 task 参数model = YOLO(model, task=task)if isinstance(overrides.get("pretrained"), str):# 如果 overrides 中的 'pretrained' 是字符串类型,则加载预训练模型model.load(overrides["pretrained"])# Task Update# 如果指定的 task 与 model 的 task 不一致,则更新 taskif task != model.task:if task:# 发出警告日志,指示传入的 task 与模型的 task 不匹配LOGGER.warning(f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. "f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model.")task = model.task# Mode# 根据 mode 执行不同的逻辑if mode in {"predict", "track"} and "source" not in overrides:# 如果 mode 是 'predict' 或 'track',并且 overrides 中没有 'source',则使用默认的数据源 ASSETSoverrides["source"] = DEFAULT_CFG.source or ASSETSLOGGER.warning(f"WARNING ⚠️ 'source' argument is missing. Using default 'source={overrides['source']}'.")elif mode in {"train", "val"}:if "data" not in overrides and "resume" not in overrides:# 如果 mode 是 'train' 或 'val',并且 overrides 中没有 'data' 和 'resume',则使用默认的数据配置overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)LOGGER.warning(f"WARNING ⚠️ 'data' argument is missing. Using default 'data={overrides['data']}'.")elif mode == "export":if "format" not in overrides:# 如果 mode 是 'export',并且 overrides 中没有 'format',则使用默认的导出格式 'torchscript'overrides["format"] = DEFAULT_CFG.format or "torchscript"LOGGER.warning(f"WARNING ⚠️ 'format' argument is missing. Using default 'format={overrides['format']}'.")# 在模型对象上调用指定的 mode 方法,传入 overrides 字典中的参数getattr(model, mode)(**overrides)  # default args from model# Show help# 输出提示信息,指示用户查阅模式相关的文档LOGGER.info(f"💡 Learn more at https://docs.ultralytics.com/modes/{mode}")
# Special modes --------------------------------------------------------------------------------------------------------
def copy_default_cfg():"""Copies the default configuration file and creates a new one with '_copy' appended to its name.This function duplicates the existing default configuration file (DEFAULT_CFG_PATH) and saves itwith '_copy' appended to its name in the current working directory. It provides a convenient wayto create a custom configuration file based on the default settings.Examples:>>> copy_default_cfg()# Output: default.yaml copied to /path/to/current/directory/default_copy.yaml# Example YOLO command with this new custom cfg:#   yolo cfg='/path/to/current/directory/default_copy.yaml' imgsz=320 batch=8Notes:- The new configuration file is created in the current working directory.- After copying, the function prints a message with the new file's location and an exampleYOLO command demonstrating how to use the new configuration file.- This function is useful for users who want to modify the default configuration withoutaltering the original file."""# 创建新文件路径,将默认配置文件复制到当前工作目录并在文件名末尾添加 '_copy'new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace(".yaml", "_copy.yaml")# 使用 shutil 库的 copy2 函数复制 DEFAULT_CFG_PATH 指定的文件到新的文件路径shutil.copy2(DEFAULT_CFG_PATH, new_file)# 记录信息到日志,包括已复制的文件路径和示例 YOLO 命令,指导如何使用新的配置文件LOGGER.info(f"{DEFAULT_CFG_PATH} copied to {new_file}\n"f"Example YOLO command with this new custom cfg:\n    yolo cfg='{new_file}' imgsz=320 batch=8")if __name__ == "__main__":# Example: entrypoint(debug='yolo predict model=yolov8n.pt')# 当作为主程序运行时,调用 entrypoint 函数并传递一个空的 debug 参数entrypoint(debug="")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.ryyt.cn/news/56843.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈,一经查实,立即删除!

相关文章

Yolov8-源码解析-二十八-

Yolov8 源码解析(二十八) .\yolov8\ultralytics\data\base.py # Ultralytics YOLO 🚀, AGPL-3.0 licenseimport glob # 导入用于获取文件路径的模块 import math # 导入数学函数模块 import os # 导入操作系统功能模块 import random # 导入生成随机数的模块 from copy…

Yolov8-源码解析-八-

Yolov8 源码解析(八)comments: true description: Learn how to manage and optimize queues using Ultralytics YOLOv8 to reduce wait times and increase efficiency in various real-world applications. keywords: queue management, YOLOv8, Ultralytics, reduce wait …

FLUX 源码解析(全)

.\flux\demo_gr.py # 导入操作系统相关模块 import os # 导入时间相关模块 import time # 从 io 模块导入 BytesIO 类 from io import BytesIO # 导入 UUID 生成模块 import uuid# 导入 PyTorch 库 import torch # 导入 Gradio 库 import gradio as gr # 导入 NumPy 库 import …

【优技教育】Oracle 19c OCP 082题库(第13题)- 2024年修正版

【优技教育】Oracle 19c OCP 082题库(Q 13题)- 2024年修正版 考试科目:1Z0-082 考试题量:90 通过分数:60% 考试时间:150min 本文为(CUUG 原创)整理并解析,转发请注明出处,禁止抄袭及未经注明出处的转载。 原文地址:http://www.cuug.com.cn/ocp/082kaoshitiku/3817564823…

最快捷查看电脑启动项内容

很多人好奇很多电脑的默认启动项从哪里的看,其实就在运行窗口开两个命令就行了。 第一个,看先用户端设置的启动项: shell:Startup 这个是针对当前登录用户的。 第二个,查看电脑最高权限的通用启动项shell:Common Startup 这个是针对所有用户的。 操作的方式很简单 就是把要…

react中使用echarts关系图

一,工作需求,展示几类数据关系,可缩放大小,可拖拽位置,在节点之间的连线上展示相关日期,每个节点展示本身信息,并且要求每个关系节点能点击。 实现情况如图所示:二,实现过程中遇到的问题: 关系图完美呈现,但关系节点点击后,整个关系图会杂乱无章的浮动,导致不知道…

易基因:中国农大田见晖教授团队揭示DNA甲基化保护早期胚胎线粒体基因组稳定性|项目文章

大家好,这里是专注表观组学十余年,领跑多组学科研服务的易基因。 在早期哺乳动物胚胎中,线粒体氧化代谢增强是着床后生存和发育的重要特征;着床前期的线粒体重塑是正常胚胎发生的关键事件。在这些变化中,氧化磷酸化(OXPHOS)增强对于支持着床后胚胎的高能量需求至关重要,…

WebShell流量特征检测_哥斯拉篇

80后用菜刀,90后用蚁剑,95后用冰蝎和哥斯拉,以phpshell连接为例,本文主要是对这四款经典的webshell管理工具进行流量分析和检测。 什么是一句话木马? 1、定义 顾名思义就是执行恶意指令的木马,通过技术手段上传到指定服务器并可以正常访问,将我们需要服务器执行的命令上…