|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import warnings |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.utils.checkpoint as cp |
|
from mmcv.cnn import build_norm_layer |
|
from mmcv.cnn.bricks.transformer import (BaseTransformerLayer, |
|
TransformerLayerSequence) |
|
from mmengine.model import BaseModule |
|
from mmengine.model.weight_init import xavier_init |
|
|
|
|
|
from mmdet3d.registry import MODELS, TASK_UTILS |
|
|
|
|
|
@MODELS.register_module() |
|
class PETRTransformer(BaseModule): |
|
"""Implements the DETR transformer. Following the official DETR |
|
implementation, this module copy-paste from torch.nn.Transformer with |
|
modifications: |
|
|
|
* positional encodings are passed in MultiheadAttention |
|
* extra LN at the end of encoder is removed |
|
* decoder returns a stack of activations from all decoding layers |
|
See `paper: End-to-End Object Detection with Transformers |
|
<https://arxiv.org/pdf/2005.12872>`_ for details. |
|
Args: |
|
encoder (`mmcv.ConfigDict` | Dict): Config of |
|
TransformerEncoder. Defaults to None. |
|
decoder ((`mmcv.ConfigDict` | Dict)): Config of |
|
TransformerDecoder. Defaults to None |
|
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. |
|
Defaults to None. |
|
""" |
|
|
|
def __init__(self, encoder=None, decoder=None, init_cfg=None, cross=False): |
|
super(PETRTransformer, self).__init__(init_cfg=init_cfg) |
|
if encoder is not None: |
|
self.encoder = MODELS.build(encoder) |
|
else: |
|
self.encoder = None |
|
self.decoder = MODELS.build(decoder) |
|
self.embed_dims = self.decoder.embed_dims |
|
self.cross = cross |
|
|
|
def init_weights(self): |
|
|
|
for m in self.modules(): |
|
if hasattr(m, 'weight') and m.weight.dim() > 1: |
|
xavier_init(m, distribution='uniform') |
|
self._is_init = True |
|
|
|
def forward(self, x, mask, query_embed, pos_embed, reg_branch=None): |
|
"""Forward function for `Transformer`. |
|
Args: |
|
x (Tensor): Input query with shape [bs, c, h, w] where |
|
c = embed_dims. |
|
mask (Tensor): The key_padding_mask used for encoder and decoder, |
|
with shape [bs, h, w]. |
|
query_embed (Tensor): The query embedding for decoder, with shape |
|
[num_query, c]. |
|
pos_embed (Tensor): The positional encoding for encoder and |
|
decoder, with the same shape as `x`. |
|
Returns: |
|
tuple[Tensor]: results of decoder containing the following tensor. |
|
- out_dec: Output from decoder. If return_intermediate_dec \ |
|
is True output has shape [num_dec_layers, bs, |
|
num_query, embed_dims], else has shape [1, bs, \ |
|
num_query, embed_dims]. |
|
- memory: Output results from encoder, with shape \ |
|
[bs, embed_dims, h, w]. |
|
""" |
|
bs, n, c, h, w = x.shape |
|
memory = x.permute(1, 3, 4, 0, |
|
2).reshape(-1, bs, |
|
c) |
|
pos_embed = pos_embed.permute(1, 3, 4, 0, 2).reshape( |
|
-1, bs, c) |
|
query_embed = query_embed.unsqueeze(1).repeat( |
|
1, bs, 1) |
|
mask = mask.view(bs, -1) |
|
target = torch.zeros_like(query_embed) |
|
|
|
|
|
out_dec = self.decoder( |
|
query=target, |
|
key=memory, |
|
value=memory, |
|
key_pos=pos_embed, |
|
query_pos=query_embed, |
|
key_padding_mask=mask, |
|
reg_branch=reg_branch, |
|
) |
|
out_dec = out_dec.transpose(1, 2) |
|
memory = memory.reshape(n, h, w, bs, c).permute(3, 0, 4, 1, 2) |
|
return out_dec, memory |
|
|
|
|
|
@MODELS.register_module() |
|
class PETRDNTransformer(BaseModule): |
|
"""Implements the DETR transformer. Following the official DETR |
|
implementation, this module copy-paste from torch.nn.Transformer with |
|
modifications: |
|
|
|
* positional encodings are passed in MultiheadAttention |
|
* extra LN at the end of encoder is removed |
|
* decoder returns a stack of activations from all decoding layers |
|
See `paper: End-to-End Object Detection with Transformers |
|
<https://arxiv.org/pdf/2005.12872>`_ for details. |
|
Args: |
|
encoder (`mmcv.ConfigDict` | Dict): Config of |
|
TransformerEncoder. Defaults to None. |
|
decoder ((`mmcv.ConfigDict` | Dict)): Config of |
|
TransformerDecoder. Defaults to None |
|
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. |
|
Defaults to None. |
|
""" |
|
|
|
def __init__(self, encoder=None, decoder=None, init_cfg=None, cross=False): |
|
super(PETRDNTransformer, self).__init__(init_cfg=init_cfg) |
|
if encoder is not None: |
|
self.encoder = MODELS.build(encoder) |
|
else: |
|
self.encoder = None |
|
self.decoder = MODELS.build(decoder) |
|
self.embed_dims = self.decoder.embed_dims |
|
self.cross = cross |
|
|
|
def init_weights(self): |
|
|
|
for m in self.modules(): |
|
if hasattr(m, 'weight') and m.weight.dim() > 1: |
|
xavier_init(m, distribution='uniform') |
|
self._is_init = True |
|
|
|
def forward(self, |
|
x, |
|
mask, |
|
query_embed, |
|
pos_embed, |
|
attn_masks=None, |
|
reg_branch=None): |
|
"""Forward function for `Transformer`. |
|
Args: |
|
x (Tensor): Input query with shape [bs, c, h, w] where |
|
c = embed_dims. |
|
mask (Tensor): The key_padding_mask used for encoder and decoder, |
|
with shape [bs, h, w]. |
|
query_embed (Tensor): The query embedding for decoder, with shape |
|
[num_query, c]. |
|
pos_embed (Tensor): The positional encoding for encoder and |
|
decoder, with the same shape as `x`. |
|
Returns: |
|
tuple[Tensor]: results of decoder containing the following tensor. |
|
- out_dec: Output from decoder. If return_intermediate_dec \ |
|
is True output has shape [num_dec_layers, bs, |
|
num_query, embed_dims], else has shape [1, bs, \ |
|
num_query, embed_dims]. |
|
- memory: Output results from encoder, with shape \ |
|
[bs, embed_dims, h, w]. |
|
""" |
|
bs, n, c, h, w = x.shape |
|
memory = x.permute(1, 3, 4, 0, |
|
2).reshape(-1, bs, |
|
c) |
|
pos_embed = pos_embed.permute(1, 3, 4, 0, 2).reshape( |
|
-1, bs, c) |
|
query_embed = query_embed.transpose( |
|
0, 1) |
|
mask = mask.view(bs, -1) |
|
target = torch.zeros_like(query_embed) |
|
|
|
out_dec = self.decoder( |
|
query=target, |
|
key=memory, |
|
value=memory, |
|
key_pos=pos_embed, |
|
query_pos=query_embed, |
|
key_padding_mask=mask, |
|
attn_masks=[attn_masks, None], |
|
reg_branch=reg_branch, |
|
) |
|
out_dec = out_dec.transpose(1, 2) |
|
memory = memory.reshape(n, h, w, bs, c).permute(3, 0, 4, 1, 2) |
|
return out_dec, memory |
|
|
|
|
|
@MODELS.register_module() |
|
class PETRTransformerDecoderLayer(BaseTransformerLayer): |
|
"""Implements decoder layer in DETR transformer. |
|
|
|
Args: |
|
attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )): |
|
Configs for self_attention or cross_attention, the order |
|
should be consistent with it in `operation_order`. If it is |
|
a dict, it would be expand to the number of attention in |
|
`operation_order`. |
|
feedforward_channels (int): The hidden dimension for FFNs. |
|
ffn_dropout (float): Probability of an element to be zeroed |
|
in ffn. Default 0.0. |
|
operation_order (tuple[str]): The execution order of operation |
|
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). |
|
Default:None |
|
act_cfg (dict): The activation config for FFNs. Default: `LN` |
|
norm_cfg (dict): Config dict for normalization layer. |
|
Default: `LN`. |
|
ffn_num_fcs (int): The number of fully-connected layers in FFNs. |
|
Default:2. |
|
""" |
|
|
|
def __init__(self, |
|
attn_cfgs, |
|
feedforward_channels, |
|
ffn_dropout=0.0, |
|
operation_order=None, |
|
act_cfg=dict(type='ReLU', inplace=True), |
|
norm_cfg=dict(type='LN'), |
|
ffn_num_fcs=2, |
|
with_cp=True, |
|
**kwargs): |
|
super(PETRTransformerDecoderLayer, self).__init__( |
|
attn_cfgs=attn_cfgs, |
|
feedforward_channels=feedforward_channels, |
|
ffn_dropout=ffn_dropout, |
|
operation_order=operation_order, |
|
act_cfg=act_cfg, |
|
norm_cfg=norm_cfg, |
|
ffn_num_fcs=ffn_num_fcs, |
|
**kwargs) |
|
assert len(operation_order) == 6 |
|
assert set(operation_order) == set( |
|
['self_attn', 'norm', 'cross_attn', 'ffn']) |
|
self.use_checkpoint = with_cp |
|
|
|
def _forward( |
|
self, |
|
query, |
|
key=None, |
|
value=None, |
|
query_pos=None, |
|
key_pos=None, |
|
attn_masks=None, |
|
query_key_padding_mask=None, |
|
key_padding_mask=None, |
|
): |
|
"""Forward function for `TransformerCoder`. |
|
|
|
Returns: |
|
Tensor: forwarded results with shape [num_query, bs, embed_dims]. |
|
""" |
|
x = super(PETRTransformerDecoderLayer, self).forward( |
|
query, |
|
key=key, |
|
value=value, |
|
query_pos=query_pos, |
|
key_pos=key_pos, |
|
attn_masks=attn_masks, |
|
query_key_padding_mask=query_key_padding_mask, |
|
key_padding_mask=key_padding_mask, |
|
) |
|
|
|
return x |
|
|
|
def forward(self, |
|
query, |
|
key=None, |
|
value=None, |
|
query_pos=None, |
|
key_pos=None, |
|
attn_masks=None, |
|
query_key_padding_mask=None, |
|
key_padding_mask=None, |
|
**kwargs): |
|
"""Forward function for `TransformerCoder`. |
|
|
|
Returns: |
|
Tensor: forwarded results with shape [num_query, bs, embed_dims]. |
|
""" |
|
|
|
if self.use_checkpoint and self.training: |
|
x = cp.checkpoint( |
|
self._forward, |
|
query, |
|
key, |
|
value, |
|
query_pos, |
|
key_pos, |
|
attn_masks, |
|
query_key_padding_mask, |
|
key_padding_mask, |
|
) |
|
else: |
|
x = self._forward( |
|
query, |
|
key=key, |
|
value=value, |
|
query_pos=query_pos, |
|
key_pos=key_pos, |
|
attn_masks=attn_masks, |
|
query_key_padding_mask=query_key_padding_mask, |
|
key_padding_mask=key_padding_mask) |
|
return x |
|
|
|
|
|
@MODELS.register_module() |
|
class PETRMultiheadAttention(BaseModule): |
|
"""A wrapper for ``torch.nn.MultiheadAttention``. |
|
|
|
This module implements MultiheadAttention with identity connection, |
|
and positional encoding is also passed as input. |
|
Args: |
|
embed_dims (int): The embedding dimension. |
|
num_heads (int): Parallel attention heads. |
|
attn_drop (float): A Dropout layer on attn_output_weights. |
|
Default: 0.0. |
|
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. |
|
Default: 0.0. |
|
dropout_layer (obj:`ConfigDict`): The dropout_layer used |
|
when adding the shortcut. |
|
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. |
|
Default: None. |
|
batch_first (bool): When it is True, Key, Query and Value are shape of |
|
(batch, n, embed_dim), otherwise (n, batch, embed_dim). |
|
Default to False. |
|
""" |
|
|
|
def __init__(self, |
|
embed_dims, |
|
num_heads, |
|
attn_drop=0., |
|
proj_drop=0., |
|
dropout_layer=dict(type='Dropout', drop_prob=0.), |
|
init_cfg=None, |
|
batch_first=False, |
|
**kwargs): |
|
super(PETRMultiheadAttention, self).__init__(init_cfg) |
|
if 'dropout' in kwargs: |
|
warnings.warn( |
|
'The arguments `dropout` in MultiheadAttention ' |
|
'has been deprecated, now you can separately ' |
|
'set `attn_drop`(float), proj_drop(float), ' |
|
'and `dropout_layer`(dict) ', DeprecationWarning) |
|
attn_drop = kwargs['dropout'] |
|
dropout_layer['drop_prob'] = kwargs.pop('dropout') |
|
|
|
self.embed_dims = embed_dims |
|
self.num_heads = num_heads |
|
self.batch_first = batch_first |
|
|
|
self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop, |
|
**kwargs) |
|
|
|
self.proj_drop = nn.Dropout(proj_drop) |
|
self.dropout_layer = MODELS.build( |
|
dropout_layer) if dropout_layer else nn.Identity() |
|
|
|
|
|
|
|
def forward(self, |
|
query, |
|
key=None, |
|
value=None, |
|
identity=None, |
|
query_pos=None, |
|
key_pos=None, |
|
attn_mask=None, |
|
key_padding_mask=None, |
|
**kwargs): |
|
"""Forward function for `MultiheadAttention`. |
|
|
|
**kwargs allow passing a more general data flow when combining |
|
with other operations in `transformerlayer`. |
|
Args: |
|
query (Tensor): The input query with shape [num_queries, bs, |
|
embed_dims] if self.batch_first is False, else |
|
[bs, num_queries embed_dims]. |
|
key (Tensor): The key tensor with shape [num_keys, bs, |
|
embed_dims] if self.batch_first is False, else |
|
[bs, num_keys, embed_dims] . |
|
If None, the ``query`` will be used. Defaults to None. |
|
value (Tensor): The value tensor with same shape as `key`. |
|
Same in `nn.MultiheadAttention.forward`. Defaults to None. |
|
If None, the `key` will be used. |
|
identity (Tensor): This tensor, with the same shape as x, |
|
will be used for the identity link. |
|
If None, `x` will be used. Defaults to None. |
|
query_pos (Tensor): The positional encoding for query, with |
|
the same shape as `x`. If not None, it will |
|
be added to `x` before forward function. Defaults to None. |
|
key_pos (Tensor): The positional encoding for `key`, with the |
|
same shape as `key`. Defaults to None. If not None, it will |
|
be added to `key` before forward function. If None, and |
|
`query_pos` has the same shape as `key`, then `query_pos` |
|
will be used for `key_pos`. Defaults to None. |
|
attn_mask (Tensor): ByteTensor mask with shape [num_queries, |
|
num_keys]. Same in `nn.MultiheadAttention.forward`. |
|
Defaults to None. |
|
key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. |
|
Defaults to None. |
|
Returns: |
|
Tensor: forwarded results with shape |
|
[num_queries, bs, embed_dims] |
|
if self.batch_first is False, else |
|
[bs, num_queries embed_dims]. |
|
""" |
|
|
|
if key is None: |
|
key = query |
|
if value is None: |
|
value = key |
|
if identity is None: |
|
identity = query |
|
if key_pos is None: |
|
if query_pos is not None: |
|
|
|
if query_pos.shape == key.shape: |
|
key_pos = query_pos |
|
else: |
|
warnings.warn(f'position encoding of key is' |
|
f'missing in {self.__class__.__name__}.') |
|
if query_pos is not None: |
|
query = query + query_pos |
|
if key_pos is not None: |
|
key = key + key_pos |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.batch_first: |
|
query = query.transpose(0, 1) |
|
key = key.transpose(0, 1) |
|
value = value.transpose(0, 1) |
|
|
|
out = self.attn( |
|
query=query, |
|
key=key, |
|
value=value, |
|
attn_mask=attn_mask, |
|
key_padding_mask=key_padding_mask)[0] |
|
|
|
if self.batch_first: |
|
out = out.transpose(0, 1) |
|
|
|
return identity + self.dropout_layer(self.proj_drop(out)) |
|
|
|
|
|
@MODELS.register_module() |
|
class PETRTransformerEncoder(TransformerLayerSequence): |
|
"""TransformerEncoder of DETR. |
|
|
|
Args: |
|
post_norm_cfg (dict): Config of last normalization layer. Default: |
|
`LN`. Only used when `self.pre_norm` is `True` |
|
""" |
|
|
|
def __init__(self, *args, post_norm_cfg=dict(type='LN'), **kwargs): |
|
super(PETRTransformerEncoder, self).__init__(*args, **kwargs) |
|
if post_norm_cfg is not None: |
|
self.post_norm = TASK_UTILS.build( |
|
post_norm_cfg, self.embed_dims)[1] if self.pre_norm else None |
|
else: |
|
assert not self.pre_norm, f'Use prenorm in ' \ |
|
f'{self.__class__.__name__},' \ |
|
f'Please specify post_norm_cfg' |
|
self.post_norm = None |
|
|
|
def forward(self, *args, **kwargs): |
|
"""Forward function for `TransformerCoder`. |
|
|
|
Returns: |
|
Tensor: forwarded results with shape [num_query, bs, embed_dims]. |
|
""" |
|
x = super(PETRTransformerEncoder, self).forward(*args, **kwargs) |
|
if self.post_norm is not None: |
|
x = self.post_norm(x) |
|
return x |
|
|
|
|
|
@MODELS.register_module() |
|
class PETRTransformerDecoder(TransformerLayerSequence): |
|
"""Implements the decoder in DETR transformer. |
|
|
|
Args: |
|
return_intermediate (bool): Whether to return intermediate outputs. |
|
post_norm_cfg (dict): Config of last normalization layer. Default: |
|
`LN`. |
|
""" |
|
|
|
def __init__(self, |
|
*args, |
|
post_norm_cfg=dict(type='LN'), |
|
return_intermediate=False, |
|
**kwargs): |
|
|
|
super(PETRTransformerDecoder, self).__init__(*args, **kwargs) |
|
self.return_intermediate = return_intermediate |
|
if post_norm_cfg is not None: |
|
self.post_norm = build_norm_layer(post_norm_cfg, |
|
self.embed_dims)[1] |
|
else: |
|
self.post_norm = None |
|
|
|
def forward(self, query, *args, **kwargs): |
|
"""Forward function for `TransformerDecoder`. |
|
Args: |
|
query (Tensor): Input query with shape |
|
`(num_query, bs, embed_dims)`. |
|
Returns: |
|
Tensor: Results with shape [1, num_query, bs, embed_dims] when |
|
return_intermediate is `False`, otherwise it has shape |
|
[num_layers, num_query, bs, embed_dims]. |
|
""" |
|
if not self.return_intermediate: |
|
x = super().forward(query, *args, **kwargs) |
|
if self.post_norm: |
|
x = self.post_norm(x)[None] |
|
return x |
|
|
|
intermediate = [] |
|
for layer in self.layers: |
|
query = layer(query, *args, **kwargs) |
|
if self.return_intermediate: |
|
if self.post_norm is not None: |
|
intermediate.append(self.post_norm(query)) |
|
else: |
|
intermediate.append(query) |
|
return torch.stack(intermediate) |
|
|