diff --git a/arctic_training/trainer/trainer.py b/arctic_training/trainer/trainer.py index 6b497b47..21f927b2 100644 --- a/arctic_training/trainer/trainer.py +++ b/arctic_training/trainer/trainer.py @@ -400,10 +400,17 @@ def epoch(self) -> None: # deal correctly with packed samples under FA2, by calculating each seqlen tflos separately sample_seqlens = batch.pop("packed_sample_seqlens") else: - sample_seqlens = [ - [len(batch["input_ids"][idx]) * self.config.sequence_parallel_size] - for idx in range(len(batch["input_ids"])) - ] + if "input_ids" not in batch: + # batch is a ContrastiveLearningBatch + sample_seqlens = [ + [len(batch.query_tokens[idx]) * self.config.sequence_parallel_size] + for idx in range(batch.query_tokens.shape[0]) + ] + else: + sample_seqlens = [ + [len(batch["input_ids"][idx]) * self.config.sequence_parallel_size] + for idx in range(len(batch["input_ids"])) + ] self.metrics.seqlens = sample_seqlens self.metrics.start_timer("step") diff --git a/projects/arctic_embed/examples/finetune_models/README.md b/projects/arctic_embed/examples/finetune_models/README.md index a936b82d..425ee068 100644 --- a/projects/arctic_embed/examples/finetune_models/README.md +++ b/projects/arctic_embed/examples/finetune_models/README.md @@ -17,7 +17,7 @@ GIT_LFS_SKIP_SMUDGE=1 git clone https://hf.co/datasets/Snowflake/arctic-embed-ft mv ./data.gitignore ./data/.gitignore # Ensure we have all the files you need for training downloaded from LFS. -cd arctic-embed-ft-v1/ +cd data/ git lfs pull --include="combined/pretokenized/example_dot95/,eval/" # Optional: Download more large files (e.g. everything but the very large precomputed embeddings). diff --git a/projects/arctic_embed/examples/finetune_models/finetune_e5_base_unsupervised.py b/projects/arctic_embed/examples/finetune_models/finetune_e5_base_unsupervised.py index ff246a43..f335f309 100644 --- a/projects/arctic_embed/examples/finetune_models/finetune_e5_base_unsupervised.py +++ b/projects/arctic_embed/examples/finetune_models/finetune_e5_base_unsupervised.py @@ -72,7 +72,7 @@ def now_timestamp_str() -> str: eval_max_seq_length_doc=512, eval_max_seq_length_query=512, ) -sconf = WSDSchedulerConfig(num_warmup_steps=500, num_decay_steps=1_000, learning_rate=LEARNING_RATE) +sconf = WSDSchedulerConfig(num_warmup_steps=500, num_decay_steps=1_000) oconf = OptimizerConfig(weight_decay=0.01, learning_rate=LEARNING_RATE) lconf = LoggerConfig(level="INFO") wconf = WandBConfig( diff --git a/projects/arctic_embed/examples/finetune_models/pretrain_bge_retromae.py b/projects/arctic_embed/examples/finetune_models/pretrain_bge_retromae.py new file mode 100644 index 00000000..3fb9418e --- /dev/null +++ b/projects/arctic_embed/examples/finetune_models/pretrain_bge_retromae.py @@ -0,0 +1,195 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This example shows how to use the Arctic Embed codebase to finetune +the venerable E5-base-v2 model (released in May 2023) on a version of MSMARCO +training data which has been hard-negative-mined using a more modern technique. + +The code needed to recreate the training data can be found in the sibling directory +`data_prep` within the `hard_negative_mining` subdirectory. + +Original model paper: https://arxiv.org/abs/2212.03533 +Model page: https://huggingface.co/intfloat/e5-base-v2 +Better negative mining paper: https://arxiv.org/abs/2407.15831 +""" +import sys +from datetime import datetime +from datetime import timezone +from pathlib import Path + +from arctic_embed.biencoder_model_factory import BiencoderModelConfig +from arctic_embed.contrastive_dataloader import ContrastivePretokenizedDataConfig +from arctic_embed.core.cuda_allocator_config import CUDA_ALLOCATOR_CONFIG_FOR_DYNAMICALLY_SIZED_DATA +from arctic_embed.trainer import BiencoderTrainer +from arctic_embed.trainer import BiencoderTrainerConfig + +from arctic_training.config.checkpoint import CheckpointConfig +from arctic_training.config.logger import LoggerConfig +from arctic_training.config.optimizer import OptimizerConfig +from arctic_training.config.wandb import WandBConfig +from arctic_training.scheduler.wsd_factory import WSDSchedulerConfig + +LEARNING_RATE = 3e-5 +GRADIENT_CLIPPING = 10.0 +# DATA_PATH = str(Path(__file__).parent / "data" / "pretrain_amazonqa" / "batched_16384") +DATA_PATH = ( + "s3://ml-dev-sfc-or-dev-misc1-k8s/cortexsearch/biencoder/pretrain_data_arctic_training_format/combined_all_16384" +) +# EVAL_DATA_PATHS = [str(path) for path in (Path(__file__).parent / "data" / "eval").iterdir() if path.is_dir()] # fix this +datasets = [ + "amazon_qa", + "ccnews_de_v1", + "ccnews_en_v1", + "ccnews_es_v1", + "ccnews_fr_v1", + "ccnews_it_v1", + "ccnews_pl_v1", + "ccnews_pt_v1", + "faq", + "mc4_de_v1", + "mc4_en_v1", + "mc4_es_v1", + "mc4_fr_v1", + "mc4_it_v1", + "mc4_pl_v1", + "mc4_pt_v1", + "mwiki_de_v1", + "mwiki_en_v1", + "mwiki_es_v1", + "mwiki_fr_v1", + "mwiki_it_v1", + "mwiki_pl_v1", + "mwiki_pt_v1", + "paq", + "pes2o", + "red_pajama", + "red_pajamas_1t_stackexchange", + "s2orc_title_abstracts", + "snippets4", + "techrepo", + "top_stories", + "trivia_qa", + "wikipedia", +] +EVAL_DATA_PATHS = [ + f"s3://ml-dev-sfc-or-dev-misc1-k8s/cortexsearch/biencoder/pretrain_data_arctic_training_format/combined_all_16384_eval/{dataset}" + for dataset in datasets +] +# from transformers import AutoTokenizer +# tok = AutoTokenizer.from_pretrained("BAAI/bge-m3-retromae") +# tok.pad_token_id --> 1 +PAD_VALUE = 1 +LEFT_PAD = False + + +def now_timestamp_str() -> str: + """Get the current ISO 8601 UTC timestamp.""" + return datetime.now(timezone.utc).strftime(r"%Y%m%dT%H%M%SZ") + + +ts = now_timestamp_str() +checkpoint_dir = Path(__file__).parent / "checkpoints" / "pretrain_bge_retromae" / ts +mconf = BiencoderModelConfig( + name_or_path="BAAI/bge-m3-retromae", pooling="first_token", kwargs={"trust_remote_code": True} +) +dconf = ContrastivePretokenizedDataConfig( + filesystem="s3", + root_directory=DATA_PATH, + # filesystem="local", + # root_directory=DATA_PATH, + # Depending on how much GPU memory you have, you may need to split each + # batch into a number of smaller sub-batches by setting the split_factor. + # If you do so, you will probably want to decrease the learning rate accordingly. + # split_factor=4, + max_seq_length_query=32, + max_seq_length_doc=256, + eval_root_directories=EVAL_DATA_PATHS, + eval_max_seq_length_doc=32, + eval_max_seq_length_query=256, + pad_value=PAD_VALUE, + left_pad=LEFT_PAD, +) +sconf = WSDSchedulerConfig(num_warmup_steps=2000, num_decay_steps=2000) +oconf = OptimizerConfig(weight_decay=0.01, learning_rate=LEARNING_RATE) +lconf = LoggerConfig(level="INFO") +wconf = WandBConfig( + enable=True, + project="arctic-training-arctic-embed-testbed", + name=f"bge-m3-retromae-pretrain-{ts}", +) +# Reference: https://www.deepspeed.ai/training/#gradient-clipping +dsconf = { + "gradient_clipping": GRADIENT_CLIPPING, + "zero_optimization": {"stage": 1}, + # NOTE: The underlying DeepSpeed engine scales gradients down by a factor of + # `1/world_size`` in the backwards pass, so we pre-scale the loss up by a factor + # of `world_size`. Given these scalings, there is a potential for increased + # numerical imprecision when using low-precision floating point representation, + # so we set communication to fp32 in the backwards all-reduce to somewhat mitigate + # this risk. + "communication_data_type": "fp32", +} +cconf = CheckpointConfig( + output_dir=checkpoint_dir, + type="biencoder", + save_every_n_steps=300, + save_end_of_training=True, +) + + +def configure_non_distributed_distributed_training_if_needed() -> None: + """Detect if we need to manually initialize distributed training environment + and do so if needed. + + NOTE: We have to do this step because Arctic Training doesn't have a default + 1-GPU launching mode and will instead fall back to trying to auto-discover + distributed training configuration (e.g. via MPI). + """ + num_cli_args = len(sys.argv) - 1 + if num_cli_args == 0: + print("***No CLI args detected, configuring for single-GPU training.***") + from os import environ + + from torch import distributed as dist + + environ["MASTER_ADDR"] = "localhost" + environ["MASTER_PORT"] = "12335" + environ["LOCAL_RANK"] = "0" + dist.init_process_group(backend="nccl", world_size=1, rank=0) + + +if __name__ == "__main__": + CUDA_ALLOCATOR_CONFIG_FOR_DYNAMICALLY_SIZED_DATA.set_env() + configure_non_distributed_distributed_training_if_needed() + tconf = BiencoderTrainerConfig( + type="biencoder", + model=mconf, + data=dconf, + scheduler=sconf, + optimizer=oconf, + logger=lconf, + checkpoint=cconf, + wandb=wconf, + deepspeed=dsconf, + loss_log_interval=0, + eval_frequency=10, + use_in_batch_negatives=True, + loss_temperature=0.02, + overfit_first_batch=False, + mrl_dim=256, + ) + trainer = BiencoderTrainer(config=tconf) + trainer.train() diff --git a/projects/arctic_embed/examples/finetune_models/pretrain_splade.py b/projects/arctic_embed/examples/finetune_models/pretrain_splade.py new file mode 100644 index 00000000..4d159db7 --- /dev/null +++ b/projects/arctic_embed/examples/finetune_models/pretrain_splade.py @@ -0,0 +1,188 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This example shows how to use the Arctic Embed codebase to finetune +the venerable E5-base-v2 model (released in May 2023) on a version of MSMARCO +training data which has been hard-negative-mined using a more modern technique. + +The code needed to recreate the training data can be found in the sibling directory +`data_prep` within the `hard_negative_mining` subdirectory. + +Original model paper: https://arxiv.org/abs/2212.03533 +Model page: https://huggingface.co/intfloat/e5-base-v2 +Better negative mining paper: https://arxiv.org/abs/2407.15831 +""" +import os +import sys +from datetime import datetime +from datetime import timezone +from pathlib import Path + +from arctic_embed.biencoder_model_factory import BiencoderModelConfig +from arctic_embed.contrastive_dataloader import ContrastivePretokenizedDataConfig +from arctic_embed.core.cuda_allocator_config import CUDA_ALLOCATOR_CONFIG_FOR_DYNAMICALLY_SIZED_DATA +from arctic_embed.trainer import BiencoderTrainer +from arctic_embed.trainer import BiencoderTrainerConfig +from transformers import AutoTokenizer + +from arctic_training.config.checkpoint import CheckpointConfig +from arctic_training.config.logger import LoggerConfig +from arctic_training.config.optimizer import OptimizerConfig +from arctic_training.config.wandb import WandBConfig +from arctic_training.scheduler.wsd_factory import WSDSchedulerConfig + +LEARNING_RATE = 2e-5 +GRADIENT_CLIPPING = 10.0 +DATA_PATH = "s3://ml-dev-sfc-or-dev-misc1-k8s/cortexsearch/biencoder/pretrain_data_arctic_training_format/answerdotai_ModernBERT_base/combined_all_32768" +datasets = [ + "amazon_qa", + "ccnews_en_v1", + "faq", + "mc4_en_v1", + "mwiki_en_v1", + "paq", + "pes2o", + "red_pajama", + "red_pajamas_1t_stackexchange", + "s2orc_title_abstracts", + "snippets4", + "techrepo", + "top_stories", + "trivia_qa", + "wikipedia", +] +EVAL_DATA_PATHS = [ + f"s3://ml-dev-sfc-or-dev-misc1-k8s/cortexsearch/biencoder/pretrain_data_arctic_training_format/answerdotai_ModernBERT_base/combined_all_32768_eval/{dataset}" + for dataset in datasets +] + + +MODEL_NAME = "answerdotai/ModernBERT-base" +LEFT_PAD = False + + +def now_timestamp_str() -> str: + """Get the current ISO 8601 UTC timestamp.""" + return datetime.now(timezone.utc).strftime(r"%Y%m%dT%H%M%SZ") + + +ts = now_timestamp_str() +checkpoint_dir = Path(__file__).parent / "checkpoints" / "pretrain_mgte" / ts +mconf = BiencoderModelConfig( + name_or_path=MODEL_NAME, + pooling="splade", + kwargs={ + "trust_remote_code": True, + "unpad_inputs": True, + "use_memory_efficient_attention": True, + }, + dtype="bf16", + attn_implementation="flash_attention_2", +) +dconf = ContrastivePretokenizedDataConfig( + filesystem="s3", + root_directory=DATA_PATH, + # filesystem="local", + # root_directory=DATA_PATH, + # Depending on how much GPU memory you have, you may need to split each + # batch into a number of smaller sub-batches by setting the split_factor. + # If you do so, you will probably want to decrease the learning rate accordingly. + split_factor=8, + eval_split_factor=8, + max_seq_length_query=32, + max_seq_length_doc=256, + eval_root_directories=EVAL_DATA_PATHS, + eval_max_seq_length_doc=32, + eval_max_seq_length_query=256, + pad_value=AutoTokenizer.from_pretrained(MODEL_NAME).pad_token_id, + left_pad=LEFT_PAD, +) +sconf = WSDSchedulerConfig(num_warmup_steps=5000, num_decay_steps=5000) +oconf = OptimizerConfig(weight_decay=0.01, learning_rate=LEARNING_RATE) +lconf = LoggerConfig(level="INFO") +wconf = WandBConfig( + enable=True, + project="arctic-training-arctic-embed-testbed", + name=f"modernbert-pretrain-splade-{ts}", +) +# Reference: https://www.deepspeed.ai/training/#gradient-clipping +dsconf = { + "gradient_clipping": GRADIENT_CLIPPING, + "zero_optimization": {"stage": 1}, + # NOTE: The underlying DeepSpeed engine scales gradients down by a factor of + # `1/world_size`` in the backwards pass, so we pre-scale the loss up by a factor + # of `world_size`. Given these scalings, there is a potential for increased + # numerical imprecision when using low-precision floating point representation, + # so we set communication to fp32 in the backwards all-reduce to somewhat mitigate + # this risk. + "communication_data_type": "fp32", +} +cconf = CheckpointConfig( + output_dir=checkpoint_dir, + type="biencoder", + save_every_n_steps=300, + save_end_of_training=True, +) + + +def configure_non_distributed_distributed_training_if_needed() -> None: + """Detect if we need to manually initialize distributed training environment + and do so if needed. + + NOTE: We have to do this step because Arctic Training doesn't have a default + 1-GPU launching mode and will instead fall back to trying to auto-discover + distributed training configuration (e.g. via MPI). + """ + num_cli_args = len(sys.argv) - 1 + if num_cli_args == 0: + print("***No CLI args detected, configuring for single-GPU training.***") + from os import environ + + from torch import distributed as dist + + environ["MASTER_ADDR"] = "localhost" + environ["MASTER_PORT"] = "12335" + environ["LOCAL_RANK"] = "0" + dist.init_process_group(backend="nccl", world_size=1, rank=0) + + +if __name__ == "__main__": + CUDA_ALLOCATOR_CONFIG_FOR_DYNAMICALLY_SIZED_DATA.set_env() + configure_non_distributed_distributed_training_if_needed() + tconf = BiencoderTrainerConfig( + type="biencoder", + model=mconf, + data=dconf, + scheduler=sconf, + optimizer=oconf, + logger=lconf, + checkpoint=cconf, + wandb=wconf, + deepspeed=dsconf, + loss_log_interval=0, + eval_frequency=300, + use_in_batch_negatives=True, + loss_temperature=0.05, + overfit_first_batch=False, + mrl_dim=None, + splade_reg_weight=float(os.getenv("SPLADE_REG_WEIGHT", "0.0")), + # Per-side SPLADE v2 FLOPs regularizers. + splade_flops_weight_query=float(os.getenv("SPLADE_FLOPS_WEIGHT_QUERY", "1e-1")), + splade_flops_weight_doc=float(os.getenv("SPLADE_FLOPS_WEIGHT_DOC", "1e-4")), + splade_nnz_threshold=float(os.getenv("SPLADE_NNZ_THRESHOLD", "0")), + ) + trainer = BiencoderTrainer(config=tconf) + trainer.train() diff --git a/projects/arctic_embed/src/arctic_embed/biencoder_model_factory.py b/projects/arctic_embed/src/arctic_embed/biencoder_model_factory.py index e5bff672..338207c3 100644 --- a/projects/arctic_embed/src/arctic_embed/biencoder_model_factory.py +++ b/projects/arctic_embed/src/arctic_embed/biencoder_model_factory.py @@ -21,6 +21,7 @@ from peft.config import PeftConfig from transformers import AutoConfig from transformers import AutoModel +from transformers import AutoModelForMaskedLM from arctic_training.config.model import ModelConfig from arctic_training.model.factory import ModelFactory @@ -55,13 +56,26 @@ def create_model(self, model_config: AutoConfig) -> Biencoder: arctic_training_model_config = self.config assert isinstance(arctic_training_model_config, BiencoderModelConfig) trust_remote_code = arctic_training_model_config.kwargs.get("trust_remote_code", None) - encoder = AutoModel.from_pretrained( - self.config.name_or_path, - config=model_config, - attn_implementation=self.config.attn_implementation, - torch_dtype=self.config.dtype, - trust_remote_code=trust_remote_code, - ) + if arctic_training_model_config.pooling == "splade": + print(f"🤨 model_config: {model_config}") + print(f"🤨 self.config.attn_implementation: {self.config.attn_implementation}") + print(f"🤨 self.config.dtype.value: {self.config.dtype.value}") + print(f"🤨 trust_remote_code: {trust_remote_code}") + encoder = AutoModelForMaskedLM.from_pretrained( + self.config.name_or_path, + config=model_config, + attn_implementation=self.config.attn_implementation, + torch_dtype=self.config.dtype.value, + trust_remote_code=trust_remote_code, + ) + else: + encoder = AutoModel.from_pretrained( + self.config.name_or_path, + config=model_config, + attn_implementation=self.config.attn_implementation, + torch_dtype=self.config.dtype.value, + trust_remote_code=trust_remote_code, + ) return Biencoder(encoder, pooling=arctic_training_model_config.pooling) def post_create_model_callback(self, model: Biencoder): diff --git a/projects/arctic_embed/src/arctic_embed/contrastive_dataloader.py b/projects/arctic_embed/src/arctic_embed/contrastive_dataloader.py index ebb58784..7a2a4148 100644 --- a/projects/arctic_embed/src/arctic_embed/contrastive_dataloader.py +++ b/projects/arctic_embed/src/arctic_embed/contrastive_dataloader.py @@ -39,6 +39,8 @@ class ContrastivePretokenizedDataConfig(DataConfig): filesystem: FilesystemOption root_directory: str split_factor: int = 1 + pad_value: int = 0 + left_pad: bool = False sources: List[DataSourceConfig] = [] max_seq_length_query: Optional[int] = None max_seq_length_doc: Optional[int] = None @@ -62,6 +64,8 @@ def __call__(self) -> Tuple[DataLoader, Optional[Dict[str, DataLoader]]]: split_factor=self.config.split_factor, shard_id=self.global_rank, world_size=self.world_size, + pad_value=self.config.pad_value, + left_pad=self.config.left_pad, max_seq_len_query=self.config.max_seq_length_query, max_seq_len_doc=self.config.max_seq_length_doc, device=self.trainer.device, @@ -79,6 +83,8 @@ def __call__(self) -> Tuple[DataLoader, Optional[Dict[str, DataLoader]]]: split_factor=self.config.eval_split_factor, shard_id=self.global_rank, world_size=self.world_size, + pad_value=self.config.pad_value, + left_pad=self.config.left_pad, max_seq_len_query=self.config.eval_max_seq_length_query, max_seq_len_doc=self.config.eval_max_seq_length_doc, device=self.trainer.device, diff --git a/projects/arctic_embed/src/arctic_embed/core/biencoder_model.py b/projects/arctic_embed/src/arctic_embed/core/biencoder_model.py index 6ef38f81..450b7f53 100644 --- a/projects/arctic_embed/src/arctic_embed/core/biencoder_model.py +++ b/projects/arctic_embed/src/arctic_embed/core/biencoder_model.py @@ -29,7 +29,7 @@ from torch import nn from transformers import PreTrainedModel -PoolingOption = Literal["first_token", "last_token", "mean"] +PoolingOption = Literal["first_token", "last_token", "mean", "splade"] logger = logging.getLogger(__name__) @@ -41,9 +41,24 @@ def __init__(self, encoder: PreTrainedModel, pooling: PoolingOption = "first_tok super().__init__() self.encoder = encoder self.pooling = pooling + self.config = encoder.config + # Caches for SPLADE pooled weights from the most recent forward. + self._cached_query_pooled: Optional[Tensor] = None + self._cached_document_pooled: Optional[Tensor] = None def encode(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: out = self.encoder(input_ids=input_ids, attention_mask=attention_mask) + # SPLADE-style sparse pooling branches on logits instead of hidden states. + if self.pooling == "splade": + if not hasattr(out, "logits"): + raise ValueError( + f"Encoder of class {self.encoder.__class__} must output `logits` for SPLADE pooling." + ) + logits = out.logits # (batch, seq_len, vocab_size) + pooled = splade_pool(logits, attention_mask) + # NOTE: For SPLADE we avoid L2-normalization to preserve magnitude and sparsity. + return pooled.contiguous() + if not hasattr(out, "last_hidden_state"): raise ValueError( f"Encoder of class {self.encoder.__class__} is missing the " @@ -71,10 +86,56 @@ def forward( document_input_ids: Tensor, document_attention_mask: Tensor, ) -> Tuple[Tensor, Tensor]: - query_vectors = self.encode(query_input_ids, query_attention_mask) - document_vectors = self.encode(document_input_ids, document_attention_mask) + # Clear caches from any previous call + self._cached_query_pooled = None + self._cached_document_pooled = None + + if self.pooling == "splade": + # Run encoder to obtain logits for query and cache. + q_out = self.encoder(input_ids=query_input_ids, attention_mask=query_attention_mask) + if not hasattr(q_out, "logits"): + raise ValueError( + f"Encoder of class {self.encoder.__class__} must output `logits` for SPLADE pooling." + ) + self._cached_query_pooled = splade_pool(q_out.logits, query_attention_mask).contiguous() + query_vectors = self._cached_query_pooled + + # Run encoder to obtain logits for document and cache. + d_out = self.encoder(input_ids=document_input_ids, attention_mask=document_attention_mask) + if not hasattr(d_out, "logits"): + raise ValueError( + f"Encoder of class {self.encoder.__class__} must output `logits` for SPLADE pooling." + ) + self._cached_document_pooled = splade_pool(d_out.logits, document_attention_mask).contiguous() + document_vectors = self._cached_document_pooled + else: + query_vectors = self.encode(query_input_ids, query_attention_mask) + document_vectors = self.encode(document_input_ids, document_attention_mask) return query_vectors, document_vectors + def compute_splade_flops_doc_cached(self) -> Tensor: + """SPLADE v2 FLOPs proxy using cached pooled weights (doc side), Eq. (4). + + Let w_j(d_i) be the SPLADE weight of term j for sample i (after activation+pooling). + With N in-batch samples, Eq. (4): ℓ_FLOPS = Σ_j ( (1/N) Σ_i w_j(d_i) )^2. + """ + if self.pooling != "splade": + raise ValueError("FLOPs regularizer requested but pooling is not 'splade'.") + if self._cached_document_pooled is None: + raise RuntimeError("Document pooled cache is empty; call forward first.") + return _splade_flops_batch_mean_squared(self._cached_document_pooled) + + def compute_splade_flops_query_cached(self) -> Tensor: + """SPLADE v2 FLOPs proxy using cached pooled weights (query side), Eq. (4). + + Same formula as doc side; we apply a different scalar weight in the trainer. + """ + if self.pooling != "splade": + raise ValueError("FLOPs regularizer requested but pooling is not 'splade'.") + if self._cached_query_pooled is None: + raise RuntimeError("Query pooled cache is empty; call forward first.") + return _splade_flops_batch_mean_squared(self._cached_query_pooled) + def save_pretrained( self, save_directory: Union[str, os.PathLike], @@ -132,3 +193,34 @@ def last_token_pool(out: Tensor, attention_mask: Tensor) -> Tensor: row = torch.arange(batch_size, device=out.device) col = attention_mask.sum(dim=1) - 1 # position of the last non-padding token return out[row, col, ...] + + +def splade_pool(logits: Tensor, attention_mask: Tensor) -> Tensor: + """SPLADE-style sparse pooling over vocabulary logits. + + - Expects `logits` with shape (batch, token, vocab_size) + - Applies activation: log(1 + relu(logits)) + - Aggregates over tokens with masked max + """ + assert logits.ndim == 3 + assert attention_mask.ndim == 2 + # Activation as in SPLADE + activated = torch.log1p(F.relu(logits)) + # Mask out padding positions so they do not affect max + mask = ~attention_mask[..., None].bool() + activated = activated.masked_fill(mask, float("-inf")) + # Max over sequence dimension + pooled = torch.amax(activated, dim=1) + # Replace -inf (for fully-masked sequences) with zeros to avoid nans downstream + pooled = torch.where(torch.isinf(pooled), torch.zeros_like(pooled), pooled) + return pooled + +def _splade_flops_batch_mean_squared(pooled_weights: Tensor) -> Tensor: + """Compute Eq. (4): || (1/N) Σ_i w(d_i) ||_2^2, where w(d_i) are pooled SPLADE weights. + + pooled_weights: (batch, vocab_size) + returns: scalar Tensor (batch-mean vector squared L2 norm) + """ + assert pooled_weights.ndim == 2 + mean_vec = pooled_weights.mean(dim=0) + return (mean_vec.pow(2).sum()) diff --git a/projects/arctic_embed/src/arctic_embed/trainer.py b/projects/arctic_embed/src/arctic_embed/trainer.py index 6542df12..dae96668 100644 --- a/projects/arctic_embed/src/arctic_embed/trainer.py +++ b/projects/arctic_embed/src/arctic_embed/trainer.py @@ -15,6 +15,7 @@ from __future__ import annotations +import os from typing import Callable from typing import Dict from typing import List @@ -57,6 +58,11 @@ class BiencoderTrainerConfig(TrainerConfig): data: ContrastivePretokenizedDataConfig mrl_dim: Optional[int] = None eval_interval: Optional[int] = None + # SPLADE regularization (v1 L1 is deprecated; use FLOPs per side instead) + splade_reg_weight: float = 0.0 # deprecated; kept for backward-compat log + splade_flops_weight_query: float = 0.0 + splade_flops_weight_doc: float = 0.0 + splade_nnz_threshold: float = 1e-3 class FakeTokenizer: @@ -138,6 +144,25 @@ def eval_and_log_cb(self: BiencoderTrainer) -> None: logger.info(f"Global Step: {self.global_step}/{self.training_horizon} Eval: {metrics}") +def splade_flops_warmup_cb(self: BiencoderTrainer) -> None: + """Pre-step callback to linearly warm up SPLADE FLOPs weights.""" + if not hasattr(self.config, "_splade_flops_warmup_steps"): + # Initialize warmup config on first call + self.config._splade_flops_warmup_steps = int(os.environ.get("SPLADE_FLOPS_WARMUP_STEPS", "2000")) + self.config._splade_flops_weight_query_target = self.config.splade_flops_weight_query + self.config._splade_flops_weight_doc_target = self.config.splade_flops_weight_doc + + if self.global_step < self.config._splade_flops_warmup_steps: + # Linear warmup from 0 to target + warmup_factor = self.global_step / self.config._splade_flops_warmup_steps + self.config.splade_flops_weight_query = warmup_factor * self.config._splade_flops_weight_query_target + self.config.splade_flops_weight_doc = warmup_factor * self.config._splade_flops_weight_doc_target + else: + # Restore target weights after warmup + self.config.splade_flops_weight_query = self.config._splade_flops_weight_query_target + self.config.splade_flops_weight_doc = self.config._splade_flops_weight_doc_target + + class BiencoderTrainer(Trainer): name = "biencoder" config: BiencoderTrainerConfig @@ -150,6 +175,7 @@ class BiencoderTrainer(Trainer): count_total_queries_seen: int = 0 count_total_documents_seen: int = 0 callbacks: List[Tuple[str, Callable]] = [ + ("pre-step", splade_flops_warmup_cb), ("post-backward", rescale_grad_cb), ("post-step", log_grad_norm_cb), ("post-step", eval_and_log_cb), @@ -208,11 +234,24 @@ def forward_and_gather(self, batch: ContrastiveLearningBatch) -> Tuple[Tensor, T @torch.no_grad() def eval(self, batch: ContrastiveLearningBatch) -> Dict[str, float]: query_embeddings, document_embeddings, relations = self.forward_and_gather(batch) - q_emb = F.normalize(query_embeddings, dim=1) - d_emb = F.normalize(document_embeddings, dim=1) - scores = torch.matmul(q_emb, d_emb.transpose(0, 1)) + # Access underlying Biencoder if wrapped by DeepSpeed + model_core = getattr(self.model, "module", self.model) + if self.config.use_in_batch_negatives: + relations[relations == 0] = -1 + if isinstance(model_core, Biencoder) and model_core.pooling == "splade": + scores = torch.matmul(query_embeddings, document_embeddings.transpose(0, 1)) + else: + q_emb = F.normalize(query_embeddings, dim=1) + d_emb = F.normalize(document_embeddings, dim=1) + scores = torch.matmul(q_emb, d_emb.transpose(0, 1)) loss_infonce = info_nce_loss(scores, relations=relations, temperature=self.config.loss_temperature).item() - return {"infoNCE": loss_infonce} + metrics = {"infoNCE": loss_infonce} + if isinstance(model_core, Biencoder) and model_core.pooling == "splade": + thr = self.config.splade_nnz_threshold + avg_terms_q = (query_embeddings > thr).to(torch.float32).sum(dim=1).mean().item() + avg_terms_d = (document_embeddings > thr).to(torch.float32).sum(dim=1).mean().item() + metrics.update({"avg_terms_query": avg_terms_q, "avg_terms_doc": avg_terms_d}) + return metrics def loss(self, batch: ContrastiveLearningBatch) -> Tensor: # Count the number of queries and documents seen in this batch. @@ -224,16 +263,45 @@ def loss(self, batch: ContrastiveLearningBatch) -> Tensor: # Forward pass, gathering embeddings so each GPU has the full picture. query_embeddings, document_embeddings, relations = self.forward_and_gather(batch) - # InfoNCE loss with Matryoshka Representation Learning (MRL). + # InfoNCE loss (SPLADE uses raw dot product; dense uses MRL path). + model_core = getattr(self.model, "module", self.model) if self.config.use_in_batch_negatives: relations[relations == 0] = -1 - loss, loss_base, loss_truncated = one_size_truncated_mrl_info_nce_loss( - query_embeddings=query_embeddings, - document_embeddings=document_embeddings, - relations=relations, - truncated_dimension=self.config.mrl_dim, - temperature=self.config.loss_temperature, - ) + if isinstance(model_core, Biencoder) and model_core.pooling == "splade": + scores = torch.matmul(query_embeddings, document_embeddings.transpose(0, 1)) + loss = info_nce_loss(scores, relations=relations, temperature=self.config.loss_temperature) + loss_base = loss + loss_truncated = None + else: + loss, loss_base, loss_truncated = one_size_truncated_mrl_info_nce_loss( + query_embeddings=query_embeddings, + document_embeddings=document_embeddings, + relations=relations, + truncated_dimension=self.config.mrl_dim, + temperature=self.config.loss_temperature, + ) + + # Remove L1 regularizer in favor of SPLADE v2 FLOPs regularizers per side. + # Keep code path off unless legacy weight is set. + if ( + isinstance(model_core, Biencoder) + and model_core.pooling == "splade" + and self.config.splade_reg_weight > 0 + ): + # No-op by default; legacy users can still get old behavior if they set this. + reg_q = query_embeddings.abs().sum(dim=1).mean() + reg_d = document_embeddings.abs().sum(dim=1).mean() + loss_reg = self.config.splade_reg_weight * (reg_q + reg_d) + loss = loss + loss_reg + + # SPLADE v2 FLOPs regularization with separate query and document terms. + if isinstance(model_core, Biencoder) and model_core.pooling == "splade": + if self.config.splade_flops_weight_query > 0: + flops_q = model_core.compute_splade_flops_query_cached() + loss = loss + self.config.splade_flops_weight_query * flops_q + if self.config.splade_flops_weight_doc > 0: + flops_d = model_core.compute_splade_flops_doc_cached() + loss = loss + self.config.splade_flops_weight_doc * flops_d # Weights and Biases logging. # NOTE: We log more than is feasible to do in a callback, so we do it here @@ -250,9 +318,37 @@ def loss(self, batch: ContrastiveLearningBatch) -> Tensor: "train/batch_size_doc": global_batch_size_doc, "train/loss_no_truncate": loss_base.item(), } + # SPLADE average active terms per query/doc for this step. + if isinstance(model_core, Biencoder) and model_core.pooling == "splade": + thr = self.config.splade_nnz_threshold + metrics["train/avg_terms_query"] = ( + (query_embeddings > thr).to(torch.float32).sum(dim=1).mean().item() + ) + metrics["train/avg_terms_doc"] = ( + (document_embeddings > thr).to(torch.float32).sum(dim=1).mean().item() + ) if loss_truncated is not None: truncated_loss_name = f"train/loss_truncate_{self.config.mrl_dim}" metrics[truncated_loss_name] = loss_truncated.item() + if isinstance(model_core, Biencoder) and model_core.pooling == "splade": + if self.config.splade_reg_weight > 0: + metrics["train/splade_l1_loss"] = loss_reg.item() + metrics["train/splade_l1_loss_weight"] = self.config.splade_reg_weight + else: + metrics["train/splade_l1_loss"] = 0.0 + metrics["train/splade_l1_loss_weight"] = 0.0 + if self.config.splade_flops_weight_query > 0: + metrics["train/splade_query_flops_loss"] = flops_q.item() + metrics["train/splade_query_flops_loss_weight"] = self.config.splade_flops_weight_query + else: + metrics["train/splade_query_flops_loss"] = 0.0 + metrics["train/splade_query_flops_loss_weight"] = 0.0 + if self.config.splade_flops_weight_doc > 0: + metrics["train/splade_doc_flops_loss"] = flops_d.item() + metrics["train/splade_doc_flops_loss_weight"] = self.config.splade_flops_weight_doc + else: + metrics["train/splade_doc_flops_loss"] = 0.0 + metrics["train/splade_doc_flops_loss_weight"] = 0.0 wandb.log(metrics, step=self.global_step) return loss