From 50b3b3d9ccda8959eaf32d5729f880f376076736 Mon Sep 17 00:00:00 2001 From: majianhan Date: Wed, 1 Apr 2026 23:15:21 +0800 Subject: [PATCH] [Fix] Merge class lists when concatenating datasets with different types When concatenating datasets with different class sets (e.g. CocoDataset + VOCDataset), ConcatDataset previously stored metainfo as a list of dicts. This broke downstream evaluation code that expects metainfo to be a single dict with a 'classes' key, causing IndexError in _det2json when cat_ids[label] was accessed with out-of-range labels. This fix adds a _merge_metainfo method that merges class lists from all sub-datasets into a unified set, preserving palette colours from each dataset. The merged metainfo dict is used instead of the list fallback, so evaluation works correctly with heterogeneous dataset concatenation. Fixes #8890 --- mmdet/datasets/dataset_wrappers.py | 67 ++++++++++++++++- tests/test_datasets/test_dataset_wrappers.py | 79 ++++++++++++++++++++ 2 files changed, 142 insertions(+), 4 deletions(-) create mode 100644 tests/test_datasets/test_dataset_wrappers.py diff --git a/mmdet/datasets/dataset_wrappers.py b/mmdet/datasets/dataset_wrappers.py index d4e26e07c0f..3d58fa4a5c2 100644 --- a/mmdet/datasets/dataset_wrappers.py +++ b/mmdet/datasets/dataset_wrappers.py @@ -223,8 +223,8 @@ def __init__(self, for dataset in self.datasets: meta_keys |= dataset.metainfo.keys() # if the metainfo of multiple datasets are the same, use metainfo - # of the first dataset, else the metainfo is a list with metainfo - # of all the datasets + # of the first dataset, else try to merge the class lists from + # different datasets into a unified metainfo dict is_all_same = True self._metainfo_first = self.datasets[0].metainfo for i, dataset in enumerate(self.datasets, 1): @@ -241,13 +241,13 @@ def __init__(self, if is_all_same: self._metainfo = self.datasets[0].metainfo else: - self._metainfo = [dataset.metainfo for dataset in self.datasets] + self._metainfo = self._merge_metainfo() self._fully_initialized = False if not lazy_init: self.full_init() - if is_all_same: + if isinstance(self._metainfo, dict): self._metainfo.update( dict(cumulative_sizes=self.cumulative_sizes)) else: @@ -255,6 +255,65 @@ def __init__(self, self._metainfo[i].update( dict(cumulative_sizes=self.cumulative_sizes)) + def _merge_metainfo(self) -> Union[dict, List[dict]]: + """Merge metainfo from all sub-datasets. + + When datasets have different ``classes`` (e.g. concatenating a + CocoDataset with a VOCDataset), merge the class lists into a + unified set so that downstream evaluation receives a single + consistent metainfo dict instead of a list. + + Returns: + Union[dict, List[dict]]: Merged metainfo dict when class + merging succeeds, otherwise a list of per-dataset + metainfo dicts (original fallback behaviour). + """ + # Check whether all datasets carry a 'classes' key + all_have_classes = all( + 'classes' in ds.metainfo for ds in self.datasets) + if not all_have_classes: + return [dataset.metainfo for dataset in self.datasets] + + # Build the merged class list preserving insertion order + merged_classes: List[str] = [] + seen: set = set() + for dataset in self.datasets: + for cls_name in dataset.metainfo['classes']: + if cls_name not in seen: + merged_classes.append(cls_name) + seen.add(cls_name) + + # Build a merged palette that matches the merged class order. + # Re-use palette colours from whichever dataset already defines + # a colour for a given class; generate a deterministic fallback + # colour for any class that has no palette entry. + cls_to_palette: dict = {} + for dataset in self.datasets: + ds_classes = dataset.metainfo.get('classes', ()) + ds_palette = dataset.metainfo.get('palette', None) + if ds_palette is not None and len(ds_palette) == len(ds_classes): + for cls_name, colour in zip(ds_classes, ds_palette): + if cls_name not in cls_to_palette: + cls_to_palette[cls_name] = colour + + merged_palette: List[tuple] = [] + for i, cls_name in enumerate(merged_classes): + if cls_name in cls_to_palette: + merged_palette.append(cls_to_palette[cls_name]) + else: + # deterministic fallback colour derived from class index + merged_palette.append( + ((i * 97 + 23) % 256, + (i * 53 + 89) % 256, + (i * 131 + 47) % 256)) + + # Start from the first dataset's metainfo and override class + # related fields with the merged versions. + merged = copy.deepcopy(self.datasets[0].metainfo) + merged['classes'] = tuple(merged_classes) + merged['palette'] = merged_palette + return merged + def get_dataset_source(self, idx: int) -> int: dataset_idx, _ = self._get_ori_dataset_idx(idx) return dataset_idx diff --git a/tests/test_datasets/test_dataset_wrappers.py b/tests/test_datasets/test_dataset_wrappers.py new file mode 100644 index 00000000000..72246782093 --- /dev/null +++ b/tests/test_datasets/test_dataset_wrappers.py @@ -0,0 +1,79 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest +from unittest.mock import MagicMock + +from mmdet.datasets.dataset_wrappers import ConcatDataset + + +def _make_mock_dataset(classes, palette=None): + """Create a mock dataset with the given classes and palette.""" + ds = MagicMock() + meta = {'classes': tuple(classes)} + if palette is not None: + meta['palette'] = list(palette) + ds.metainfo = meta + ds.__len__ = MagicMock(return_value=5) + return ds + + +class TestConcatDataset(unittest.TestCase): + + def test_merge_metainfo_same_classes(self): + """Datasets with identical classes should keep a single metainfo.""" + ds1 = _make_mock_dataset(['cat', 'dog']) + ds2 = _make_mock_dataset(['cat', 'dog']) + concat = ConcatDataset.__new__(ConcatDataset) + concat.datasets = [ds1, ds2] + concat.ignore_keys = [] + merged = concat._merge_metainfo() + # Same classes: merged dict should equal first dataset's metainfo + self.assertIsInstance(merged, dict) + self.assertEqual(merged['classes'], ('cat', 'dog')) + + def test_merge_metainfo_different_classes(self): + """Datasets with different classes should produce a merged set.""" + ds1 = _make_mock_dataset( + ['cat', 'dog'], + palette=[(255, 0, 0), (0, 255, 0)]) + ds2 = _make_mock_dataset( + ['dog', 'bird'], + palette=[(0, 255, 0), (0, 0, 255)]) + concat = ConcatDataset.__new__(ConcatDataset) + concat.datasets = [ds1, ds2] + concat.ignore_keys = [] + merged = concat._merge_metainfo() + self.assertIsInstance(merged, dict) + # 'cat' from ds1, 'dog' shared, 'bird' from ds2 + self.assertEqual(merged['classes'], ('cat', 'dog', 'bird')) + # palette should carry over from each dataset + self.assertEqual(merged['palette'][0], (255, 0, 0)) # cat from ds1 + self.assertEqual(merged['palette'][1], (0, 255, 0)) # dog from ds1 + self.assertEqual(merged['palette'][2], (0, 0, 255)) # bird from ds2 + + def test_merge_metainfo_no_overlap(self): + """Datasets with completely disjoint classes.""" + ds1 = _make_mock_dataset(['a', 'b']) + ds2 = _make_mock_dataset(['c', 'd']) + concat = ConcatDataset.__new__(ConcatDataset) + concat.datasets = [ds1, ds2] + concat.ignore_keys = [] + merged = concat._merge_metainfo() + self.assertIsInstance(merged, dict) + self.assertEqual(merged['classes'], ('a', 'b', 'c', 'd')) + + def test_merge_metainfo_missing_classes_key(self): + """If a dataset has no 'classes' key, fall back to list of dicts.""" + ds1 = _make_mock_dataset(['cat']) + ds2 = MagicMock() + ds2.metainfo = {'custom_key': 'value'} + concat = ConcatDataset.__new__(ConcatDataset) + concat.datasets = [ds1, ds2] + concat.ignore_keys = [] + merged = concat._merge_metainfo() + # Should fall back to list + self.assertIsInstance(merged, list) + self.assertEqual(len(merged), 2) + + +if __name__ == '__main__': + unittest.main()