1. 数据集准备

mmdet的数据集支持 coco格式和 voc 格式, 但 voc 格式官方只自带了少量网络模型文件, 所以推荐使用 coco 格式的数据集

2. 修改mmdet/core/evalution/class_names.py和mmdet/datasets/coco.py中的标签为自建数据集的类别

class_names.py修改如下函数:

image-20211127103931054

coco.py修改如下函数:

image-20211128182844672

重要: 修改完 class_names.py 和 voc.py 之后一定要重新编译代码,否则验证输出仍然为原类别,且训练过程中指标异常, 在根目录 mmdetection 下执行命令:

python setup.py install

重新编译

3. 简单运行训练命令, 生成配置文件

python tools/train.py configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py --work-dir work_dirs

其中 work_dirs 为你的工作目录,训练产生的日志,模型,网络结构文件会存放于此

运行完命令后,会生成一个包含所有配置信息的配置文件在mmdetection/work_dirs文件夹下面,名称与你训练的命令指定的配置文件名称一样, 如下图:

image-20211127110724326

4. 修改工作目录下生成的模型配置文件 faster_rcnn_r50_fpn_1x_voc.py 的的相关参数

4.1 修改num_classes变量

全局搜索num_classes,将其值改为自建数据集的类别数(注意不包含背景)。把搜索到的num_classes全改掉, 如下图

image-20211127110809018

4.2 修改数据加载部分的信息

搜索 data_root, 先修改数据文件的根节点目录,

image-20211127110905400

然后依次修改下面代码中的训练集, 验证集, 测试集的数据集路径位置, 举例训练集如下:

image-20211127111036870

4.3 修改训练图片大小, 训练时的 batch_size, 学习率, epoch

图片输入大小修改主要修改 img_scale, 如下图, 修改为你的图片实际输入大小

image-20211127112822780

对于 batch_size, 主要由 GPU 数量与 samples_per_gpu 参数决定

workers_per_gpu: 读取数据时每个gpu分配的线程数 。一般设置为 2即可

samples_per_gpu: 每个gpu读取的图像数量,该参数和训练时的gpu数量决定了训练时的batch_size。如下图, 由于我只有一个gpu, 该参数设置为 2, 所以 batch_size为2

image-20211127113000293

学习率设置, 位置如下:

image-20211127121447050

重要提示: 默认学习速率为8个gpu。如果使用的GPU小于或大于8个,则需要设置学习速率与GPU个数成正比,例如4个GPU的学习速率为0.01,16个GPU的学习速率为0.04。
同理 1 个 GPU, samples_per_gpu= 2, 学习速率设置为 0.0025

计算公式: 批大小(gpu_num * samples_per_gpu) / 16 * 0.02

参考:【01】MMDetection学习记录(一) mmdetecion-学习率调整-线性缩放原则

4.4 使用预训练模型训练

需要提前下载预训练模型, 可以从如下链接中获取需要的模型: open-mmlab模型库

然后修改模型配置文件 faster_rcnn_r50_fpn_1x_voc.py 中的

load_from = None
# 修改为:
load_from = 'checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'

其中具体目录要看放在哪个目录下

4.5 load_from, resume_from, pre_train 的区别

resume_from 同时加载模型权重(model weights) 和优化状态(optimizer status),且 epoch 是继承了指定 checkpoint 的信息. 一般用于意外终端的训练过程的恢复.

load_from 仅加载模型权重(model weights),训练过程的 epoch 是从 0 开始训练的, 相当于重新开始, 一般用于模型 finetuning(微调).

如果要使用, 两者都为’/work_dir/xxx/epoch_xxx.pth’ 格式,

加载顺序优先级: pretrained, resume_from> load_from , 其中如果加载了 resume_from的 断点文件, 那么久不会再加载 load_from 的文件

4.6 一些其他配置

在该文件的最下面几行有如下代码, 还可以修改其其他配置, 见注释

runner = dict(type='EpochBasedRunner', max_epochs=100)  # 训练轮次
checkpoint_config = dict(interval=1)    # 设置多久保存一次模型
log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')])  # 训练几次 iteration 保存一次日志
custom_hooks = [dict(type='NumClassCheckHook')]
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None    # 是否加载上一次训练的模型
resume_from = None 
workflow = [('train', 1)]
work_dir = 'work_dirs'  # 工作目录, 即训练结果存放位置, 在该文件夹下会生成新的配置文件
gpu_ids = range(0, 1)

5 训练

运行命令开始训练, 注意这里使用刚刚修改的在 work_dir 下生成的 faster_rcnn_r50_fpn_1x_coco.py 文件进行训练:

python tools/train.py work_dirs/faster_rcnn_r50_fpn_1x_coco.py

ps: 后面参数是你的配置文件的路径

这个工具接受以下参数:

  • --no-validate (不建议): 在训练期间关闭测试.
  • --work-dir ${WORK_DIR}: 覆盖工作目录.
  • --resume-from ${CHECKPOINT_FILE}: 从某个 checkpoint 文件继续训练.
  • --options 'Key=value': 覆盖使用的配置文件中的其他设置.

Github 项目 - mmdetection 模型训练

MMdetection官方中文文档1:使用已有模型在标准数据集上进行推理

6. 可能遇到的报错

6.1 ValueError: need at least one array to concatenate
  • 大概率是数据集有问题,可以检查数据集路径,检查json里边的类名和代码里的类名是否对应。
    我的问题就是 annotations.json 的coco标注文件中, 类别有错误, 修改后不再报错

参考: ValueError: need at least one array to concatenate的解决

6.2 在 Ubuntu 下遇到报错: AssertionError: The num_classes (3) in Shared2FCBBoxHead of MMDataParallel does not matches the length of CLASSES 80) in CocoDataset

在尝试之前修改 coco.py, class_names.py的类别并 执行python setup.py install仍然无法解决.

尝试如下方法成功: 这是因为其实跟重新编译一样,重新编译的原因就是因为环境里的源文件没有修改,所以你才会报错。mmdetection-master目录下只是一些python文件,真正运行程序时,运行的还是环境里的源文件,因为我们直接去环境里修改源文件。

假设我的conda安装在base下,因此去下面的目录下,分别修改两个文件中的类别为实际的类别:

~/anaconda3/lib/python3.8/site-packages/mmdet/datasets/coco.py
~/anaconda3/lib/python3.8/site-packages/mmdet/core/evaluation/class_names.py

修改位置及方法同上文的第 2 点

如果conda环境名为conda_env_name, 则目录可能为, 这个要具体看编译位置, 当运行报错时控制台会有显示

/anaconda3/envs/conda_env_name/lib/python3.7/site-packages/mmdet/core/evaluation/class_names.py

/anaconda3/envs/conda_env_name/lib/python3.7/site-packages/mmdet/datasets/coco.py

参考: AssertionError: The num_classes (3) in Shared2FCBBoxHead of MMDataParallel does not matches

主流程参考:mmdetection训练voc格式的自己的数据集

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐