关联比赛: Spark“数字人体”AI挑战赛——脊柱疾病智能诊断大赛
triple-Z团队答题攻略
1 赛题分析
1.1 赛题回顾
本次比赛的任务是采用模型对核磁共振的脊柱图像进行智能检测。首先需要对5个椎体和6个椎间盘进行定位,这部分实际上就是11个关键点的检测任务;之后需要对每一个关键点对应的椎体/椎间盘进行疾病分类。因此,整个比赛的任务可以分解为关键点检测和关键点分类两类大问题。
1.2 赛题分析
在开始建模之前,我们需要对数据有清晰的认识。通过对初赛训练集的分析,我们认为本次比赛存在如下三点挑战:
相比自然图像的大数据集,数据量比较少(毕竟医学影像获取和标注成本都比较高)
类别间的样本不均衡(如下图左图)
标注噪声(如下图右图)
针对上述三个挑战,我们的解决方案如下:
设计了一个简洁高效的定位+分类的单阶段检测模型
调整损失函数以适应类别不均衡
数据增强、关键点标注抖动、基于先验统计的二分类阈值调整
2 模型方法
2.1 模型整体框架
我们首先将关键点定位分解为:粗定位和精回归两个子任务,并且在关键点检测的时候同时预测对应位置的关键点类别。模型整体框架如下图,主要由三部分组成:
backbone:采用resnet18[1]的前4个特征层(16倍下采样)提取图像的特征(注:这里去掉最后一个特征层不仅减少参数量且更容易训练,效果更好)
neck:通过几个特征融合块融合多尺度特征并且扩大模型的感受野
output:最后通过四个1X1卷积分别输出:关键点粗定位、关键点y轴上的细回归、关键点x轴上的细回归、关键点分类
我们这种设计方案,相比于主流的heatmap-based关键点检测方法,不需要上采样层,不需要resnet最后一个特征提取层,使得学习任务更加简单,因此参数量和计算量都会更少(实验部分会有详细的对比分析)。
2.2 neck模块细节
neck部分的DetNet block[2]、ASPP Block[3]和SPP Block[4]都是用来融合多尺度特征且扩大模型感受野的,具体设计如下,DetNet block与原论文保持一致,ASPP/SPP Block都做了适应性的修改。
2.3 关键点粗定位
对于5个椎体+6个椎间盘共11个关键点进行检测,可以通过输出11个通道的张量分别代表11个关键点的预测。对于第k个关键点的粗定位,我们用一个网格图来表示,该图分辨率为原图下采样之后的分辨率,粗定位图上的每一个网格的值为其中心点与关键点的关系度量,如下图:
2.4 关键点细回归
有了关键点的粗定位图,我们可以找到离目标关键点最近的网格中心点,但每一个网格对应原图是一个的区域,显然直接取中心点离目标关键点有一定的误差,因此我们需要额外的两个与粗定位分辨率一样的定位细回归图(x轴和y轴两个方向),其每一个网格的值为其中心点到关键点在x/y轴上面的偏移,如下图。这里直接用偏移量的话由于范围太大了模型不太好学,因此我们把偏移量除以,使得关键点附近的网格的值都分布在1附近。
2.5 关键点分类
模型在进行粗定位和细回归的时候实际上已经学到了椎体/椎间盘的特征和位置信息,因此我们直接通过一个并行分支对相应网格位置进行分类预测。如下图,对于一个网格,如果其为椎体,那么需要一个5维度的向量表示对5个椎体的二分类;如果是椎间盘,那么需要一个4*6维度的向量表示对6个椎间盘的四分类,此外椎体内疝出(v5)可以与其他四类共存,因此我们额外采用6维度的向量进行表示。综上,一个网格应该对应35维度的向量。
2.6 损失函数
损失函数是一个多任务学习损失函数,由三部分构成:
对于关键点粗定位图的损失函数,直接采用MSELoss进行学习;
对于关键点细回归图的损失函数,首先过滤出粗定位图中激活值高的网格,而后只对细回归图上的这部分网格计算损失函数,如下图左边图中红框框所示。实际上这种做法能够使得模型更加关注目标区域,减少对无关区域的关注;
对于关键点分类图的损失函数,同样过滤出需要关注的网格,而后对于每一个网格,若该关键点是椎体,则计算前5维度向量的损失;若是椎间盘,则计算后4*6+6维度的损失。为了适应类别不平衡的问题,对于BCELoss和CrossEntropyLoss都采用了类别样本数倒数作为相应类别的权重。
总的损失函数是上面各部分的加权求和,权重我们根据经验值直接设置的。
2.7 推理过程
推理的时候取关键点粗定位图值最高的网格(关系最密切的网格),再在细回归图和分类图取相同位置的网格,即可以解析出关键点坐标和关键点类别,如下图:
2.8 其他一些小技巧
我们统计了初赛训练集上不同位置椎体的类别分布,如下图,我们发现不同位置其v1/v2的比例有所不同,靠近胸椎的椎体病变概率相对低一些,而在靠近尾椎的椎体更容易产生病变,这个发现和我们的直觉是一致的。因此我们模型会根据这个先验统计结果进行二分类的阈值调整,调整后的阈值如下图。同样地,对于椎体内疝出,我们也做了同样的策略。
此外,为了使得模型鲁棒性更强,减少标注噪声的影像,我们还采用了多种数据增强,如下图。其中对于关键点随机抖动,具体做法是对于标签关键点,我们会加上一个随机的小偏移量,以模拟医生标注可能抖动的问题。
2.9 模型融合
医学影像处理在数据不多的情况下,一般采用多折交叉验证使得模型验证更加稳定,因此我们采用了5折交叉验证训练了5个模型,而后将他们在测试集上的结果进行平均,得到最终成绩。
3 实验结果
3.1 复赛成绩及模型性能测试
3.2 定位模型与主流模型的性能对比
由于复赛数据不可见,且每天一次提交,机会宝贵,因此我们额外在关键点检测的公开数据集[6]上进行我们模型与主流模型[5]的对比实验,如下图。证明了我们模型的有效性和高效性,且更容易收敛(小模型就可以达到不错的效果)。
4 总结及展望
4.1 可行性讨论
希望我们的模型能够实际运用到现实生活中,因此在设计模型的时候就考虑了模型的高效性,并且我们实验表明该模型能够达到实时性的要求。我们设想以后患者的影像图片可以和我们模型的预测结果一并传输到医生的电脑上,医生可以根据模型的预测进行一些修正和调整,但就可以节省了大量的手动查询和定位等时间。相当于把“是什么”的问题转化成了“是不是”,感觉一定程度可以提高医生的工作效率。
4.2 总结
我们本次比赛主要做了如下几个工作:
设计了一个简洁高效的定位+分类模型
设计了合适的损失函数有效地训练模型
充足实验证明模型的有效性和高效性
结合实际考虑模型的落地方式
4.3 展望
我们认为我们的工作还有如下几点提升空间:
横断面的数据没有利用到,而医生诊断的时候往往是依据矢状位+横断面,应该进一步模拟专家诊断的过程
我们团队没有相应的专业知识,应该加入更多的一些临床专家先验知识提高模型性能
我们模型实际上是一个多任务预测模型,因此模型结构存在优化的空间
多任务的学习也是一个热门研究方向,不同任务间的权重分配我们是采用经验值设置的,可以通过一些前沿工作进一步优化
5 核心代码
核心代码主要是4个输出张量的构建,如下
6 比赛经验总结及感想
在参加比赛的时候,不要急着搭建和训练模型,需要先对比赛数据进行分析和总结。比赛的数据一般更贴近实际运用,和平时科研所用的公开数据集还是有所区别,因此在分析数据的过程中可以总结数据的难点和挑战,对应的再选用或者自己设计合适的模型。
本次比赛赛题新颖、充满挑战,并且一有问题主办方都能够及时解决,同时让我们学习到了很多知识、提高了实践能力!
7 参考资料
[1] Deep Residual Learning for Image Recognition
[2] Detnet: A backbone network for object detection
[3] Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs
[4] Spatial Pyramid Pooling in Deep Convolutional Networks for Visual Recognition
[5] Simple Baselines for Human Pose Estimation and Tracking
[6] 2d human pose estimation: New benchmark and state of the art analysis
查看更多内容,欢迎访问天池技术圈官方地址:Spark"数字人体"AI挑战赛_脊柱疾病智能诊断大赛_GPU赛道亚军比赛攻略_triple-Z团队_天池技术圈-阿里云天池