diff --git a/configs/sot/stark/stark-st1_r50_8xb16-500e_got10k-lasot-trackingnet-coco_base.py b/configs/sot/stark/stark-st1_r50_8xb16-500e_got10k-lasot-trackingnet-coco_base.py index d963d714f..d23754842 100644 --- a/configs/sot/stark/stark-st1_r50_8xb16-500e_got10k-lasot-trackingnet-coco_base.py +++ b/configs/sot/stark/stark-st1_r50_8xb16-500e_got10k-lasot-trackingnet-coco_base.py @@ -27,51 +27,44 @@ out_channels=256, kernel_size=1, act_cfg=None), + encoder=dict( + num_layers=6, + layer_cfg=dict( # DetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.1, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.1, + act_cfg=dict(type='ReLU', inplace=True)))), + decoder=dict( + num_layers=6, + layer_cfg=dict( # DetrTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.1, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.1, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.1, + act_cfg=dict(type='ReLU', inplace=True))), + return_intermediate=False), + num_queries=1, + positional_encoding=dict(num_feats=128, normalize=True), head=dict( type='StarkHead', - num_querys=1, - transformer=dict( - type='StarkTransformer', - encoder=dict( - type='mmdet.DetrTransformerEncoder', - num_layers=6, - transformerlayers=dict( - type='BaseTransformerLayer', - attn_cfgs=[ - dict( - type='MultiheadAttention', - embed_dims=256, - num_heads=8, - attn_drop=0.1, - dropout_layer=dict(type='Dropout', drop_prob=0.1)) - ], - ffn_cfgs=dict( - feedforward_channels=2048, - embed_dims=256, - ffn_drop=0.1), - operation_order=('self_attn', 'norm', 'ffn', 'norm'))), - decoder=dict( - type='mmdet.DetrTransformerDecoder', - return_intermediate=False, - num_layers=6, - transformerlayers=dict( - type='BaseTransformerLayer', - attn_cfgs=dict( - type='MultiheadAttention', - embed_dims=256, - num_heads=8, - attn_drop=0.1, - dropout_layer=dict(type='Dropout', drop_prob=0.1)), - ffn_cfgs=dict( - feedforward_channels=2048, - embed_dims=256, - ffn_drop=0.1), - operation_order=('self_attn', 'norm', 'cross_attn', 'norm', - 'ffn', 'norm'))), - ), - positional_encoding=dict( - type='mmdet.SinePositionalEncoding', num_feats=128, - normalize=True), bbox_head=dict( type='CornerPredictorHead', inplanes=256, diff --git a/configs/vis/mask2former/mask2former_r101_8xb2-8e_youtubevis2019.py b/configs/vis/mask2former/mask2former_r101_8xb2-8e_youtubevis2019.py index 04e7460e8..77c18bde5 100644 --- a/configs/vis/mask2former/mask2former_r101_8xb2-8e_youtubevis2019.py +++ b/configs/vis/mask2former/mask2former_r101_8xb2-8e_youtubevis2019.py @@ -7,6 +7,7 @@ checkpoint='torchvision://resnet101')), init_cfg=dict( type='Pretrained', - checkpoint='https://download.openmmlab.com/mmdetection/v2.0/' - 'mask2former/mask2former_r101_lsj_8x2_50e_coco/' - 'mask2former_r101_lsj_8x2_50e_coco_20220426_100250-c50b6fa6.pth')) + checkpoint='https://download.openmmlab.com/mmdetection/v3.0/' + 'mask2former/mask2former_r101_8xb2-lsj-50e_coco-panoptic' + '/mask2former_r101_8xb2-lsj-50e_coco-' + 'panoptic_20220329_225104-c74d4d71.pth')) diff --git a/configs/vis/mask2former/mask2former_r101_8xb2-8e_youtubevis2021.py b/configs/vis/mask2former/mask2former_r101_8xb2-8e_youtubevis2021.py index c890adbae..e4133711c 100644 --- a/configs/vis/mask2former/mask2former_r101_8xb2-8e_youtubevis2021.py +++ b/configs/vis/mask2former/mask2former_r101_8xb2-8e_youtubevis2021.py @@ -7,6 +7,7 @@ checkpoint='torchvision://resnet101')), init_cfg=dict( type='Pretrained', - checkpoint='https://download.openmmlab.com/mmdetection/v2.0/' - 'mask2former/mask2former_r101_lsj_8x2_50e_coco/' - 'mask2former_r101_lsj_8x2_50e_coco_20220426_100250-c50b6fa6.pth')) + checkpoint='https://download.openmmlab.com/mmdetection/v3.0/' + 'mask2former/mask2former_r101_8xb2-lsj-50e_coco-panoptic' + '/mask2former_r101_8xb2-lsj-50e_' + 'coco-panoptic_20220329_225104-c74d4d71.pth')) diff --git a/configs/vis/mask2former/mask2former_r50_8xb2-8e_youtubevis2019.py b/configs/vis/mask2former/mask2former_r50_8xb2-8e_youtubevis2019.py index c503f828b..abc88d926 100644 --- a/configs/vis/mask2former/mask2former_r50_8xb2-8e_youtubevis2019.py +++ b/configs/vis/mask2former/mask2former_r50_8xb2-8e_youtubevis2019.py @@ -39,64 +39,47 @@ num_outs=3, norm_cfg=dict(type='GN', num_groups=32), act_cfg=dict(type='ReLU'), - encoder=dict( - type='mmdet.DetrTransformerEncoder', + encoder=dict( # DeformableDetrTransformerEncoder num_layers=6, - transformerlayers=dict( - type='BaseTransformerLayer', - attn_cfgs=dict( - type='MultiScaleDeformableAttention', + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention embed_dims=256, num_heads=8, num_levels=3, num_points=4, - im2col_step=128, dropout=0.0, - batch_first=False, - norm_cfg=None, - init_cfg=None), - ffn_cfgs=dict( - type='FFN', + batch_first=True), + ffn_cfg=dict( embed_dims=256, feedforward_channels=1024, num_fcs=2, ffn_drop=0.0, - act_cfg=dict(type='ReLU', inplace=True)), - operation_order=('self_attn', 'norm', 'ffn', 'norm')), - init_cfg=None), - positional_encoding=dict( - type='mmdet.SinePositionalEncoding', - num_feats=128, - normalize=True), + act_cfg=dict(type='ReLU', inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True), init_cfg=None), enforce_decoder_input_project=False, positional_encoding=dict( type='SinePositionalEncoding3D', num_feats=128, normalize=True), transformer_decoder=dict( - type='mmdet.DetrTransformerDecoder', return_intermediate=True, num_layers=9, - transformerlayers=dict( - type='mmdet.DetrTransformerDecoderLayer', - attn_cfgs=dict( - type='MultiheadAttention', + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention embed_dims=256, num_heads=8, - attn_drop=0.0, - proj_drop=0.0, - dropout_layer=None, - batch_first=False), - ffn_cfgs=dict( + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( embed_dims=256, feedforward_channels=2048, num_fcs=2, - act_cfg=dict(type='ReLU', inplace=True), ffn_drop=0.0, - dropout_layer=None, - add_identity=True), - feedforward_channels=2048, - operation_order=('cross_attn', 'norm', 'self_attn', 'norm', - 'ffn', 'norm')), + act_cfg=dict(type='ReLU', inplace=True))), init_cfg=None), loss_cls=dict( type='mmdet.CrossEntropyLoss', @@ -138,9 +121,10 @@ sampler=dict(type='mmdet.MaskPseudoSampler'))), init_cfg=dict( type='Pretrained', - checkpoint='https://download.openmmlab.com/mmdetection/v2.0/' - 'mask2former/mask2former_r50_lsj_8x2_50e_coco/' - 'mask2former_r50_lsj_8x2_50e_coco_20220506_191028-8e96e88b.pth')) + checkpoint='https://download.openmmlab.com/mmdetection/v3.0/' + 'mask2former/mask2former_r50_8xb2-lsj-50e_coco-panoptic/' + 'mask2former_r50_8xb2-lsj-50e_' + 'coco-panoptic_20230114_094547-7add5fa8.pth')) # optimizer embed_multi = dict(lr_mult=1.0, decay_mult=0.0) diff --git a/configs/vis/mask2former/mask2former_swin-l-p4-w12-384-in21k_8xb2-8e_youtubevis2021.py b/configs/vis/mask2former/mask2former_swin-l-p4-w12-384-in21k_8xb2-8e_youtubevis2021.py index d4c70ba56..68b36aca9 100644 --- a/configs/vis/mask2former/mask2former_swin-l-p4-w12-384-in21k_8xb2-8e_youtubevis2021.py +++ b/configs/vis/mask2former/mask2former_swin-l-p4-w12-384-in21k_8xb2-8e_youtubevis2021.py @@ -29,10 +29,11 @@ init_cfg=dict( type='Pretrained', checkpoint= # noqa: E251 - 'https://download.openmmlab.com/mmdetection/v2.0/mask2former/' - 'mask2former_swin-l-p4-w12-384-in21k_lsj_16x1_100e_coco-panoptic/' - 'mask2former_swin-l-p4-w12-384-in21k_lsj_16x1_100e_coco-panoptic_' - '20220407_104949-d4919c44.pth')) + 'https://download.openmmlab.com/mmdetection/v3.0' + '/mask2former/mask2former_swin-l-p4-w12-384-' + 'in21k_16xb1-lsj-100e_coco-panoptic' + '/mask2former_swin-l-p4-w12-384-in21k_16xb1-lsj-100e' + '_coco-panoptic_20220407_104949-82f8d28d.pth')) # set all layers in backbone to lr_mult=0.1 # set all norm layers, position_embeding, diff --git a/mmtrack/models/sot/stark.py b/mmtrack/models/sot/stark.py index d85753022..a50b064f6 100644 --- a/mmtrack/models/sot/stark.py +++ b/mmtrack/models/sot/stark.py @@ -5,8 +5,11 @@ import torch import torch.nn.functional as F +from mmdet.models.layers import (DetrTransformerDecoder, + DetrTransformerEncoder, + SinePositionalEncoding) from mmdet.structures.bbox.transforms import bbox_xyxy_to_cxcywh -from torch import Tensor +from torch import Tensor, nn from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.conv import _ConvNd @@ -43,7 +46,11 @@ class Stark(BaseSingleObjectTracker): def __init__(self, backbone: dict, neck: Optional[dict] = None, - head: Optional[dict] = None, + encoder: OptConfigType = None, + decoder: OptConfigType = None, + head: OptConfigType = None, + positional_encoding: OptConfigType = None, + num_queries: int = 100, pretrains: Optional[dict] = None, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None, @@ -71,6 +78,29 @@ def __init__(self, if frozen_modules is not None: self.freeze_module(frozen_modules) + self.encoder = encoder + self.decoder = decoder + self.positional_encoding = positional_encoding + self.num_queries = num_queries + self._init_layers() + + def _init_layers(self) -> None: + """Initialize layers except for backbone, neck and bbox_head.""" + self.positional_encoding = SinePositionalEncoding( + **self.positional_encoding) + self.encoder = DetrTransformerEncoder(**self.encoder) + self.decoder = DetrTransformerDecoder(**self.decoder) + self.embed_dims = self.encoder.embed_dims + # NOTE The embed_dims is typically passed from the inside out. + # For example in DETR, The embed_dims is passed as + # self_attn -> the first encoder layer -> encoder -> detector. + self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims) + + num_feats = self.positional_encoding.num_feats + assert num_feats * 2 == self.embed_dims, \ + 'embed_dims should be exactly 2 times of num_feats. ' \ + f'Found {self.embed_dims} and {num_feats}.' + def init_weights(self): """Initialize the weights of modules in single object tracker.""" # We don't use the `init_weights()` function in BaseModule, since it @@ -87,6 +117,11 @@ def init_weights(self): if self.with_head: self.head.init_weights() + for coder in self.encoder, self.decoder: + for p in coder.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + def extract_feat(self, img: Tensor) -> Tensor: """Extract the features of the input image. @@ -300,6 +335,82 @@ def loss(self, inputs: dict, data_samples: List[TrackDataSample], x_dict = dict(feat=x_feat, mask=search_padding_mask[:, 0]) head_inputs.append(x_dict) - losses = self.head.loss(head_inputs, data_samples) + outs_dec, enc_mem = self.forward_transformer(head_inputs) + losses = self.head.loss(head_inputs, outs_dec, enc_mem, data_samples) return losses + + def forward_transformer(self, inputs): + # 1. preprocess inputs for transformer + all_inputs = [] + for input in inputs: + feat = input['feat'][0] + feat_size = feat.shape[-2:] + mask = F.interpolate( + input['mask'][None].float(), size=feat_size).to(torch.bool)[0] + pos_embed = self.positional_encoding(mask) + all_inputs.append(dict(feat=feat, mask=mask, pos_embed=pos_embed)) + all_inputs = self.head._merge_template_search(all_inputs) + + # 2. forward transformer head + # outs_dec is in (1, bs, num_query, c) shape + # enc_mem is in (feats_flatten_len, bs, c) shape + outs_dec, enc_mem = self.transformer( + all_inputs['feat'].permute(1, 0, 2), all_inputs['mask'], + self.query_embedding.weight, + all_inputs['pos_embed'].permute(1, 0, 2)) + return outs_dec, enc_mem + + def transformer(self, x: Tensor, mask: Tensor, query_embed: Tensor, + pos_embed: Tensor) -> Tuple[Tensor, Tensor]: + """Forward function for `StarkTransformer`. + + The difference with transofrmer module in `MMCV` is the input shape. + The sizes of template feature maps and search feature maps are + different. Thus, we must flatten and concatenate them outside this + module. The `MMCV` flatten the input features inside tranformer module. + + Args: + x (Tensor): Input query with shape (feats_flatten_len, bs, c) + where c = embed_dims. + mask (Tensor): The key_padding_mask used for encoder and decoder, + with shape (bs, feats_flatten_len). + query_embed (Tensor): The query embedding for decoder, with shape + (num_query, c). + pos_embed (Tensor): The positional encoding for encoder and + decoder, with shape (feats_flatten_len, bs, c). + + Here, 'feats_flatten_len' = z_feat_h*z_feat_w*2 + \ + x_feat_h*x_feat_w. + 'z_feat_h' and 'z_feat_w' denote the height and width of the + template features respectively. + 'x_feat_h' and 'x_feat_w' denote the height and width of search + features respectively. + Returns: + tuple[Tensor, Tensor]: results of decoder containing the following + tensor. + - out_dec: Output from decoder. If return_intermediate_dec \ + is True, output has shape [num_dec_layers, bs, + num_query, embed_dims], else has shape [1, bs, \ + num_query, embed_dims]. + Here, return_intermediate_dec=False + - enc_mem: Output results from encoder, with shape \ + (feats_flatten_len, bs, embed_dims). + """ + bs, _, _ = x.shape + query_embed = query_embed.unsqueeze(1).repeat( + bs, 1, 1) # [num_query, embed_dims] -> [num_query, bs, embed_dims] + + enc_mem = self.encoder( + query=x, query_pos=pos_embed, key_padding_mask=mask) + target = torch.zeros_like(query_embed) + # out_dec: [num_dec_layers, num_query, bs, embed_dims] + out_dec = self.decoder( + query=target, + key=enc_mem, + value=enc_mem, + key_pos=pos_embed, + query_pos=query_embed, + key_padding_mask=mask) + enc_mem = enc_mem.permute(1, 0, 2) + return out_dec, enc_mem diff --git a/mmtrack/models/track_heads/mask2former_head.py b/mmtrack/models/track_heads/mask2former_head.py index c668a07c4..d20522cb5 100644 --- a/mmtrack/models/track_heads/mask2former_head.py +++ b/mmtrack/models/track_heads/mask2former_head.py @@ -9,6 +9,7 @@ from mmcv.ops import point_sample from mmdet.models.dense_heads import AnchorFreeHead from mmdet.models.dense_heads import MaskFormerHead as MMDET_MaskFormerHead +from mmdet.models.layers import Mask2FormerTransformerDecoder from mmdet.models.utils import get_uncertain_point_coords_with_randomness from mmdet.structures.mask import mask2bbox from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig, reduce_mean @@ -105,10 +106,10 @@ def __init__(self, self.num_queries = num_queries self.num_frames = num_frames self.num_transformer_feat_level = num_transformer_feat_level - self.num_heads = transformer_decoder.transformerlayers. \ - attn_cfgs.num_heads + self.num_heads = transformer_decoder.layer_cfg. \ + self_attn_cfg.num_heads self.num_transformer_decoder_layers = transformer_decoder.num_layers - assert pixel_decoder.encoder.transformerlayers.attn_cfgs.num_levels \ + assert pixel_decoder.encoder.layer_cfg.self_attn_cfg.num_levels \ == num_transformer_feat_level pixel_decoder_ = copy.deepcopy(pixel_decoder) pixel_decoder_.update( @@ -116,7 +117,8 @@ def __init__(self, feat_channels=feat_channels, out_channels=out_channels) self.pixel_decoder = MODELS.build(pixel_decoder_) - self.transformer_decoder = MODELS.build(transformer_decoder) + self.transformer_decoder = Mask2FormerTransformerDecoder( + **transformer_decoder) self.decoder_embed_dims = self.transformer_decoder.embed_dims self.decoder_input_projs = ModuleList() @@ -429,7 +431,7 @@ def _forward_head( """Forward for head part which is called after every decoder layer. Args: - decoder_out (Tensor): in shape (num_queries, batch_size, c). + decoder_out (Tensor): in shape (batch_size, num_queries, c). mask_feature (Tensor): in shape (batch_size, t, c, h, w). attn_mask_target_size (tuple[int, int]): target attention mask size. @@ -446,7 +448,6 @@ def _forward_head( (batch_size * num_heads, num_queries, h, w). """ decoder_out = self.transformer_decoder.post_norm(decoder_out) - decoder_out = decoder_out.transpose(0, 1) # shape (batch_size, num_queries, c) cls_pred = self.cls_embed(decoder_out) # shape (batch_size, num_queries, c) @@ -454,7 +455,6 @@ def _forward_head( # shape (batch_size, num_queries, t, h, w) mask_pred = torch.einsum('bqc,btchw->bqthw', mask_embed, mask_feature) b, q, t, _, _ = mask_pred.shape - attn_mask = F.interpolate( mask_pred.flatten(0, 1), attn_mask_target_size, @@ -496,6 +496,7 @@ def forward(self, x: List[Tensor], """ mask_features, multi_scale_memorys = self.pixel_decoder(x) bt, c_m, h_m, w_m = mask_features.shape + batch_size = bt // self.num_frames if self.training else 1 t = bt // batch_size mask_features = mask_features.view(batch_size, t, c_m, h_m, w_m) @@ -524,11 +525,12 @@ def forward(self, x: List[Tensor], 3).permute(1, 3, 0, 2).flatten(0, 1) decoder_inputs.append(decoder_input) decoder_positional_encodings.append(decoder_positional_encoding) - # shape (num_queries, c) -> (num_queries, batch_size, c) - query_feat = self.query_feat.weight.unsqueeze(1).repeat( - (1, batch_size, 1)) - query_embed = self.query_embed.weight.unsqueeze(1).repeat( - (1, batch_size, 1)) + + # shape (num_queries, c) -> (batch_size, num_queries, c) + query_feat = self.query_feat.weight.unsqueeze(0).repeat( + (batch_size, 1, 1)) + query_embed = self.query_embed.weight.unsqueeze(0).repeat( + (batch_size, 1, 1)) cls_pred_list = [] mask_pred_list = [] @@ -542,17 +544,16 @@ def forward(self, x: List[Tensor], # if a mask is all True(all background), then set it all False. attn_mask[torch.where( attn_mask.sum(-1) == attn_mask.shape[-1])] = False - # cross_attn + self_attn layer = self.transformer_decoder.layers[i] - attn_masks = [attn_mask, None] query_feat = layer( query=query_feat, - key=decoder_inputs[level_idx], - value=decoder_inputs[level_idx], + key=decoder_inputs[level_idx].permute(1, 0, 2), + value=decoder_inputs[level_idx].permute(1, 0, 2), query_pos=query_embed, - key_pos=decoder_positional_encodings[level_idx], - attn_masks=attn_masks, + key_pos=decoder_positional_encodings[level_idx].permute( + 1, 0, 2), + cross_attn_mask=attn_mask, query_key_padding_mask=None, # here we do not apply masking on padded region key_padding_mask=None) diff --git a/mmtrack/models/track_heads/stark_head.py b/mmtrack/models/track_heads/stark_head.py index 3dbf1adce..1750982b7 100644 --- a/mmtrack/models/track_heads/stark_head.py +++ b/mmtrack/models/track_heads/stark_head.py @@ -5,14 +5,12 @@ import torch import torch.nn.functional as F from mmcv.cnn.bricks import ConvModule -from mmcv.cnn.bricks.transformer import build_positional_encoding -from mmdet.models.layers import Transformer from mmengine.model import BaseModule from mmengine.structures import InstanceData from torch import Tensor, nn from mmtrack.registry import MODELS -from mmtrack.utils import InstanceList, OptConfigType, SampleList +from mmtrack.utils import InstanceList, SampleList @MODELS.register_module() @@ -176,93 +174,6 @@ def forward(self, x: Tensor) -> Tensor: return x.view(-1, 1) -@MODELS.register_module() -class StarkTransformer(Transformer): - """The transformer head used in STARK. `STARK. - - `_. - - This module follows the official DETR implementation. - See `paper: End-to-End Object Detection with Transformers - `_ for details. - - Args: - encoder (`mmengine.ConfigDict` | Dict): Config of - TransformerEncoder. Defaults to None. - decoder ((`mmengine.ConfigDict` | Dict)): Config of - TransformerDecoder. Defaults to None - init_cfg (obj:`mmengine.ConfigDict`): The Config for initialization. - Defaults to None. - """ - - def __init__( - self, - encoder: OptConfigType = None, - decoder: OptConfigType = None, - init_cfg: OptConfigType = None, - ): - super(StarkTransformer, self).__init__( - encoder=encoder, decoder=decoder, init_cfg=init_cfg) - - def forward(self, x: Tensor, mask: Tensor, query_embed: Tensor, - pos_embed: Tensor) -> Tuple[Tensor, Tensor]: - """Forward function for `StarkTransformer`. - - The difference with transofrmer module in `MMCV` is the input shape. - The sizes of template feature maps and search feature maps are - different. Thus, we must flatten and concatenate them outside this - module. The `MMCV` flatten the input features inside tranformer module. - - Args: - x (Tensor): Input query with shape (feats_flatten_len, bs, c) - where c = embed_dims. - mask (Tensor): The key_padding_mask used for encoder and decoder, - with shape (bs, feats_flatten_len). - query_embed (Tensor): The query embedding for decoder, with shape - (num_query, c). - pos_embed (Tensor): The positional encoding for encoder and - decoder, with shape (feats_flatten_len, bs, c). - - Here, 'feats_flatten_len' = z_feat_h*z_feat_w*2 + \ - x_feat_h*x_feat_w. - 'z_feat_h' and 'z_feat_w' denote the height and width of the - template features respectively. - 'x_feat_h' and 'x_feat_w' denote the height and width of search - features respectively. - Returns: - tuple[Tensor, Tensor]: results of decoder containing the following - tensor. - - out_dec: Output from decoder. If return_intermediate_dec \ - is True, output has shape [num_dec_layers, bs, - num_query, embed_dims], else has shape [1, bs, \ - num_query, embed_dims]. - Here, return_intermediate_dec=False - - enc_mem: Output results from encoder, with shape \ - (feats_flatten_len, bs, embed_dims). - """ - _, bs, _ = x.shape - query_embed = query_embed.unsqueeze(1).repeat( - 1, bs, 1) # [num_query, embed_dims] -> [num_query, bs, embed_dims] - - enc_mem = self.encoder( - query=x, - key=None, - value=None, - query_pos=pos_embed, - query_key_padding_mask=mask) - target = torch.zeros_like(query_embed) - # out_dec: [num_dec_layers, num_query, bs, embed_dims] - out_dec = self.decoder( - query=target, - key=enc_mem, - value=enc_mem, - key_pos=pos_embed, - query_pos=query_embed, - key_padding_mask=mask) - out_dec = out_dec.transpose(1, 2) - return out_dec, enc_mem - - @MODELS.register_module() class StarkHead(BaseModule): """STARK head module for bounding box regression and prediction of @@ -273,11 +184,6 @@ class StarkHead(BaseModule): `STARK `_. Args: - num_query (int): Number of query in transformer. - transformer (obj:`mmengine.ConfigDict`|dict): Config for transformer. - Default: None. - positional_encoding (obj:`mmengine.ConfigDict`|dict): - Config for position encoding. bbox_head (obj:`mmengine.ConfigDict`|dict, optional): Config for bbox head. Defaults to None. cls_head (obj:`mmengine.ConfigDict`|dict, optional): Config for @@ -297,12 +203,6 @@ class StarkHead(BaseModule): """ def __init__(self, - num_query=1, - transformer=None, - positional_encoding=dict( - type='SinePositionalEncoding', - num_feats=128, - normalize=True), bbox_head=None, cls_head=None, loss_cls=dict( @@ -318,9 +218,6 @@ def __init__(self, frozen_modules=None, **kwargs): super(StarkHead, self).__init__(init_cfg=init_cfg) - self.transformer = MODELS.build(transformer) - self.positional_encoding = build_positional_encoding( - positional_encoding) assert bbox_head is not None self.bbox_head = MODELS.build(bbox_head) if cls_head is None: @@ -332,10 +229,6 @@ def __init__(self, # the stage-2 training self.cls_head = MODELS.build(cls_head) self.loss_cls = MODELS.build(loss_cls) - self.embed_dims = self.transformer.embed_dims - self.num_query = num_query - self.query_embedding = nn.Embedding(self.num_query, self.embed_dims) - self.train_cfg = train_cfg self.test_cfg = test_cfg self.fp16_enabled = False @@ -350,10 +243,6 @@ def __init__(self, for param in m.parameters(): param.requires_grad = False - def init_weights(self): - """Parameters initialization.""" - self.transformer.init_weights() - def _merge_template_search(self, inputs: List[Dict[str, Tensor]]) -> dict: """Merge the data of template and search images. The merge includes 3 steps: flatten, premute and concatenate. @@ -433,7 +322,7 @@ def forward_bbox_head(self, feat: Tensor, enc_mem: Tensor) -> Tensor: outputs_coord = self.bbox_head(bbox_feat) return outputs_coord - def forward(self, inputs: List[dict]) -> dict: + def forward(self, outs_dec, enc_mem) -> dict: """" Args: inputs (list[dict(tuple(Tensor))]): The list contains the @@ -451,25 +340,6 @@ def forward(self, inputs: List[dict]) -> dict: [tl_x, tl_y, br_x, br_y] format - 'pred_logit': (Tensor) of shape (bs * num_query, 1) """ - # 1. preprocess inputs for transformer - all_inputs = [] - for input in inputs: - feat = input['feat'][0] - feat_size = feat.shape[-2:] - mask = F.interpolate( - input['mask'][None].float(), size=feat_size).to(torch.bool)[0] - pos_embed = self.positional_encoding(mask) - all_inputs.append(dict(feat=feat, mask=mask, pos_embed=pos_embed)) - all_inputs = self._merge_template_search(all_inputs) - - # 2. forward transformer head - # outs_dec is in (1, bs, num_query, c) shape - # enc_mem is in (feats_flatten_len, bs, c) shape - outs_dec, enc_mem = self.transformer(all_inputs['feat'], - all_inputs['mask'], - self.query_embedding.weight, - all_inputs['pos_embed']) - # 3. forward bbox head and classification head if not self.training: pred_logits = None @@ -652,7 +522,7 @@ def _bbox_clip(self, # TODO: unify the `sefl.predict`, `self.loss` and so on in all the heads of # SOT. - def loss(self, inputs: List[dict], data_samples: SampleList, + def loss(self, inputs, outs_dec, enc_mem, data_samples: SampleList, **kwargs) -> dict: """Compute loss. @@ -671,7 +541,7 @@ def loss(self, inputs: List[dict], data_samples: SampleList, Returns: dict[str, Tensor]: a dictionary of loss components. """ - outs = self(inputs) + outs = self(outs_dec, enc_mem) batch_gt_instances = [] batch_img_metas = []