官网也提供了步骤,这里详细介绍下训练自己数据的过程以及中间遇到的一些问题。训练模型这里采用PointRCNN,具体的介绍参考:https://www.cnblogs.com/xiaxuexiaoab/p/18033887
一、准备数据集
数据集这一块我们需要准备好原始点云数据、物体目标标注文件、以及训练和验证对应的索引号,存放至OpenPCD/data/目录,其存放格式如下:
data
|---ImageSets
| |---train.txt // 存放索引号
| |---val.txt
|---labels
| |---idx.txt //第idx的标签
|---points
| |---idx.npy // 以npy形式存放点云数据
Points
相比于KITTI数据集,我们自己的数据通常只有点云坐标,也就是X,Y,Z信息,数据的坐标满足如下右手坐标系
labels
标签文件形式为c_x, c_y, c_z, dx, dy, dz, $\theta$,category_name
其中(c_x, c_y, c_z表示框的中心点, dx, dy, dz通常表示长宽高,$$\theta$$表示长轴与x轴夹角,以弧度制表示,逆时针为正, category_name表示物体类别)。
如:
# format: [x y z dx dy dz heading_angle category_name]
1.50 1.46 0.10 5.12 1.85 4.13 1.56 Vehicle
5.54 0.57 0.41 1.08 0.74 1.95 1.57 Pedestrian
- 偏转角
除了得到目标框的中心点和大小,还需要得到训练数据(标注工具可采用SUSTechPOINTS,SSE),我们还需要确定目标的偏转角度,角度\(\theta\)表示X轴正方向与物体方向的夹角,以弧度制表示,逆时针为正。如果物体没有明确的方向,可以采用长轴的一侧作为方向。
ImageSets
这个目录下主要存放训练和验证集的索引号,如00000, 00001等。
具体的可以参考博客给出的示例数据Tree
也可以通过百度云链接: https://pan.baidu.com/s/1FufaSbXb7z77k0PZm5h7hQ 提取码: tree
二、修改数据配置
依据数据集类别修改OpenPCD/pcdet/datasets/custom/custom_dataset.py
,将里面类别进行修改,可以进行验证
if __name__ == '__main__':import sysif sys.argv.__len__() > 1 and sys.argv[1] == 'create_custom_infos':import yamlfrom pathlib import Pathfrom easydict import EasyDictdataset_cfg = EasyDict(yaml.safe_load(open(sys.argv[2])))ROOT_DIR = (Path(__file__).resolve().parent / '../../../').resolve()create_custom_infos(dataset_cfg=dataset_cfg,class_names=['tree'],data_path=ROOT_DIR / 'data' / 'tree_data',save_path=ROOT_DIR / 'data' / 'tree_data',)
依据数据集修改OpenPCD/tools/cfgs/dataset_configs/custom_dataset.yaml
。
DATA_PATH: '../data/tree_data'POINT_CLOUD_RANGE: [0, -10.24, -1, 10.24, 10.24, 3]
# 这个地方只会在eval阶段会用到,所以如果自己不需要eval的话可以不加
MAP_CLASS_TO_KITTI: {'Tree': 'tree'
}
# 需要与自己的点云数据格式对应,一般不需要改
POINT_FEATURE_ENCODING: {encoding_type: absolute_coordinates_encoding,used_feature_list: ['x', 'y', 'z', 'intensity'],src_feature_list: ['x', 'y', 'z', 'intensity'],
}DATA_AUGMENTOR:DISABLE_AUG_LIST: ['placeholder']AUG_CONFIG_LIST:- NAME: gt_samplingUSE_ROAD_PLANE: FalseDB_INFO_PATH:- custom_dbinfos_train.pklPREPARE: {# 需要改成自己的数据集类别filter_by_min_points: ['Tree:5'],# filter_by_difficulty: [-1], # 这个地方如果不注释的话训练可能会报错,可以自己尝试一下}# 需要改成自己的数据集类别SAMPLE_GROUPS: [Tree:15']NUM_POINT_FEATURES: 4DATABASE_WITH_FAKELIDAR: FalseREMOVE_EXTRA_WIDTH: [0.0, 0.0, 0.0]LIMIT_WHOLE_SCENE: True
验证:
运行以下代码
python -m pcdet.datasets.custom.custom_dataset create_custom_infos ../../../tools/cfgs/dataset_configs/custom_dataset.yaml
成功后数据目录下会生成custom_infos_val.pkl
, custom_infos_train.pkl
, custom_dbinfos_train.pkl
。
三、修改网络配置
在tools/cfgs/目录下选择对应模型的配置文件,如OpenPCD/tools/cfgs/kitti_models/pointrcnn.yaml
或 pv_rcnn.yaml
# 要改成自己的类别
CLASS_NAMES: ['Tree']DATA_CONFIG:_BASE_CONFIG_: cfgs/dataset_configs/custom_dataset.yamlDATA_PROCESSOR:# - NAME: mask_points_and_boxes_outside_range# REMOVE_OUTSIDE_BOXES: True- NAME: sample_pointsNUM_POINTS: {'train': 7000,'test': 7000 }- NAME: shuffle_pointsSHUFFLE_ENABLED: {'train': True,'test': False}
对于基于体素的网络来说,体素大小与点云范围有比例关系,需要设置成符号条件的数值才行!!!点云范围和VOXEL_SIZE要有一定关系,Z轴/VOXEL_SIZE=40 X、Y轴/VOXEL_SIZE是16的倍数即可。
四、训练及验证
训练
以PointRCNN为例,修改pointrcnn.yaml配置后,执行以下命令
python train.py --cfg_file ./cfgs/kitti_models/pointrcnn.yaml
测试
- 测试训练后的模型
python test.py --cfg_file ${CONFIG_FILE} --batch_size ${BATCH_SIZE} --ckpt ${CKPT}
- 测试所有训练的模型,并在Tensorboard上显示, 则添加--eval_all
python test.py --cfg_file ${CONFIG_FILE} --batch_size ${BATCH_SIZE} --eval_all
五、常见问题
- ValueError: attempted relative import beyond top-level package
运行custom_dataset.py的时候出现以下错误,pycharm里面可以正常跳转,但运行会报错
from ...ops.roiaware_pool3d import roiaware_pool3d_utils
ValueError: attempted relative import beyond top-level package
解决方法
...ops.roiaware_pool3d
替换为 pcdet.ops.roiaware_pool3d
相关解释:https://stackoverflow.com/questions/30669474/beyond-top-level-package-error-in-relative-import
- Error ModuleNotFoundError: No module named 'av2'
OpenPCDet/pcdet/datasets/argo2/argo2_dataset.py", line 7, in <module>
from av2.utils.io import read_feather
ModuleNotFoundError: No module named 'av2'
解决方法
注释掉OpenPCD/pcdet/datasets/__init__.py
的15行和27行
# from .argo2.argo2_dataset import Argo2Dataset
from .custom.custom_dataset import CustomDataset__all__ = {'DatasetTemplate': DatasetTemplate,'KittiDataset': KittiDataset,'NuScenesDataset': NuScenesDataset,'WaymoDataset': WaymoDataset,'PandasetDataset': PandasetDataset,'LyftDataset': LyftDataset,'ONCEDataset': ONCEDataset,'CustomDataset': CustomDataset,# 'Argo2Dataset': Argo2Dataset
}
相关解释:https://github.com/open-mmlab/OpenPCDet/issues/1574