Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions mmdet/models/losses/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ def cross_entropy(pred,
"""
# The default value of ignore_index is the same as F.cross_entropy
ignore_index = -100 if ignore_index is None else ignore_index

# label[0] are categorical labels of images
if isinstance(label, tuple) and len(label) > 1:
label = label[0]

# element-wise losses
loss = F.cross_entropy(
pred,
Expand Down Expand Up @@ -117,6 +122,10 @@ def binary_cross_entropy(pred,
# The default value of ignore_index is the same as F.cross_entropy
ignore_index = -100 if ignore_index is None else ignore_index

# label[0] are categorical labels of images
if isinstance(label, tuple) and len(label) > 1:
label = label[0]

if pred.dim() != label.dim():
label, weight, valid_mask = _expand_onehot_labels(
label, weight, pred.size(-1), ignore_index)
Expand Down Expand Up @@ -191,6 +200,11 @@ def mask_cross_entropy(pred,
assert ignore_index is None, 'BCE loss does not support ignore_index'
# TODO: handle these two reserved arguments
assert reduction == 'mean' and avg_factor is None

# label[0] are categorical labels of images
if isinstance(label, tuple) and len(label) > 1:
label = label[0]

num_rois = pred.size()[0]
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
pred_slice = pred[inds, label].squeeze(1)
Expand Down