Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 108 additions & 39 deletions src/transformers/models/distilbert/modeling_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,12 @@ def create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor):


def _create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor):
position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
position_enc = np.array(
[
[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]
for pos in range(n_pos)
]
)
out.requires_grad = False
out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
Expand All @@ -83,13 +88,19 @@ def _create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor):
class Embeddings(nn.Module):
def __init__(self, config: PreTrainedConfig):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
self.word_embeddings = nn.Embedding(
config.vocab_size, config.dim, padding_idx=config.pad_token_id
)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.dim
)

self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
self.dropout = nn.Dropout(config.dropout)
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
"position_ids",
torch.arange(config.max_position_embeddings).expand((1, -1)),
persistent=False,
)

@deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
Expand All @@ -111,10 +122,16 @@ def forward(
if hasattr(self, "position_ids"):
position_ids = self.position_ids[:, :seq_length]
else:
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
position_ids = torch.arange(
seq_length, dtype=torch.long, device=input_ids.device
) # (max_seq_length)
position_ids = position_ids.unsqueeze(0).expand_as(
input_ids
) # (bs, max_seq_length)

position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
position_embeddings = self.position_embeddings(
position_ids
) # (bs, max_seq_length, dim)

embeddings = inputs_embeds + position_embeddings # (bs, max_seq_length, dim)
embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim)
Expand Down Expand Up @@ -143,7 +160,9 @@ def eager_attention_forward(
attn_weights = attn_weights + attention_mask

attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_weights = nn.functional.dropout(
attn_weights, p=dropout, training=module.training
)

attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
Expand All @@ -164,7 +183,9 @@ def __init__(self, config: PreTrainedConfig):
# Have an even number of multi heads that divide the dimensions
if self.dim % self.n_heads != 0:
# Raise value errors for even multi-head attention nodes
raise ValueError(f"self.n_heads: {self.n_heads} must divide self.dim: {self.dim} evenly")
raise ValueError(
f"self.n_heads: {self.n_heads} must divide self.dim: {self.dim} evenly"
)

self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
Expand Down Expand Up @@ -218,7 +239,20 @@ def __init__(self, config: PreTrainedConfig):
self.activation = get_activation(config.activation)

def forward(self, input: torch.Tensor) -> torch.Tensor:
return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)
"""
Forward pass of the FFN layer.

Args:
input (`torch.Tensor` of shape `(batch_size, sequence_length, dim)`):
Input tensor to the feed-forward network.

Returns:
`torch.Tensor` of shape `(batch_size, sequence_length, dim)`):
Output tensor after feed-forward transformation.
"""
return apply_chunking_to_forward(
self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input
)

def ff_chunk(self, input: torch.Tensor) -> torch.Tensor:
x = self.lin1(input)
Expand All @@ -234,7 +268,9 @@ def __init__(self, config: PreTrainedConfig):

# Have an even number of Configure multi-heads
if config.dim % config.n_heads != 0:
raise ValueError(f"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly")
raise ValueError(
f"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly"
)

self.attention = DistilBertSelfAttention(config)
self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
Expand Down Expand Up @@ -267,7 +303,9 @@ class Transformer(nn.Module):
def __init__(self, config: PreTrainedConfig):
super().__init__()
self.n_layers = config.n_layers
self.layer = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
self.layer = nn.ModuleList(
[TransformerBlock(config) for _ in range(config.n_layers)]
)
self.gradient_checkpointing = False

def forward(
Expand Down Expand Up @@ -315,7 +353,10 @@ def _init_weights(self, module: nn.Module):
torch.empty_like(module.position_embeddings.weight),
),
)
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
init.copy_(
module.position_ids,
torch.arange(module.position_ids.shape[-1]).expand((1, -1)),
)


@auto_docstring
Expand Down Expand Up @@ -347,29 +388,39 @@ def resize_position_embeddings(self, new_num_position_embeddings: int):
size will add correct vectors at the end following the position encoding algorithm, whereas reducing
the size will remove vectors from the end.
"""
num_position_embeds_diff = new_num_position_embeddings - self.config.max_position_embeddings
num_position_embeds_diff = (
new_num_position_embeddings - self.config.max_position_embeddings
)

# no resizing needs to be done if the length stays the same
if num_position_embeds_diff == 0:
return

logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
logger.info(
f"Setting `config.max_position_embeddings={new_num_position_embeddings}`..."
)
self.config.max_position_embeddings = new_num_position_embeddings

old_position_embeddings_weight = self.embeddings.position_embeddings.weight.clone()
old_position_embeddings_weight = (
self.embeddings.position_embeddings.weight.clone()
)

self.embeddings.position_embeddings = nn.Embedding(self.config.max_position_embeddings, self.config.dim)
self.embeddings.position_embeddings = nn.Embedding(
self.config.max_position_embeddings, self.config.dim
)

if self.config.sinusoidal_pos_embds:
create_sinusoidal_embeddings(
n_pos=self.config.max_position_embeddings, dim=self.config.dim, out=self.position_embeddings.weight
n_pos=self.config.max_position_embeddings,
dim=self.config.dim,
out=self.position_embeddings.weight,
)
else:
with torch.no_grad():
if num_position_embeds_diff > 0:
self.embeddings.position_embeddings.weight[:-num_position_embeds_diff] = nn.Parameter(
old_position_embeddings_weight
)
self.embeddings.position_embeddings.weight[
:-num_position_embeds_diff
] = nn.Parameter(old_position_embeddings_weight)
else:
self.embeddings.position_embeddings.weight = nn.Parameter(
old_position_embeddings_weight[:num_position_embeds_diff]
Expand Down Expand Up @@ -408,7 +459,9 @@ def forward(
model's internal embedding lookup matrix.
"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)

embeddings = self.embeddings(input_ids, inputs_embeds, position_ids)

Expand All @@ -425,13 +478,13 @@ def forward(
)


@auto_docstring(
custom_intro="""
@auto_docstring(custom_intro="""
DistilBert Model with a `masked language modeling` head on top.
"""
)
""")
class DistilBertForMaskedLM(DistilBertPreTrainedModel):
_tied_weights_keys = {"vocab_projector.weight": "distilbert.embeddings.word_embeddings.weight"}
_tied_weights_keys = {
"vocab_projector.weight": "distilbert.embeddings.word_embeddings.weight"
}

def __init__(self, config: PreTrainedConfig):
super().__init__(config)
Expand Down Expand Up @@ -513,12 +566,18 @@ def forward(
hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
prediction_logits = self.activation(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_projector(prediction_logits) # (bs, seq_length, vocab_size)
prediction_logits = self.vocab_layer_norm(
prediction_logits
) # (bs, seq_length, dim)
prediction_logits = self.vocab_projector(
prediction_logits
) # (bs, seq_length, vocab_size)

mlm_loss = None
if labels is not None:
mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1))
mlm_loss = self.mlm_loss_fct(
prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1)
)

return MaskedLMOutput(
loss=mlm_loss,
Expand All @@ -528,12 +587,10 @@ def forward(
)


@auto_docstring(
custom_intro="""
@auto_docstring(custom_intro="""
DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the
pooled output) e.g. for GLUE tasks.
"""
)
""")
class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
def __init__(self, config: PreTrainedConfig):
super().__init__(config)
Expand Down Expand Up @@ -605,7 +662,9 @@ def forward(
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
elif self.num_labels > 1 and (
labels.dtype == torch.long or labels.dtype == torch.int
):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
Expand Down Expand Up @@ -639,7 +698,9 @@ def __init__(self, config: PreTrainedConfig):
self.distilbert = DistilBertModel(config)
self.qa_outputs = nn.Linear(config.dim, config.num_labels)
if config.num_labels != 2:
raise ValueError(f"config.num_labels should be 2, but it is {config.num_labels}")
raise ValueError(
f"config.num_labels should be 2, but it is {config.num_labels}"
)

self.dropout = nn.Dropout(config.qa_dropout)

Expand Down Expand Up @@ -890,10 +951,18 @@ def forward(
>>> loss = outputs.loss
>>> logits = outputs.logits
```"""
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
num_choices = (
input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
)

input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
input_ids = (
input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
)
attention_mask = (
attention_mask.view(-1, attention_mask.size(-1))
if attention_mask is not None
else None
)
inputs_embeds = (
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
if inputs_embeds is not None
Expand Down
Loading