一、Introduction
任务定义 SGG: 给定一张图片,抽取三元组:主体(subjects)、客体(objects)、关系(relations)。其中主体、客体用bounding box框出来 PSG: SGG是用bounding box将主体、客体标出来,PSG用全景分割(panoptic segmentation)来替代bounding box
1.1 该论文试图解决什么问题?
以往的scene graph generation任中,关系存在长尾问题,本文提出HiLo架构可以有效解决该问题。
1.2 以往方法存在的问题
- 关系的类别有一个长尾效应问题,以往的方法更倾向于预测高频的关系(成为biasd methods)
- 主体-客体对的关系存在语义重叠(有多种语义关系),以往的方法倾向于只预测一种
二、Method
2.1 biased & unbiased method
biased方法:指经过统计,有些关系出现的次数是远远高于其他关系的,那么模型在预测的时候会倾向于高频关系的预测,具有这种特性的方法称为biased method。
以下是biased method、unbiased method和本文的方法预测的差异
- biased method: 预测的结果是向looking at、 beside这种常见的高频的关系
- unbiased method: 预测的结果主要的是向chasing、playing这类低频的词
- HiLo:既有低频的关系也有高频关系
2.2 overview
整体结构如下(还是比较复杂的)
先看中间的结构,该结构来自于mask2former,mask2former的思想又来自于maskfomer和DETR,所以想要清楚的了解该结构,需要把这3篇论文看一下。下面只是简述。
图(b)解释
- 该网络结构分为上下两个分支,其中上面(H-L)部分用于预测低频关系,下面(L-H)部分用预测高频关系。
- Triplet Query: 源自DETR,相当于可学习的位置编码;固定数量(mask2former中取100);经过decoder后和Pixel Decoder得到的feature相乘,得到N个mask
- Task Heads: 这里需要产生3个类别(subject、object、related)的预测,网络结构:three linear classifiers ;2个mask(subject和object的mask)的预测, 网络结构:2个MLP后得到的embeding与feature相乘得到mask
- Masked relation attention: 该结果没有出现在图中,但是这个mask attention是mask2former相较于maskformer最大的创新点,核心思想就是在计算注意力事,每个object只和做注意力计算,而不是和全图做注意力
该处loss如下:
$$L_{baseline}=\lambda_1 \cdot L_{so_{cls}}+ \lambda_2 \cdot L_{so_mask}+ \lambda_2 \cdot L_{re\_{cls}}$$
图(a)解释
b中两个分支的数据其实来自一这里,先将原始数据的关系排序(降序)
- H-L data: H-L data指:High to low, 简单理解,向上平移一下,那么最高频的关系去掉,最低频的关系变成2个,那么生成的数据的关系相较于之前变低频了
- L-H data: 与上相反
Relation augmentation
除了上面的数据处理,HiLo还提出一个数据增强的方式。因为原始的数据标注中可能丢掉了很多关系,该方式可以增加大量的关系,分为两种情况
- 如果标注的数据中已经标注了关系,那么对subject-object对预测所有关系(训练数据中所有的关系)的概率,选择预测结果为标注的关系的概率作为阈值,其它所有的关系概率大家该阈值的,那么就添加该关系到训练数据中
- 如果本来没有标注关系,那么阈值就选取预测类别为no-relation类别的概率,预测的关系概率大于该阈值的添加到训练数据中
图(c)解释
在图b中虽然分为高频和低频两个分支,但是对应同一张图,预测的结果应该是一样的(包括subject、object的类别和mask)。基于这个思想,这里产生两个监督信号。一个针对类别(Subject-object consistency loss),一个针对于mask(Relation consistency loss)。 loss如下
- Subject-object consistency loss
- Relation consistency loss
注:该处RIE详细操作不是很明白,感觉有点像上面H-L data、L-Hdata的处理方式一样,把最高频和最低频的关系去掉,再计算loss
total loss $$L=L_{base}+L_{obj}+L_{rel}$$
Experiments
- 在PSG数据集上与其他方法比较
- 效果展示