-
Notifications
You must be signed in to change notification settings - Fork 9.8k
Expand file tree
/
Copy pathbase_det_dataset.py
More file actions
137 lines (118 loc) · 5.56 KB
/
base_det_dataset.py
File metadata and controls
137 lines (118 loc) · 5.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List, Optional
from mmengine.dataset import BaseDataset
from mmengine.fileio import load
from mmengine.utils import is_abs
from ..registry import DATASETS
from .utils import validate_pipeline_order
@DATASETS.register_module()
class BaseDetDataset(BaseDataset):
"""Base dataset for detection.
Args:
proposal_file (str, optional): Proposals file path. Defaults to None.
file_client_args (dict): Arguments to instantiate the
corresponding backend in mmdet <= 3.0.0rc6. Defaults to None.
backend_args (dict, optional): Arguments to instantiate the
corresponding backend. Defaults to None.
return_classes (bool): Whether to return class information
for open vocabulary-based algorithms. Defaults to False.
caption_prompt (dict, optional): Prompt for captioning.
Defaults to None.
"""
def __init__(self,
*args,
seg_map_suffix: str = '.png',
proposal_file: Optional[str] = None,
file_client_args: dict = None,
backend_args: dict = None,
return_classes: bool = False,
caption_prompt: Optional[dict] = None,
**kwargs) -> None:
self.seg_map_suffix = seg_map_suffix
self.proposal_file = proposal_file
self.backend_args = backend_args
self.return_classes = return_classes
self.caption_prompt = caption_prompt
if self.caption_prompt is not None:
assert self.return_classes, \
'return_classes must be True when using caption_prompt'
if file_client_args is not None:
raise RuntimeError(
'The `file_client_args` is deprecated, '
'please use `backend_args` instead, please refer to'
'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501
)
# Validate pipeline transform order before building the dataset
# so that users get early warnings about suspicious configurations.
pipeline = kwargs.get('pipeline', None)
if pipeline is not None:
validate_pipeline_order(pipeline)
super().__init__(*args, **kwargs)
def full_init(self) -> None:
"""Load annotation file and set ``BaseDataset._fully_initialized`` to
True.
If ``lazy_init=False``, ``full_init`` will be called during the
instantiation and ``self._fully_initialized`` will be set to True. If
``obj._fully_initialized=False``, the class method decorated by
``force_full_init`` will call ``full_init`` automatically.
Several steps to initialize annotation:
- load_data_list: Load annotations from annotation file.
- load_proposals: Load proposals from proposal file, if
`self.proposal_file` is not None.
- filter data information: Filter annotations according to
filter_cfg.
- slice_data: Slice dataset according to ``self._indices``
- serialize_data: Serialize ``self.data_list`` if
``self.serialize_data`` is True.
"""
if self._fully_initialized:
return
# load data information
self.data_list = self.load_data_list()
# get proposals from file
if self.proposal_file is not None:
self.load_proposals()
# filter illegal data, such as data that has no annotations.
self.data_list = self.filter_data()
# Get subset data according to indices.
if self._indices is not None:
self.data_list = self._get_unserialized_subset(self._indices)
# serialize data_list
if self.serialize_data:
self.data_bytes, self.data_address = self._serialize_data()
self._fully_initialized = True
def load_proposals(self) -> None:
"""Load proposals from proposals file.
The `proposals_list` should be a dict[img_path: proposals]
with the same length as `data_list`. And the `proposals` should be
a `dict` or :obj:`InstanceData` usually contains following keys.
- bboxes (np.ndarry): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
- scores (np.ndarry): Classification scores, has a shape
(num_instance, ).
"""
# TODO: Add Unit Test after fully support Dump-Proposal Metric
if not is_abs(self.proposal_file):
self.proposal_file = osp.join(self.data_root, self.proposal_file)
proposals_list = load(
self.proposal_file, backend_args=self.backend_args)
assert len(self.data_list) == len(proposals_list)
for data_info in self.data_list:
img_path = data_info['img_path']
# `file_name` is the key to obtain the proposals from the
# `proposals_list`.
file_name = osp.join(
osp.split(osp.split(img_path)[0])[-1],
osp.split(img_path)[-1])
proposals = proposals_list[file_name]
data_info['proposals'] = proposals
def get_cat_ids(self, idx: int) -> List[int]:
"""Get COCO category ids by index.
Args:
idx (int): Index of data.
Returns:
List[int]: All categories in the image of specified index.
"""
instances = self.get_data_info(idx)['instances']
return [instance['bbox_label'] for instance in instances]