diff --git a/mmdet/models/losses/cross_entropy_loss.py b/mmdet/models/losses/cross_entropy_loss.py index 49fac7743ce..ab149852f6f 100644 --- a/mmdet/models/losses/cross_entropy_loss.py +++ b/mmdet/models/losses/cross_entropy_loss.py @@ -39,6 +39,11 @@ def cross_entropy(pred, """ # The default value of ignore_index is the same as F.cross_entropy ignore_index = -100 if ignore_index is None else ignore_index + + # label[0] are categorical labels of images + if isinstance(label, tuple) and len(label) > 1: + label = label[0] + # element-wise losses loss = F.cross_entropy( pred, @@ -117,6 +122,10 @@ def binary_cross_entropy(pred, # The default value of ignore_index is the same as F.cross_entropy ignore_index = -100 if ignore_index is None else ignore_index + # label[0] are categorical labels of images + if isinstance(label, tuple) and len(label) > 1: + label = label[0] + if pred.dim() != label.dim(): label, weight, valid_mask = _expand_onehot_labels( label, weight, pred.size(-1), ignore_index) @@ -191,6 +200,11 @@ def mask_cross_entropy(pred, assert ignore_index is None, 'BCE loss does not support ignore_index' # TODO: handle these two reserved arguments assert reduction == 'mean' and avg_factor is None + + # label[0] are categorical labels of images + if isinstance(label, tuple) and len(label) > 1: + label = label[0] + num_rois = pred.size()[0] inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) pred_slice = pred[inds, label].squeeze(1)