项目概述
本项目是一个基于PyTorch框架开发的人脸口罩检测系统,能够识别图像中人物是否佩戴口罩,并区分三种状态:正确佩戴口罩(绿色框)、不正确佩戴口罩(橙色框)和未佩戴口罩(红色框)。该项目由开发者Abhinand(GitHub: abhinand5)创建,代码托管在GitHub上。
系统架构
系统采用Faster R-CNN(Region-based Convolutional Neural Network)作为基础检测模型,这是目标检测领域的一种高效算法。具体实现使用了PyTorch提供的fasterrcnn_resnet50_fpn
预训练模型,并在其基础上进行了定制化修改以适应口罩检测任务。
核心组件
- FaceMaskDetector类:系统的核心模型类,封装了模型的构建、训练、预测和保存/加载功能。
- FaceMaskDataset类:自定义数据集类,负责加载和处理Kaggle上的口罩检测数据集。
- 预测与可视化模块:包含图像预测和结果可视化的功能。
技术实现细节
模型构建
模型基于TorchVision提供的预训练Faster R-CNN模型进行构建。在FaceMaskDetector.build_model()
方法中,开发者替换了原始模型的分类器部分,以适应本项目中的三类检测任务(正确佩戴、不正确佩戴、未佩戴)。
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
pretrained=self.pretrained
)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(
in_features,
n_classes + 1 # +1 for background class
)
这种设计利用了迁移学习的优势,即在预训练模型的基础上进行微调,可以显著减少训练时间和提高模型性能。
数据处理
数据集来自Kaggle的"face-mask-detection"数据集,包含带有标注的图像。FaceMaskDataset
类负责加载图像和对应的XML格式标注文件,并将其转换为PyTorch张量格式。
标注文件解析过程中,系统将口罩佩戴状态分为三类:
- 1: 正确佩戴口罩
- 2: 不正确佩戴口罩
- 3: 未佩戴口罩
def get_label(self, obj):
if obj.find("name").text == "with_mask":
return 1
elif obj.find("name").text == "mask_weared_incorrect":
return 2
return 3
训练过程
训练过程在FaceMaskDetector.train()
方法中实现,使用了随机梯度下降(SGD)优化器,并加入了动量和权重衰减来提高模型泛化能力。
optimizer = torch.optim.SGD(
params, lr=learning_rate, momentum=0.9, weight_decay=0.0005
)
训练过程中使用了pkbar库来显示进度条,使训练过程更加直观。
预测与可视化
预测功能在FaceMaskDetector.predict()
方法中实现,可以处理单张或多张图像的预测任务。预测结果通过plot_result()
函数进行可视化,将检测框和类别信息绘制在原始图像上,并保存为PNG格式。
颜色编码方案:
- 绿色框:正确佩戴口罩
- 橙色框:不正确佩戴口罩
- 红色框:未佩戴口罩
使用说明
环境配置
项目依赖可以通过以下命令安装:
pip install -r requirements.txt
运行预测
预测单张图像的命令格式如下:
python detect_face_mask.py [PATH TO YOUR IMAGE] [FILE NAME FOR RESULT]
示例:
python detect_face_mask.py ./my_image.png my_result
模型训练
虽然文档中没有直接提供训练脚本的命令行接口,但从代码可以看出,训练过程可以通过设置相关参数来执行:
N_EPOCHS = 10
LEARNING_RATE = 0.005
model.train(n_epochs=N_EPOCHS, learning_rate=LEARNING_RATE)
模型与数据
由于模型文件和原始数据集较大,开发者将它们托管在Google Drive上:
总结
本项目展示了一个完整的基于深度学习的目标检测系统实现,从数据加载、模型构建、训练到预测和可视化的全流程。系统具有以下特点:
- 使用成熟的Faster R-CNN架构作为基础,保证了检测性能
- 通过迁移学习减少了训练时间和数据需求
- 清晰的代码结构和模块化设计提高了可维护性
- 直观的可视化结果便于理解和使用
该项目可作为计算机视觉初学者学习目标检测技术的良好示例,也可作为实际应用中人脸口罩检测的基础框架进行扩展和优化。
源码获取
完整代码已开源,包含详细的注释文档:
🔗 [GitCode仓库] https://gitcode.com/laonong-1024/python-automation-scripts
📥 [备用下载] https://pan.quark.cn/s/654cf649e5a6 提取码:f5VG