打开APP
userphoto
未登录

开通VIP,畅享免费电子书等14项超值服

开通VIP
代码阅读

这一篇我们来看一下损失函数的定义。

class SetCriterion(nn.Module):
    """ This class computes the loss for DETR.
    The process happens in two steps:
        1) we compute hungarian assignment between ground truth boxes and the outputs of the model
        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
    """
    def __init__(self, num_classes, matcher, weight_dict, losses, focal_alpha=0.25):
        """ Create the criterion.
        Parameters:
            num_classes: number of object categories, omitting the special no-object category
            matcher: module able to compute a matching between targets and proposals
            weight_dict: dict containing as key the names of the losses and as values their relative weight.
            losses: list of all the losses to be applied. See get_loss for list of available losses.
            focal_alpha: alpha in Focal Loss
        """
        super().__init__()
        self.num_classes = num_classes
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.losses = losses
        self.focal_alpha = focal_alpha

该类定义前的注释指出DETR的损失包含两步:

  1. 计算模型输出和gt之间的二分图匹配;
  2. 对于匹配成功的数据对监督其类别和box

在初始化函数的参数里有一个matcher需要说明一下,这个是用来计算二分图匹配的nn.Module类:

class HungarianMatcher(nn.Module):
    """This class computes an assignment between the targets and the predictions of the network

    For efficiency reasons, the targets don't include the no_object. Because of this, in general,
    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
    while the others are un-matched (and thus treated as non-objects).
    """

    def __init__(self,
                 cost_class: float = 1,
                 cost_bbox: float = 1,
                 cost_giou: float = 1):
        """Creates the matcher

        Params:
            cost_class: This is the relative weight of the classification error in the matching cost
            cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
            cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
        """
        super().__init__()
        self.cost_class = cost_class
        self.cost_bbox = cost_bbox
        self.cost_giou = cost_giou
        assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"

这种两个集合数目不同的二分图匹配问题一般选择少的一方作为匹配对的数目,其余表示匹配到背景。但二分图匹配还有一种选择方式是根据能量阈值选择匹配对数的方法。初始化函数的参数导入的是类别、box的L1差异以及giou差异在匹配能量中的占比。也就是说最终的匹配能量由这三部分组成

def forward(self, outputs, targets):  # Matcher的推理函数
      with torch.no_grad():
            bs, num_queries = outputs["pred_logits"].shape[:2]

            # We flatten to compute the cost matrices in a batch
            out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid()
            out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]

            # Also concat the target labels and boxes
            tgt_ids = torch.cat([v["labels"] for v in targets])
            tgt_bbox = torch.cat([v["boxes"] for v in targets])

            # Compute the classification cost.  # 采用的focal loss
            alpha = 0.25
            gamma = 2.0
            neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
            pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
            cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]

            # Compute the L1 cost between boxes
            cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)

            # Compute the giou cost betwen boxes
            cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox),
                                             box_cxcywh_to_xyxy(tgt_bbox))

            # Final cost matrix
            C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
            C = C.view(bs, num_queries, -1).cpu()

            sizes = [len(v["boxes"]) for v in targets]  # batch中每个sample中目标的个数
            indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] # 相当于选择每个样本的sample与target的相似度矩阵进行二分匹配
            return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] # 长度为batchsize的元组list

这里首先需要注意的是整个推理过程是不参与梯度反向传导的。其次在刻画预测类别与gt的差异性时使用的是focal loss,且其参数

是固定的。最终对batch中每个样本使用匈牙利算法进行二分图匹配,获得对应的索引集合,输出格式是[(第一个样本配对的输出索引集合,第一个样本配对的gt索引集合), ...]

有个有意思的地方是,函数中没有采用循环方式分别针对每个样本计算能量矩阵,而是直接计算batch中所有的预测与所有的gt的能量矩阵,然后在通过索引的方式分别对每一个样本的能量矩阵块进行匈牙利匹配,不确定这种算法效率和循环比是否更有效

在使用Matcher获得匹配对之后,便可以对匹配对的回归和分类损失进行监督,SetCriterionforward函数主要是对最基础的输出(decoder的最后一层输出)计算了损失,可能的情况会计算辅助损失(decoder中每一个layer的输出)和two-stage的proposal损失(Encoder的最后一层对proposal的预测输出),不同情况下的损失计算方式是相同的,只是输入不同,这些损失包括labels,boxesmasks,我们主要看检测,所以我们这里忽略masks

loss_map中还有一个cardinality指标,注意到该值是不进行梯度反传的,只是用来作为模型性能度量的一个指标,表示预测的目标数与真实目标数的差异,其定义函数loss_cardinality中最重要的一句是

card_pred =  (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)  

这里card_pred表示预测为前景目标的mask,因为在制作target时,使用类别数表示背景,也即预测类别向量的最后一位表示其属于背景的概率。

在计算labels和box的损失时,出现一个函数_get_src_permutation_idx,这个函数主要是将Matcher返回的多个样本的匹配对索引拉平方便索引。举个例子,batch_size=2, query_num=4, 第一个样本的gt数位2, 第二个样本的gt数为3,那么matcher的返回可能是:
[([0,2], [0, 1]), ([1,3, 0], [2, 0, 1])], _get_src_permutation_idx的返回值idx为一个元组,即[0, 0,1,1,1](即每个匹配对对应的query所在的样本在batch中的索引)和[0, 2,1,3,0](即每个匹配对的query在每个样本所有query中的索引),这样的话
target_classes[idx] 表示选择对应的样本对应的query,进而进行gt赋值。
loss_label中需要注意的代码行:

        target_classes_onehot = target_classes_onehot[:,:,:-1]  # 最后一类是背景类
        loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1]

表示针对于gt为背景的query,其gt是全零向量,因此采用的是sigmoid+F.binary_cross_entropy_with_logits 构建Focal loss,而不是softmax。
这里有个奇怪的地方是 loss_ce有一个系数query_num, 这是应为sigmoid_focal_loss输出有一个query_num上的mean操作,所以这里可以抵消。

loss_boxes操作类似,唯一需要注意的是调用box_ops.generalized_box_iou是返回的是

的矩阵,即src和gt任意两两配对,因此需要diag操作。


以上就是DETR损失函数部分的定义,很容易阅读。下一篇我们来看看数据集定义部分。

本站仅提供存储服务,所有内容均由用户发布,如发现有害或侵权内容,请点击举报
打开APP,阅读全文并永久保存 查看更多类似文章
猜你喜欢
类似文章
【热】打开小程序,算一算2024你的财运
万字长文详解DETR:一个打开了CV界的潘多拉魔盒的存在
AAAI 2020 | DIoU和CIoU:IoU在目标检测中的正确打开方式
PolyLoss:一种将分类损失函数加入泰勒展开式的损失函数
FreeType 管理字形
(RegionProposal Network)RPN网络结构及详解
【技术综述】万字长文详解Faster RCNN源代码
更多类似文章 >>
生活服务
热点新闻
分享 收藏 导长图 关注 下载文章
绑定账号成功
后续可登录账号畅享VIP特权!
如果VIP功能使用有故障,
可点击这里联系客服!

联系客服