Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 63 additions & 4 deletions mmdet/datasets/dataset_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ def __init__(self,
for dataset in self.datasets:
meta_keys |= dataset.metainfo.keys()
# if the metainfo of multiple datasets are the same, use metainfo
# of the first dataset, else the metainfo is a list with metainfo
# of all the datasets
# of the first dataset, else try to merge the class lists from
# different datasets into a unified metainfo dict
is_all_same = True
self._metainfo_first = self.datasets[0].metainfo
for i, dataset in enumerate(self.datasets, 1):
Expand All @@ -241,20 +241,79 @@ def __init__(self,
if is_all_same:
self._metainfo = self.datasets[0].metainfo
else:
self._metainfo = [dataset.metainfo for dataset in self.datasets]
self._metainfo = self._merge_metainfo()

self._fully_initialized = False
if not lazy_init:
self.full_init()

if is_all_same:
if isinstance(self._metainfo, dict):
self._metainfo.update(
dict(cumulative_sizes=self.cumulative_sizes))
else:
for i, dataset in enumerate(self.datasets):
self._metainfo[i].update(
dict(cumulative_sizes=self.cumulative_sizes))

def _merge_metainfo(self) -> Union[dict, List[dict]]:
"""Merge metainfo from all sub-datasets.

When datasets have different ``classes`` (e.g. concatenating a
CocoDataset with a VOCDataset), merge the class lists into a
unified set so that downstream evaluation receives a single
consistent metainfo dict instead of a list.

Returns:
Union[dict, List[dict]]: Merged metainfo dict when class
merging succeeds, otherwise a list of per-dataset
metainfo dicts (original fallback behaviour).
"""
# Check whether all datasets carry a 'classes' key
all_have_classes = all(
'classes' in ds.metainfo for ds in self.datasets)
if not all_have_classes:
return [dataset.metainfo for dataset in self.datasets]

# Build the merged class list preserving insertion order
merged_classes: List[str] = []
seen: set = set()
for dataset in self.datasets:
for cls_name in dataset.metainfo['classes']:
if cls_name not in seen:
merged_classes.append(cls_name)
seen.add(cls_name)

# Build a merged palette that matches the merged class order.
# Re-use palette colours from whichever dataset already defines
# a colour for a given class; generate a deterministic fallback
# colour for any class that has no palette entry.
cls_to_palette: dict = {}
for dataset in self.datasets:
ds_classes = dataset.metainfo.get('classes', ())
ds_palette = dataset.metainfo.get('palette', None)
if ds_palette is not None and len(ds_palette) == len(ds_classes):
for cls_name, colour in zip(ds_classes, ds_palette):
if cls_name not in cls_to_palette:
cls_to_palette[cls_name] = colour

merged_palette: List[tuple] = []
for i, cls_name in enumerate(merged_classes):
if cls_name in cls_to_palette:
merged_palette.append(cls_to_palette[cls_name])
else:
# deterministic fallback colour derived from class index
merged_palette.append(
((i * 97 + 23) % 256,
(i * 53 + 89) % 256,
(i * 131 + 47) % 256))

# Start from the first dataset's metainfo and override class
# related fields with the merged versions.
merged = copy.deepcopy(self.datasets[0].metainfo)
merged['classes'] = tuple(merged_classes)
merged['palette'] = merged_palette
return merged

def get_dataset_source(self, idx: int) -> int:
dataset_idx, _ = self._get_ori_dataset_idx(idx)
return dataset_idx
79 changes: 79 additions & 0 deletions tests/test_datasets/test_dataset_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
from unittest.mock import MagicMock

from mmdet.datasets.dataset_wrappers import ConcatDataset


def _make_mock_dataset(classes, palette=None):
"""Create a mock dataset with the given classes and palette."""
ds = MagicMock()
meta = {'classes': tuple(classes)}
if palette is not None:
meta['palette'] = list(palette)
ds.metainfo = meta
ds.__len__ = MagicMock(return_value=5)
return ds


class TestConcatDataset(unittest.TestCase):

def test_merge_metainfo_same_classes(self):
"""Datasets with identical classes should keep a single metainfo."""
ds1 = _make_mock_dataset(['cat', 'dog'])
ds2 = _make_mock_dataset(['cat', 'dog'])
concat = ConcatDataset.__new__(ConcatDataset)
concat.datasets = [ds1, ds2]
concat.ignore_keys = []
merged = concat._merge_metainfo()
# Same classes: merged dict should equal first dataset's metainfo
self.assertIsInstance(merged, dict)
self.assertEqual(merged['classes'], ('cat', 'dog'))

def test_merge_metainfo_different_classes(self):
"""Datasets with different classes should produce a merged set."""
ds1 = _make_mock_dataset(
['cat', 'dog'],
palette=[(255, 0, 0), (0, 255, 0)])
ds2 = _make_mock_dataset(
['dog', 'bird'],
palette=[(0, 255, 0), (0, 0, 255)])
concat = ConcatDataset.__new__(ConcatDataset)
concat.datasets = [ds1, ds2]
concat.ignore_keys = []
merged = concat._merge_metainfo()
self.assertIsInstance(merged, dict)
# 'cat' from ds1, 'dog' shared, 'bird' from ds2
self.assertEqual(merged['classes'], ('cat', 'dog', 'bird'))
# palette should carry over from each dataset
self.assertEqual(merged['palette'][0], (255, 0, 0)) # cat from ds1
self.assertEqual(merged['palette'][1], (0, 255, 0)) # dog from ds1
self.assertEqual(merged['palette'][2], (0, 0, 255)) # bird from ds2

def test_merge_metainfo_no_overlap(self):
"""Datasets with completely disjoint classes."""
ds1 = _make_mock_dataset(['a', 'b'])
ds2 = _make_mock_dataset(['c', 'd'])
concat = ConcatDataset.__new__(ConcatDataset)
concat.datasets = [ds1, ds2]
concat.ignore_keys = []
merged = concat._merge_metainfo()
self.assertIsInstance(merged, dict)
self.assertEqual(merged['classes'], ('a', 'b', 'c', 'd'))

def test_merge_metainfo_missing_classes_key(self):
"""If a dataset has no 'classes' key, fall back to list of dicts."""
ds1 = _make_mock_dataset(['cat'])
ds2 = MagicMock()
ds2.metainfo = {'custom_key': 'value'}
concat = ConcatDataset.__new__(ConcatDataset)
concat.datasets = [ds1, ds2]
concat.ignore_keys = []
merged = concat._merge_metainfo()
# Should fall back to list
self.assertIsInstance(merged, list)
self.assertEqual(len(merged), 2)


if __name__ == '__main__':
unittest.main()