多标签分类新建模方法
Swift Lv6

常见的多标签分类方法是同时生成多个标签的logits,然后接一个sigmoid激活函数做二分类。该方法简单直接,但忽略了标签之间的相关性。虽然业界针对该问题提出了很多解决思路,但大多是任务特定,通用性不强,也不够优雅。

Transformer decoder倒是可以序列输出多个标签,但却加入了位置偏差。而标签之间是没有位置关系的,谁先谁后无所谓,只要输出全就行。这样也导致数据集不好构造。

C-Tran

General Multi-label Image Classification with Transformers 这篇论文提供了新思路,类似BERT的MLM预训练任务:通过在输入端对多个标签做随机mask,然后预测被mask的标签,从而强制模型去学习标签之间的依赖关系:

model

模型细节:
detail

params

  • Label Embeddings: 可学习的参数矩阵,由模型隐式学习到标签的语义信息和标签间依赖。有点像DETR的query
  • State Embeddings: 控制标签的mask比例,这样就跟标签学习实现了解耦,也方便在推理阶段注入全比例mask

实验结果

不说了,全是sota:

exp


Powered by Hexo & Theme Keep
Unique Visitor Page View