Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mmdet/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
CustomSampleSizeSampler, GroupMultiSourceSampler,
MultiSourceSampler, TrackAspectRatioBatchSampler,
TrackImgSampler)
from .utils import get_loading_pipeline
from .utils import get_loading_pipeline, validate_pipeline_order
from .v3det import V3DetDataset
from .voc import VOCDataset
from .wider_face import WIDERFaceDataset
Expand All @@ -49,5 +49,6 @@
'BaseSegDataset', 'ADE20KSegDataset', 'CocoSegDataset',
'ADE20KInstanceDataset', 'iSAIDDataset', 'V3DetDataset', 'ConcatDataset',
'ODVGDataset', 'MDETRStyleRefCocoDataset', 'DODDataset',
'CustomSampleSizeSampler', 'Flickr30kDataset'
'CustomSampleSizeSampler', 'Flickr30kDataset',
'validate_pipeline_order'
]
6 changes: 6 additions & 0 deletions mmdet/datasets/base_det_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from mmengine.utils import is_abs

from ..registry import DATASETS
from .utils import validate_pipeline_order


@DATASETS.register_module()
Expand Down Expand Up @@ -48,6 +49,11 @@ def __init__(self,
'please use `backend_args` instead, please refer to'
'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501
)
# Validate pipeline transform order before building the dataset
# so that users get early warnings about suspicious configurations.
pipeline = kwargs.get('pipeline', None)
if pipeline is not None:
validate_pipeline_order(pipeline)
super().__init__(*args, **kwargs)

def full_init(self) -> None:
Expand Down
143 changes: 143 additions & 0 deletions mmdet/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,153 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

from mmcv.transforms import LoadImageFromFile

from mmdet.datasets.transforms import LoadAnnotations, LoadPanopticAnnotations
from mmdet.registry import TRANSFORMS

# Categories used by validate_pipeline_order to detect suspicious ordering.
# Each transform is assigned to one category; the expected order is the
# numeric order of those categories. Placing a higher-category transform
# before a lower-category one triggers a warning.
_TRANSFORM_CATEGORIES = {
# --- Loading (must come first) ---
'LoadImageFromFile': 0,
'LoadMultiChannelImageFromFiles': 0,
'LoadAnnotations': 1,
'LoadPanopticAnnotations': 1,
'LoadProposals': 1,
# --- Spatial augmentation ---
'Resize': 10,
'FixScaleResize': 10,
'FixShapeResize': 10,
'ResizeShortestEdge': 10,
'RandomResize': 10,
'RandomCrop': 11,
'RandomCenterCropPad': 11,
'MinIoURandomCrop': 11,
'Expand': 11,
'RandomFlip': 12,
'RandomShift': 12,
'RandomAffine': 12,
'Mosaic': 13,
'CachedMosaic': 13,
'MixUp': 13,
'CachedMixUp': 13,
'CopyPaste': 13,
# --- Pad ---
'Pad': 20,
# --- Pixel-level augmentation (after spatial) ---
'PhotoMetricDistortion': 15,
'YOLOXHSVRandomAug': 15,
'Albu': 15,
'InstaBoost': 15,
# --- Normalize (after all augmentation) ---
'Normalize': 30,
# --- Formatting (must come last) ---
'DefaultFormatBundle': 40,
'PackDetInputs': 40,
'PackTrackInputs': 40,
'PackReIDInputs': 40,
'Collect': 41,
}

# Specific pair-wise rules: (earlier, later) means ``earlier`` must come
# before ``later`` in the pipeline. Each entry carries a human-readable
# reason shown in the warning message.
_ORDER_RULES = [
('LoadImageFromFile', 'LoadAnnotations',
'Annotations should be loaded after the image.'),
('LoadImageFromFile', 'Resize',
'Image must be loaded before resizing.'),
('Resize', 'Pad',
'Padding should happen after resizing; otherwise the resize may '
'undo the padding.'),
('Resize', 'Normalize',
'Normalize should be applied after spatial transforms like Resize.'),
('RandomFlip', 'Normalize',
'Normalize should come after spatial augmentations like RandomFlip.'),
('RandomFlip', 'Pad',
'Pad should come after RandomFlip to avoid padding artifacts being '
'flipped.'),
('Normalize', 'PackDetInputs',
'Formatting / packing must be the last step.'),
('Normalize', 'DefaultFormatBundle',
'Formatting / packing must be the last step.'),
('Pad', 'Normalize',
'Normalize should be applied after Pad so that padded values are '
'also normalized. If you intentionally normalize before padding, '
'you can ignore this warning.'),
]


def validate_pipeline_order(pipeline):
"""Check that the transforms in *pipeline* follow a sensible order.

The function emits :py:class:`UserWarning` messages for each suspicious
ordering it detects. It never raises an exception so that existing
configs keep working.

Args:
pipeline (list[dict]): Data pipeline config list, where each
element is a ``dict(type=..., ...)``.

Example:
>>> pipeline = [
... dict(type='LoadImageFromFile'),
... dict(type='Normalize', mean=[0]*3, std=[1]*3),
... dict(type='RandomFlip', prob=0.5), # after Normalize!
... dict(type='PackDetInputs'),
... ]
>>> validate_pipeline_order(pipeline) # emits a warning
"""
if not pipeline:
return

# Build a mapping from transform name to its *first* index in the
# pipeline so that we can check pair-wise rules.
first_index = {}
for idx, step in enumerate(pipeline):
name = step.get('type', '')
if name and name not in first_index:
first_index[name] = idx

# --- pair-wise rule checks ---
for early, late, reason in _ORDER_RULES:
early_idx = first_index.get(early)
late_idx = first_index.get(late)
if early_idx is not None and late_idx is not None:
if early_idx > late_idx:
warnings.warn(
f"In the data pipeline, '{early}' (index {early_idx}) "
f"appears after '{late}' (index {late_idx}), which is "
f"likely incorrect. {reason} Please review the "
f"transform order in your config.",
UserWarning,
)

# --- generic category-order check ---
prev_name = None
prev_cat = -1
for step in pipeline:
name = step.get('type', '')
cat = _TRANSFORM_CATEGORIES.get(name)
if cat is None:
# Unknown transform – skip.
continue
if cat < prev_cat:
warnings.warn(
f"In the data pipeline, '{name}' (category {cat}) appears "
f"after '{prev_name}' (category {prev_cat}). The "
f"recommended order is: Loading -> Spatial augmentation "
f"-> Pixel augmentation -> Pad -> Normalize -> "
f"Formatting. Please verify the transform order in "
f"your config.",
UserWarning,
)
prev_name = name
prev_cat = cat


def get_loading_pipeline(pipeline):
"""Only keep loading image and annotations related configuration.
Expand Down
110 changes: 110 additions & 0 deletions tests/test_datasets/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from unittest import TestCase

from mmdet.datasets.utils import validate_pipeline_order


class TestValidatePipelineOrder(TestCase):
"""Tests for :func:`validate_pipeline_order`."""

def test_correct_order_no_warning(self):
"""A standard pipeline in the correct order should not warn."""
pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', prob=0.5),
dict(type='Pad', size_divisor=32),
dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]),
dict(type='PackDetInputs'),
]
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
validate_pipeline_order(pipeline)
assert len(w) == 0, [str(x.message) for x in w]

def test_normalize_before_random_flip_warns(self):
"""Normalize before RandomFlip should trigger a warning."""
pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', scale=(1333, 800)),
dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]),
dict(type='RandomFlip', prob=0.5),
dict(type='PackDetInputs'),
]
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
validate_pipeline_order(pipeline)
msgs = [str(x.message) for x in w]
assert any('RandomFlip' in m and 'Normalize' in m for m in msgs)

def test_pad_before_resize_warns(self):
"""Pad appearing before Resize is suspicious."""
pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Pad', size_divisor=32),
dict(type='Resize', scale=(1333, 800)),
dict(type='PackDetInputs'),
]
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
validate_pipeline_order(pipeline)
msgs = [str(x.message) for x in w]
assert any('Pad' in m and 'Resize' in m for m in msgs)

def test_empty_pipeline_no_error(self):
"""An empty pipeline should not crash."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
validate_pipeline_order([])
assert len(w) == 0

def test_minimal_pipeline_no_warning(self):
"""A minimal valid pipeline should not warn."""
pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='PackDetInputs'),
]
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
validate_pipeline_order(pipeline)
assert len(w) == 0, [str(x.message) for x in w]

def test_unknown_transforms_ignored(self):
"""Transforms not in the known list should be silently skipped."""
pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='MyCustomTransform'),
dict(type='Resize', scale=(1333, 800)),
dict(type='PackDetInputs'),
]
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
validate_pipeline_order(pipeline)
assert len(w) == 0, [str(x.message) for x in w]

def test_issue_6106_bad_pipeline_warns(self):
"""The problematic pipeline from issue #6106 should trigger
warnings. The user had Normalize before Pad."""
pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=[(800, 200), (800, 700)]),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375], to_rgb=True),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
]
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
validate_pipeline_order(pipeline)
# Should warn about Pad after Normalize (or Normalize before Pad)
assert len(w) > 0, 'Expected warnings for issue #6106 pipeline'
msgs = [str(x.message) for x in w]
assert any('Pad' in m and 'Normalize' in m for m in msgs)