From cb70bb02b5bb2a70dc8bf2c90763e57b6254e9b2 Mon Sep 17 00:00:00 2001 From: Puxuan Yu Date: Mon, 30 Jun 2025 14:44:30 -0700 Subject: [PATCH 01/11] patch arctic embed finetuning examples --- .../examples/finetune_models/finetune_e5_base_unsupervised.py | 2 +- .../arctic_embed/src/arctic_embed/biencoder_model_factory.py | 2 +- projects/arctic_embed/src/arctic_embed/core/biencoder_model.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/projects/arctic_embed/examples/finetune_models/finetune_e5_base_unsupervised.py b/projects/arctic_embed/examples/finetune_models/finetune_e5_base_unsupervised.py index ff246a43..f335f309 100644 --- a/projects/arctic_embed/examples/finetune_models/finetune_e5_base_unsupervised.py +++ b/projects/arctic_embed/examples/finetune_models/finetune_e5_base_unsupervised.py @@ -72,7 +72,7 @@ def now_timestamp_str() -> str: eval_max_seq_length_doc=512, eval_max_seq_length_query=512, ) -sconf = WSDSchedulerConfig(num_warmup_steps=500, num_decay_steps=1_000, learning_rate=LEARNING_RATE) +sconf = WSDSchedulerConfig(num_warmup_steps=500, num_decay_steps=1_000) oconf = OptimizerConfig(weight_decay=0.01, learning_rate=LEARNING_RATE) lconf = LoggerConfig(level="INFO") wconf = WandBConfig( diff --git a/projects/arctic_embed/src/arctic_embed/biencoder_model_factory.py b/projects/arctic_embed/src/arctic_embed/biencoder_model_factory.py index e5bff672..80c2db87 100644 --- a/projects/arctic_embed/src/arctic_embed/biencoder_model_factory.py +++ b/projects/arctic_embed/src/arctic_embed/biencoder_model_factory.py @@ -59,7 +59,7 @@ def create_model(self, model_config: AutoConfig) -> Biencoder: self.config.name_or_path, config=model_config, attn_implementation=self.config.attn_implementation, - torch_dtype=self.config.dtype, + torch_dtype=self.config.dtype.value, trust_remote_code=trust_remote_code, ) return Biencoder(encoder, pooling=arctic_training_model_config.pooling) diff --git a/projects/arctic_embed/src/arctic_embed/core/biencoder_model.py b/projects/arctic_embed/src/arctic_embed/core/biencoder_model.py index 6ef38f81..a0464a86 100644 --- a/projects/arctic_embed/src/arctic_embed/core/biencoder_model.py +++ b/projects/arctic_embed/src/arctic_embed/core/biencoder_model.py @@ -41,6 +41,7 @@ def __init__(self, encoder: PreTrainedModel, pooling: PoolingOption = "first_tok super().__init__() self.encoder = encoder self.pooling = pooling + self.config = encoder.config def encode(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: out = self.encoder(input_ids=input_ids, attention_mask=attention_mask) From 694799a9e47722e352eec6f1c3c053154474d318 Mon Sep 17 00:00:00 2001 From: Puxuan Yu Date: Mon, 30 Jun 2025 15:12:46 -0700 Subject: [PATCH 02/11] update instructions on data downloading --- projects/arctic_embed/examples/finetune_models/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/arctic_embed/examples/finetune_models/README.md b/projects/arctic_embed/examples/finetune_models/README.md index a936b82d..425ee068 100644 --- a/projects/arctic_embed/examples/finetune_models/README.md +++ b/projects/arctic_embed/examples/finetune_models/README.md @@ -17,7 +17,7 @@ GIT_LFS_SKIP_SMUDGE=1 git clone https://hf.co/datasets/Snowflake/arctic-embed-ft mv ./data.gitignore ./data/.gitignore # Ensure we have all the files you need for training downloaded from LFS. -cd arctic-embed-ft-v1/ +cd data/ git lfs pull --include="combined/pretokenized/example_dot95/,eval/" # Optional: Download more large files (e.g. everything but the very large precomputed embeddings). From 8583ebfcef393321aaca01cd83576a9ac6286299 Mon Sep 17 00:00:00 2001 From: Puxuan Yu Date: Mon, 7 Jul 2025 13:33:05 -0700 Subject: [PATCH 03/11] make metrics calculation compatible with arctic embed training --- arctic_training/trainer/trainer.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/arctic_training/trainer/trainer.py b/arctic_training/trainer/trainer.py index 6b497b47..7007e18a 100644 --- a/arctic_training/trainer/trainer.py +++ b/arctic_training/trainer/trainer.py @@ -400,10 +400,20 @@ def epoch(self) -> None: # deal correctly with packed samples under FA2, by calculating each seqlen tflos separately sample_seqlens = batch.pop("packed_sample_seqlens") else: - sample_seqlens = [ - [len(batch["input_ids"][idx]) * self.config.sequence_parallel_size] - for idx in range(len(batch["input_ids"])) - ] + if "input_ids" not in batch: + # batch is a ContrastiveLearningBatch + sample_seqlens = [ + [ + (len(batch.query_tokens[idx]) + len(batch.document_tokens[idx])) + * self.config.sequence_parallel_size + ] + for idx in range(len(batch.query_tokens)) + ] + else: + sample_seqlens = [ + [len(batch["input_ids"][idx]) * self.config.sequence_parallel_size] + for idx in range(len(batch["input_ids"])) + ] self.metrics.seqlens = sample_seqlens self.metrics.start_timer("step") From 9ba3575b648a9ea0ed0003f54e1c5c04ba83024f Mon Sep 17 00:00:00 2001 From: Puxuan Yu Date: Wed, 23 Jul 2025 10:56:56 -0700 Subject: [PATCH 04/11] fix pretraining errors --- .../finetune_models/pretrain_bge_retromae.py | 181 ++++++++++++++++++ .../arctic_embed/src/arctic_embed/trainer.py | 4 + 2 files changed, 185 insertions(+) create mode 100644 projects/arctic_embed/examples/finetune_models/pretrain_bge_retromae.py diff --git a/projects/arctic_embed/examples/finetune_models/pretrain_bge_retromae.py b/projects/arctic_embed/examples/finetune_models/pretrain_bge_retromae.py new file mode 100644 index 00000000..07b09bbc --- /dev/null +++ b/projects/arctic_embed/examples/finetune_models/pretrain_bge_retromae.py @@ -0,0 +1,181 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This example shows how to use the Arctic Embed codebase to finetune +the venerable E5-base-v2 model (released in May 2023) on a version of MSMARCO +training data which has been hard-negative-mined using a more modern technique. + +The code needed to recreate the training data can be found in the sibling directory +`data_prep` within the `hard_negative_mining` subdirectory. + +Original model paper: https://arxiv.org/abs/2212.03533 +Model page: https://huggingface.co/intfloat/e5-base-v2 +Better negative mining paper: https://arxiv.org/abs/2407.15831 +""" +import sys +from datetime import datetime +from datetime import timezone +from pathlib import Path + +from arctic_embed.biencoder_model_factory import BiencoderModelConfig +from arctic_embed.contrastive_dataloader import ContrastivePretokenizedDataConfig +from arctic_embed.core.cuda_allocator_config import CUDA_ALLOCATOR_CONFIG_FOR_DYNAMICALLY_SIZED_DATA +from arctic_embed.trainer import BiencoderTrainer +from arctic_embed.trainer import BiencoderTrainerConfig + +from arctic_training.config.checkpoint import CheckpointConfig +from arctic_training.config.logger import LoggerConfig +from arctic_training.config.optimizer import OptimizerConfig +from arctic_training.config.wandb import WandBConfig +from arctic_training.scheduler.wsd_factory import WSDSchedulerConfig + +LEARNING_RATE = 3e-5 +GRADIENT_CLIPPING = 10.0 +# DATA_PATH = str(Path(__file__).parent / "data" / "pretrain_amazonqa" / "batched_16384") +DATA_PATH = "s3://ml-dev-sfc-or-dev-misc1-k8s/cortexsearch/biencoder/pretrain_data_arctic_training_format/combined_all_16384" +# EVAL_DATA_PATHS = [str(path) for path in (Path(__file__).parent / "data" / "eval").iterdir() if path.is_dir()] # fix this +datasets = [ + "amazon_qa", + "ccnews_de_v1", + "ccnews_en_v1", + "ccnews_es_v1", + "ccnews_fr_v1", + "ccnews_it_v1", + "ccnews_pl_v1", + "ccnews_pt_v1", + "faq", + "mc4_de_v1", + "mc4_en_v1", + "mc4_es_v1", + "mc4_fr_v1", + "mc4_it_v1", + "mc4_pl_v1", + "mc4_pt_v1", + "mwiki_de_v1", + "mwiki_en_v1", + "mwiki_es_v1", + "mwiki_fr_v1", + "mwiki_it_v1", + "mwiki_pl_v1", + "mwiki_pt_v1", + "paq", + "pes2o", + "red_pajama", + "red_pajamas_1t_stackexchange", + "s2orc_title_abstracts", + "snippets4", + "techrepo", + "top_stories", + "trivia_qa", + "wikipedia", +] +EVAL_DATA_PATHS = [f"s3://ml-dev-sfc-or-dev-misc1-k8s/cortexsearch/biencoder/pretrain_data_arctic_training_format/combined_all_16384_eval/{dataset}" for dataset in datasets] + + +def now_timestamp_str() -> str: + """Get the current ISO 8601 UTC timestamp.""" + return datetime.now(timezone.utc).strftime(r"%Y%m%dT%H%M%SZ") + + +ts = now_timestamp_str() +checkpoint_dir = Path(__file__).parent / "checkpoints" / "pretrain_bge_retromae" / ts +mconf = BiencoderModelConfig(name_or_path="BAAI/bge-m3-retromae", pooling="first_token", kwargs={"trust_remote_code": True}) +dconf = ContrastivePretokenizedDataConfig( + filesystem="s3", + root_directory=DATA_PATH, + # filesystem="local", + # root_directory=DATA_PATH, + # Depending on how much GPU memory you have, you may need to split each + # batch into a number of smaller sub-batches by setting the split_factor. + # If you do so, you will probably want to decrease the learning rate accordingly. + # split_factor=4, + max_seq_length_query=32, + max_seq_length_doc=256, + eval_root_directories=EVAL_DATA_PATHS, + eval_max_seq_length_doc=32, + eval_max_seq_length_query=256, +) +sconf = WSDSchedulerConfig(num_warmup_steps=2000, num_decay_steps=2000) +oconf = OptimizerConfig(weight_decay=0.01, learning_rate=LEARNING_RATE) +lconf = LoggerConfig(level="INFO") +wconf = WandBConfig( + enable=True, + project="arctic-training-arctic-embed-testbed", + name=f"bge-m3-retromae-pretrain-{ts}", +) +# Reference: https://www.deepspeed.ai/training/#gradient-clipping +dsconf = { + "gradient_clipping": GRADIENT_CLIPPING, + "zero_optimization": {"stage": 1}, + # NOTE: The underlying DeepSpeed engine scales gradients down by a factor of + # `1/world_size`` in the backwards pass, so we pre-scale the loss up by a factor + # of `world_size`. Given these scalings, there is a potential for increased + # numerical imprecision when using low-precision floating point representation, + # so we set communication to fp32 in the backwards all-reduce to somewhat mitigate + # this risk. + "communication_data_type": "fp32", +} +cconf = CheckpointConfig( + output_dir=checkpoint_dir, + type="biencoder", + save_every_n_steps=300, + save_end_of_training=True, +) + + +def configure_non_distributed_distributed_training_if_needed() -> None: + """Detect if we need to manually initialize distributed training environment + and do so if needed. + + NOTE: We have to do this step because Arctic Training doesn't have a default + 1-GPU launching mode and will instead fall back to trying to auto-discover + distributed training configuration (e.g. via MPI). + """ + num_cli_args = len(sys.argv) - 1 + if num_cli_args == 0: + print("***No CLI args detected, configuring for single-GPU training.***") + from os import environ + + from torch import distributed as dist + + environ["MASTER_ADDR"] = "localhost" + environ["MASTER_PORT"] = "12335" + environ["LOCAL_RANK"] = "0" + dist.init_process_group(backend="nccl", world_size=1, rank=0) + + +if __name__ == "__main__": + CUDA_ALLOCATOR_CONFIG_FOR_DYNAMICALLY_SIZED_DATA.set_env() + configure_non_distributed_distributed_training_if_needed() + tconf = BiencoderTrainerConfig( + type="biencoder", + model=mconf, + data=dconf, + scheduler=sconf, + optimizer=oconf, + logger=lconf, + checkpoint=cconf, + wandb=wconf, + deepspeed=dsconf, + loss_log_interval=0, + eval_frequency=10, + use_in_batch_negatives=True, + loss_temperature=0.02, + overfit_first_batch=False, + mrl_dim=256 + ) + trainer = BiencoderTrainer(config=tconf) + trainer.train() diff --git a/projects/arctic_embed/src/arctic_embed/trainer.py b/projects/arctic_embed/src/arctic_embed/trainer.py index 6542df12..beac66ff 100644 --- a/projects/arctic_embed/src/arctic_embed/trainer.py +++ b/projects/arctic_embed/src/arctic_embed/trainer.py @@ -127,7 +127,9 @@ def eval_and_log_cb(self: BiencoderTrainer) -> None: for eval_batch in tqdm(eval_loader, desc=f"eval/{eval_name}", unit="batch"): eval_metrics = self.eval(eval_batch) em_list.append(eval_metrics) + print(f"inside eval_and_log_cb, em_list: {em_list}") avg_metrics = {f"eval/{eval_name}/{k}": sum(em[k] for em in em_list) / len(em_list) for k in eval_metrics} + print(f"inside eval_and_log_cb, avg_metrics: {avg_metrics}") metrics.update(avg_metrics) finally: self.model.train(mode=initial_train_mode) @@ -208,6 +210,8 @@ def forward_and_gather(self, batch: ContrastiveLearningBatch) -> Tuple[Tensor, T @torch.no_grad() def eval(self, batch: ContrastiveLearningBatch) -> Dict[str, float]: query_embeddings, document_embeddings, relations = self.forward_and_gather(batch) + if self.config.use_in_batch_negatives: + relations[relations == 0] = -1 q_emb = F.normalize(query_embeddings, dim=1) d_emb = F.normalize(document_embeddings, dim=1) scores = torch.matmul(q_emb, d_emb.transpose(0, 1)) From 91cc631abdb64d4c5d509d7d385e7219d0dc50b4 Mon Sep 17 00:00:00 2001 From: Puxuan Yu Date: Mon, 4 Aug 2025 13:34:11 -0700 Subject: [PATCH 05/11] better customization for different base models --- arctic_training/trainer/trainer.py | 7 ++----- .../finetune_models/pretrain_bge_retromae.py | 20 +++++++++++++++---- .../arctic_embed/contrastive_dataloader.py | 6 ++++++ .../arctic_embed/src/arctic_embed/trainer.py | 2 -- 4 files changed, 24 insertions(+), 11 deletions(-) diff --git a/arctic_training/trainer/trainer.py b/arctic_training/trainer/trainer.py index 7007e18a..21f927b2 100644 --- a/arctic_training/trainer/trainer.py +++ b/arctic_training/trainer/trainer.py @@ -403,11 +403,8 @@ def epoch(self) -> None: if "input_ids" not in batch: # batch is a ContrastiveLearningBatch sample_seqlens = [ - [ - (len(batch.query_tokens[idx]) + len(batch.document_tokens[idx])) - * self.config.sequence_parallel_size - ] - for idx in range(len(batch.query_tokens)) + [len(batch.query_tokens[idx]) * self.config.sequence_parallel_size] + for idx in range(batch.query_tokens.shape[0]) ] else: sample_seqlens = [ diff --git a/projects/arctic_embed/examples/finetune_models/pretrain_bge_retromae.py b/projects/arctic_embed/examples/finetune_models/pretrain_bge_retromae.py index 07b09bbc..0e8d113e 100644 --- a/projects/arctic_embed/examples/finetune_models/pretrain_bge_retromae.py +++ b/projects/arctic_embed/examples/finetune_models/pretrain_bge_retromae.py @@ -45,7 +45,9 @@ LEARNING_RATE = 3e-5 GRADIENT_CLIPPING = 10.0 # DATA_PATH = str(Path(__file__).parent / "data" / "pretrain_amazonqa" / "batched_16384") -DATA_PATH = "s3://ml-dev-sfc-or-dev-misc1-k8s/cortexsearch/biencoder/pretrain_data_arctic_training_format/combined_all_16384" +DATA_PATH = ( + "s3://ml-dev-sfc-or-dev-misc1-k8s/cortexsearch/biencoder/pretrain_data_arctic_training_format/combined_all_16384" +) # EVAL_DATA_PATHS = [str(path) for path in (Path(__file__).parent / "data" / "eval").iterdir() if path.is_dir()] # fix this datasets = [ "amazon_qa", @@ -82,7 +84,15 @@ "trivia_qa", "wikipedia", ] -EVAL_DATA_PATHS = [f"s3://ml-dev-sfc-or-dev-misc1-k8s/cortexsearch/biencoder/pretrain_data_arctic_training_format/combined_all_16384_eval/{dataset}" for dataset in datasets] +EVAL_DATA_PATHS = [ + f"s3://ml-dev-sfc-or-dev-misc1-k8s/cortexsearch/biencoder/pretrain_data_arctic_training_format/combined_all_16384_eval/{dataset}" + for dataset in datasets +] +# from transformers import AutoTokenizer +# tok = AutoTokenizer.from_pretrained("BAAI/bge-m3-retromae") +# tok.pad_token_id --> 1 +PAD_VALUE = 1 +LEFT_PAD = False def now_timestamp_str() -> str: @@ -92,7 +102,9 @@ def now_timestamp_str() -> str: ts = now_timestamp_str() checkpoint_dir = Path(__file__).parent / "checkpoints" / "pretrain_bge_retromae" / ts -mconf = BiencoderModelConfig(name_or_path="BAAI/bge-m3-retromae", pooling="first_token", kwargs={"trust_remote_code": True}) +mconf = BiencoderModelConfig( + name_or_path="BAAI/bge-m3-retromae", pooling="first_token", kwargs={"trust_remote_code": True} +) dconf = ContrastivePretokenizedDataConfig( filesystem="s3", root_directory=DATA_PATH, @@ -175,7 +187,7 @@ def configure_non_distributed_distributed_training_if_needed() -> None: use_in_batch_negatives=True, loss_temperature=0.02, overfit_first_batch=False, - mrl_dim=256 + mrl_dim=256, ) trainer = BiencoderTrainer(config=tconf) trainer.train() diff --git a/projects/arctic_embed/src/arctic_embed/contrastive_dataloader.py b/projects/arctic_embed/src/arctic_embed/contrastive_dataloader.py index ebb58784..7a2a4148 100644 --- a/projects/arctic_embed/src/arctic_embed/contrastive_dataloader.py +++ b/projects/arctic_embed/src/arctic_embed/contrastive_dataloader.py @@ -39,6 +39,8 @@ class ContrastivePretokenizedDataConfig(DataConfig): filesystem: FilesystemOption root_directory: str split_factor: int = 1 + pad_value: int = 0 + left_pad: bool = False sources: List[DataSourceConfig] = [] max_seq_length_query: Optional[int] = None max_seq_length_doc: Optional[int] = None @@ -62,6 +64,8 @@ def __call__(self) -> Tuple[DataLoader, Optional[Dict[str, DataLoader]]]: split_factor=self.config.split_factor, shard_id=self.global_rank, world_size=self.world_size, + pad_value=self.config.pad_value, + left_pad=self.config.left_pad, max_seq_len_query=self.config.max_seq_length_query, max_seq_len_doc=self.config.max_seq_length_doc, device=self.trainer.device, @@ -79,6 +83,8 @@ def __call__(self) -> Tuple[DataLoader, Optional[Dict[str, DataLoader]]]: split_factor=self.config.eval_split_factor, shard_id=self.global_rank, world_size=self.world_size, + pad_value=self.config.pad_value, + left_pad=self.config.left_pad, max_seq_len_query=self.config.eval_max_seq_length_query, max_seq_len_doc=self.config.eval_max_seq_length_doc, device=self.trainer.device, diff --git a/projects/arctic_embed/src/arctic_embed/trainer.py b/projects/arctic_embed/src/arctic_embed/trainer.py index beac66ff..54a825e2 100644 --- a/projects/arctic_embed/src/arctic_embed/trainer.py +++ b/projects/arctic_embed/src/arctic_embed/trainer.py @@ -127,9 +127,7 @@ def eval_and_log_cb(self: BiencoderTrainer) -> None: for eval_batch in tqdm(eval_loader, desc=f"eval/{eval_name}", unit="batch"): eval_metrics = self.eval(eval_batch) em_list.append(eval_metrics) - print(f"inside eval_and_log_cb, em_list: {em_list}") avg_metrics = {f"eval/{eval_name}/{k}": sum(em[k] for em in em_list) / len(em_list) for k in eval_metrics} - print(f"inside eval_and_log_cb, avg_metrics: {avg_metrics}") metrics.update(avg_metrics) finally: self.model.train(mode=initial_train_mode) From d2094d3728253d78663d9be0714c9bdf4db42b9a Mon Sep 17 00:00:00 2001 From: Puxuan Yu Date: Mon, 4 Aug 2025 13:37:51 -0700 Subject: [PATCH 06/11] better customization for different base models --- .../examples/finetune_models/pretrain_bge_retromae.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/projects/arctic_embed/examples/finetune_models/pretrain_bge_retromae.py b/projects/arctic_embed/examples/finetune_models/pretrain_bge_retromae.py index 0e8d113e..3fb9418e 100644 --- a/projects/arctic_embed/examples/finetune_models/pretrain_bge_retromae.py +++ b/projects/arctic_embed/examples/finetune_models/pretrain_bge_retromae.py @@ -119,6 +119,8 @@ def now_timestamp_str() -> str: eval_root_directories=EVAL_DATA_PATHS, eval_max_seq_length_doc=32, eval_max_seq_length_query=256, + pad_value=PAD_VALUE, + left_pad=LEFT_PAD, ) sconf = WSDSchedulerConfig(num_warmup_steps=2000, num_decay_steps=2000) oconf = OptimizerConfig(weight_decay=0.01, learning_rate=LEARNING_RATE) From 128fd8d33a10511788551f57577dcae0e6af5547 Mon Sep 17 00:00:00 2001 From: Puxuan Yu Date: Mon, 25 Aug 2025 14:24:25 -0700 Subject: [PATCH 07/11] wip --- .../examples/finetune_models/pretrain_mgte.py | 205 ++++++++++++++++++ .../arctic_embed/biencoder_model_factory.py | 24 +- .../src/arctic_embed/core/biencoder_model.py | 57 ++++- .../arctic_embed/src/arctic_embed/trainer.py | 69 +++++- 4 files changed, 335 insertions(+), 20 deletions(-) create mode 100644 projects/arctic_embed/examples/finetune_models/pretrain_mgte.py diff --git a/projects/arctic_embed/examples/finetune_models/pretrain_mgte.py b/projects/arctic_embed/examples/finetune_models/pretrain_mgte.py new file mode 100644 index 00000000..58723f1c --- /dev/null +++ b/projects/arctic_embed/examples/finetune_models/pretrain_mgte.py @@ -0,0 +1,205 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This example shows how to use the Arctic Embed codebase to finetune +the venerable E5-base-v2 model (released in May 2023) on a version of MSMARCO +training data which has been hard-negative-mined using a more modern technique. + +The code needed to recreate the training data can be found in the sibling directory +`data_prep` within the `hard_negative_mining` subdirectory. + +Original model paper: https://arxiv.org/abs/2212.03533 +Model page: https://huggingface.co/intfloat/e5-base-v2 +Better negative mining paper: https://arxiv.org/abs/2407.15831 +""" +import sys +from datetime import datetime +from datetime import timezone +from pathlib import Path +import os + +from arctic_embed.biencoder_model_factory import BiencoderModelConfig +from arctic_embed.contrastive_dataloader import ContrastivePretokenizedDataConfig +from arctic_embed.core.cuda_allocator_config import CUDA_ALLOCATOR_CONFIG_FOR_DYNAMICALLY_SIZED_DATA +from arctic_embed.trainer import BiencoderTrainer +from arctic_embed.trainer import BiencoderTrainerConfig + +from arctic_training.config.checkpoint import CheckpointConfig +from arctic_training.config.logger import LoggerConfig +from arctic_training.config.optimizer import OptimizerConfig +from arctic_training.config.wandb import WandBConfig +from arctic_training.scheduler.wsd_factory import WSDSchedulerConfig + +LEARNING_RATE = 1e-4 +GRADIENT_CLIPPING = 10.0 +# DATA_PATH = str(Path(__file__).parent / "data" / "pretrain_amazonqa" / "batched_16384") +DATA_PATH = ( + "s3://ml-dev-sfc-or-dev-misc1-k8s/cortexsearch/biencoder/pretrain_data_arctic_training_format/Alibaba_NLP_gte_multilingual_base/combined_all_32768" +) +# EVAL_DATA_PATHS = [str(path) for path in (Path(__file__).parent / "data" / "eval").iterdir() if path.is_dir()] # fix this +datasets = [ + "amazon_qa", + "ccnews_de_v1", + "ccnews_en_v1", + "ccnews_es_v1", + "ccnews_fr_v1", + "ccnews_it_v1", + "ccnews_pl_v1", + "ccnews_pt_v1", + "faq", + "mc4_de_v1", + "mc4_en_v1", + "mc4_es_v1", + "mc4_fr_v1", + "mc4_it_v1", + "mc4_pl_v1", + "mc4_pt_v1", + "mwiki_de_v1", + "mwiki_en_v1", + "mwiki_es_v1", + "mwiki_fr_v1", + "mwiki_it_v1", + "mwiki_pl_v1", + "mwiki_pt_v1", + "paq", + "pes2o", + "red_pajama", + "red_pajamas_1t_stackexchange", + "s2orc_title_abstracts", + "snippets4", + "techrepo", + "top_stories", + "trivia_qa", + "wikipedia", +] +EVAL_DATA_PATHS = [ + f"s3://ml-dev-sfc-or-dev-misc1-k8s/cortexsearch/biencoder/pretrain_data_arctic_training_format/Alibaba_NLP_gte_multilingual_base/combined_all_32768_eval/{dataset}" + for dataset in datasets +] +# from transformers import AutoTokenizer +# tok = AutoTokenizer.from_pretrained("BAAI/bge-m3-retromae") +# tok.pad_token_id --> 1 +PAD_VALUE = 1 +LEFT_PAD = False + + +def now_timestamp_str() -> str: + """Get the current ISO 8601 UTC timestamp.""" + return datetime.now(timezone.utc).strftime(r"%Y%m%dT%H%M%SZ") + + +ts = now_timestamp_str() +checkpoint_dir = Path(__file__).parent / "checkpoints" / "pretrain_mgte" / ts +mconf = BiencoderModelConfig( + name_or_path="Alibaba-NLP/gte-multilingual-mlm-base", + pooling=os.getenv("BIENCODER_POOLING", "first_token"), + kwargs={ + "trust_remote_code": True, + "unpad_inputs": True, + "use_memory_efficient_attention": True, + }, +) +dconf = ContrastivePretokenizedDataConfig( + filesystem="s3", + root_directory=DATA_PATH, + # filesystem="local", + # root_directory=DATA_PATH, + # Depending on how much GPU memory you have, you may need to split each + # batch into a number of smaller sub-batches by setting the split_factor. + # If you do so, you will probably want to decrease the learning rate accordingly. + # split_factor=4, + max_seq_length_query=32, + max_seq_length_doc=256, + eval_root_directories=EVAL_DATA_PATHS, + eval_max_seq_length_doc=32, + eval_max_seq_length_query=256, + pad_value=PAD_VALUE, + left_pad=LEFT_PAD, +) +sconf = WSDSchedulerConfig(num_warmup_steps=2000, num_decay_steps=2000) +oconf = OptimizerConfig(weight_decay=0.01, learning_rate=LEARNING_RATE) +lconf = LoggerConfig(level="INFO") +wconf = WandBConfig( + enable=True, + project="arctic-training-arctic-embed-testbed", + name=f"mgte-pretrain-{ts}", +) +# Reference: https://www.deepspeed.ai/training/#gradient-clipping +dsconf = { + "gradient_clipping": GRADIENT_CLIPPING, + "zero_optimization": {"stage": 1}, + # NOTE: The underlying DeepSpeed engine scales gradients down by a factor of + # `1/world_size`` in the backwards pass, so we pre-scale the loss up by a factor + # of `world_size`. Given these scalings, there is a potential for increased + # numerical imprecision when using low-precision floating point representation, + # so we set communication to fp32 in the backwards all-reduce to somewhat mitigate + # this risk. + "communication_data_type": "fp32", +} +cconf = CheckpointConfig( + output_dir=checkpoint_dir, + type="biencoder", + save_every_n_steps=300, + save_end_of_training=True, +) + + +def configure_non_distributed_distributed_training_if_needed() -> None: + """Detect if we need to manually initialize distributed training environment + and do so if needed. + + NOTE: We have to do this step because Arctic Training doesn't have a default + 1-GPU launching mode and will instead fall back to trying to auto-discover + distributed training configuration (e.g. via MPI). + """ + num_cli_args = len(sys.argv) - 1 + if num_cli_args == 0: + print("***No CLI args detected, configuring for single-GPU training.***") + from os import environ + + from torch import distributed as dist + + environ["MASTER_ADDR"] = "localhost" + environ["MASTER_PORT"] = "12335" + environ["LOCAL_RANK"] = "0" + dist.init_process_group(backend="nccl", world_size=1, rank=0) + + +if __name__ == "__main__": + CUDA_ALLOCATOR_CONFIG_FOR_DYNAMICALLY_SIZED_DATA.set_env() + configure_non_distributed_distributed_training_if_needed() + tconf = BiencoderTrainerConfig( + type="biencoder", + model=mconf, + data=dconf, + scheduler=sconf, + optimizer=oconf, + logger=lconf, + checkpoint=cconf, + wandb=wconf, + deepspeed=dsconf, + loss_log_interval=0, + eval_frequency=300, + use_in_batch_negatives=True, + loss_temperature=0.02, + overfit_first_batch=False, + mrl_dim=256, + splade_reg_weight=float(os.getenv("SPLADE_REG_WEIGHT", "0.0")), + splade_flops_weight=float(os.getenv("SPLADE_FLOPS_WEIGHT", "0.0")), + splade_nnz_threshold=float(os.getenv("SPLADE_NNZ_THRESHOLD", "1e-3")), + ) + trainer = BiencoderTrainer(config=tconf) + trainer.train() diff --git a/projects/arctic_embed/src/arctic_embed/biencoder_model_factory.py b/projects/arctic_embed/src/arctic_embed/biencoder_model_factory.py index 80c2db87..98094422 100644 --- a/projects/arctic_embed/src/arctic_embed/biencoder_model_factory.py +++ b/projects/arctic_embed/src/arctic_embed/biencoder_model_factory.py @@ -21,6 +21,7 @@ from peft.config import PeftConfig from transformers import AutoConfig from transformers import AutoModel +from transformers import AutoModelForMaskedLM from arctic_training.config.model import ModelConfig from arctic_training.model.factory import ModelFactory @@ -55,13 +56,22 @@ def create_model(self, model_config: AutoConfig) -> Biencoder: arctic_training_model_config = self.config assert isinstance(arctic_training_model_config, BiencoderModelConfig) trust_remote_code = arctic_training_model_config.kwargs.get("trust_remote_code", None) - encoder = AutoModel.from_pretrained( - self.config.name_or_path, - config=model_config, - attn_implementation=self.config.attn_implementation, - torch_dtype=self.config.dtype.value, - trust_remote_code=trust_remote_code, - ) + if arctic_training_model_config.pooling == "splade": + encoder = AutoModelForMaskedLM.from_pretrained( + self.config.name_or_path, + config=model_config, + attn_implementation=self.config.attn_implementation, + torch_dtype=self.config.dtype.value, + trust_remote_code=trust_remote_code, + ) + else: + encoder = AutoModel.from_pretrained( + self.config.name_or_path, + config=model_config, + attn_implementation=self.config.attn_implementation, + torch_dtype=self.config.dtype.value, + trust_remote_code=trust_remote_code, + ) return Biencoder(encoder, pooling=arctic_training_model_config.pooling) def post_create_model_callback(self, model: Biencoder): diff --git a/projects/arctic_embed/src/arctic_embed/core/biencoder_model.py b/projects/arctic_embed/src/arctic_embed/core/biencoder_model.py index a0464a86..bb6868bb 100644 --- a/projects/arctic_embed/src/arctic_embed/core/biencoder_model.py +++ b/projects/arctic_embed/src/arctic_embed/core/biencoder_model.py @@ -29,7 +29,7 @@ from torch import nn from transformers import PreTrainedModel -PoolingOption = Literal["first_token", "last_token", "mean"] +PoolingOption = Literal["first_token", "last_token", "mean", "splade"] logger = logging.getLogger(__name__) @@ -45,6 +45,17 @@ def __init__(self, encoder: PreTrainedModel, pooling: PoolingOption = "first_tok def encode(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: out = self.encoder(input_ids=input_ids, attention_mask=attention_mask) + # SPLADE-style sparse pooling branches on logits instead of hidden states. + if self.pooling == "splade": + if not hasattr(out, "logits"): + raise ValueError( + f"Encoder of class {self.encoder.__class__} must output `logits` for SPLADE pooling." + ) + logits = out.logits # (batch, seq_len, vocab_size) + pooled = splade_pool(logits, attention_mask) + # NOTE: For SPLADE we avoid L2-normalization to preserve magnitude and sparsity. + return pooled.contiguous() + if not hasattr(out, "last_hidden_state"): raise ValueError( f"Encoder of class {self.encoder.__class__} is missing the " @@ -76,6 +87,29 @@ def forward( document_vectors = self.encode(document_input_ids, document_attention_mask) return query_vectors, document_vectors + def compute_splade_flops(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: + """Compute SPLADE FLOPs proxy regularizer for a batch. + + FLOPs proxy: sum_v log(1 + sum_t ReLU(logits_{t,v})) averaged over batch. + """ + if self.pooling != "splade": + raise ValueError("FLOPs regularizer requested but pooling is not 'splade'.") + out = self.encoder(input_ids=input_ids, attention_mask=attention_mask) + if not hasattr(out, "logits"): + raise ValueError( + f"Encoder of class {self.encoder.__class__} must output `logits` for SPLADE FLOPs computation." + ) + logits = out.logits # (batch, seq_len, vocab_size) + relu_logits = F.relu(logits) + if attention_mask.dtype != relu_logits.dtype: + attn = attention_mask.to(relu_logits.dtype) + else: + attn = attention_mask + relu_logits = relu_logits * attn[..., None] + token_sums = relu_logits.sum(dim=1) # (batch, vocab_size) + flops_per_example = torch.log1p(token_sums).sum(dim=1) # (batch) + return flops_per_example.mean() + def save_pretrained( self, save_directory: Union[str, os.PathLike], @@ -133,3 +167,24 @@ def last_token_pool(out: Tensor, attention_mask: Tensor) -> Tensor: row = torch.arange(batch_size, device=out.device) col = attention_mask.sum(dim=1) - 1 # position of the last non-padding token return out[row, col, ...] + + +def splade_pool(logits: Tensor, attention_mask: Tensor) -> Tensor: + """SPLADE-style sparse pooling over vocabulary logits. + + - Expects `logits` with shape (batch, token, vocab_size) + - Applies activation: log(1 + relu(logits)) + - Aggregates over tokens with masked max + """ + assert logits.ndim == 3 + assert attention_mask.ndim == 2 + # Activation as in SPLADE + activated = torch.log1p(F.relu(logits)) + # Mask out padding positions so they do not affect max + mask = ~attention_mask[..., None].bool() + activated = activated.masked_fill(mask, float("-inf")) + # Max over sequence dimension + pooled = torch.amax(activated, dim=1) + # Replace -inf (for fully-masked sequences) with zeros to avoid nans downstream + pooled = torch.where(torch.isinf(pooled), torch.zeros_like(pooled), pooled) + return pooled diff --git a/projects/arctic_embed/src/arctic_embed/trainer.py b/projects/arctic_embed/src/arctic_embed/trainer.py index 54a825e2..46cc6f2c 100644 --- a/projects/arctic_embed/src/arctic_embed/trainer.py +++ b/projects/arctic_embed/src/arctic_embed/trainer.py @@ -57,6 +57,9 @@ class BiencoderTrainerConfig(TrainerConfig): data: ContrastivePretokenizedDataConfig mrl_dim: Optional[int] = None eval_interval: Optional[int] = None + splade_reg_weight: float = 0.0 + splade_flops_weight: float = 0.0 + splade_nnz_threshold: float = 1e-3 class FakeTokenizer: @@ -210,11 +213,20 @@ def eval(self, batch: ContrastiveLearningBatch) -> Dict[str, float]: query_embeddings, document_embeddings, relations = self.forward_and_gather(batch) if self.config.use_in_batch_negatives: relations[relations == 0] = -1 - q_emb = F.normalize(query_embeddings, dim=1) - d_emb = F.normalize(document_embeddings, dim=1) - scores = torch.matmul(q_emb, d_emb.transpose(0, 1)) + if isinstance(self.model, Biencoder) and self.model.pooling == "splade": + scores = torch.matmul(query_embeddings, document_embeddings.transpose(0, 1)) + else: + q_emb = F.normalize(query_embeddings, dim=1) + d_emb = F.normalize(document_embeddings, dim=1) + scores = torch.matmul(q_emb, d_emb.transpose(0, 1)) loss_infonce = info_nce_loss(scores, relations=relations, temperature=self.config.loss_temperature).item() - return {"infoNCE": loss_infonce} + metrics = {"infoNCE": loss_infonce} + if isinstance(self.model, Biencoder) and self.model.pooling == "splade": + thr = self.config.splade_nnz_threshold + avg_terms_q = (query_embeddings > thr).to(torch.float32).sum(dim=1).mean().item() + avg_terms_d = (document_embeddings > thr).to(torch.float32).sum(dim=1).mean().item() + metrics.update({"avg_terms_query": avg_terms_q, "avg_terms_doc": avg_terms_d}) + return metrics def loss(self, batch: ContrastiveLearningBatch) -> Tensor: # Count the number of queries and documents seen in this batch. @@ -226,16 +238,36 @@ def loss(self, batch: ContrastiveLearningBatch) -> Tensor: # Forward pass, gathering embeddings so each GPU has the full picture. query_embeddings, document_embeddings, relations = self.forward_and_gather(batch) - # InfoNCE loss with Matryoshka Representation Learning (MRL). + # InfoNCE loss (SPLADE uses raw dot product; dense uses MRL path). if self.config.use_in_batch_negatives: relations[relations == 0] = -1 - loss, loss_base, loss_truncated = one_size_truncated_mrl_info_nce_loss( - query_embeddings=query_embeddings, - document_embeddings=document_embeddings, - relations=relations, - truncated_dimension=self.config.mrl_dim, - temperature=self.config.loss_temperature, - ) + if isinstance(self.model, Biencoder) and self.model.pooling == "splade": + scores = torch.matmul(query_embeddings, document_embeddings.transpose(0, 1)) + loss = info_nce_loss(scores, relations=relations, temperature=self.config.loss_temperature) + loss_base = loss + loss_truncated = None + else: + loss, loss_base, loss_truncated = one_size_truncated_mrl_info_nce_loss( + query_embeddings=query_embeddings, + document_embeddings=document_embeddings, + relations=relations, + truncated_dimension=self.config.mrl_dim, + temperature=self.config.loss_temperature, + ) + + # Optional SPLADE-style L1 sparsity regularization on pooled embeddings. + if isinstance(self.model, Biencoder) and self.model.pooling == "splade" and self.config.splade_reg_weight > 0: + reg_q = query_embeddings.abs().sum(dim=1).mean() + reg_d = document_embeddings.abs().sum(dim=1).mean() + reg = self.config.splade_reg_weight * (reg_q + reg_d) + loss = loss + reg + + # Optional SPLADE-style FLOPs regularization (proxy for number of active terms). + if isinstance(self.model, Biencoder) and self.model.pooling == "splade" and self.config.splade_flops_weight > 0: + flops_q = self.model.compute_splade_flops(batch.query_tokens, batch.query_attention_mask) + flops_d = self.model.compute_splade_flops(batch.document_tokens, batch.document_attention_mask) + flops_reg = self.config.splade_flops_weight * (flops_q + flops_d) + loss = loss + flops_reg # Weights and Biases logging. # NOTE: We log more than is feasible to do in a callback, so we do it here @@ -252,9 +284,22 @@ def loss(self, batch: ContrastiveLearningBatch) -> Tensor: "train/batch_size_doc": global_batch_size_doc, "train/loss_no_truncate": loss_base.item(), } + # SPLADE average active terms per query/doc for this step. + if isinstance(self.model, Biencoder) and self.model.pooling == "splade": + thr = self.config.splade_nnz_threshold + metrics["train/avg_terms_query"] = ( + (query_embeddings > thr).to(torch.float32).sum(dim=1).mean().item() + ) + metrics["train/avg_terms_doc"] = ( + (document_embeddings > thr).to(torch.float32).sum(dim=1).mean().item() + ) if loss_truncated is not None: truncated_loss_name = f"train/loss_truncate_{self.config.mrl_dim}" metrics[truncated_loss_name] = loss_truncated.item() + if isinstance(self.model, Biencoder) and self.model.pooling == "splade" and self.config.splade_reg_weight > 0: + metrics["train/splade_reg_weight"] = self.config.splade_reg_weight + if isinstance(self.model, Biencoder) and self.model.pooling == "splade" and self.config.splade_flops_weight > 0: + metrics["train/splade_flops_weight"] = self.config.splade_flops_weight wandb.log(metrics, step=self.global_step) return loss From 21f0204caad7bbbd339c75d2b73a597345bc86ad Mon Sep 17 00:00:00 2001 From: Puxuan Yu Date: Wed, 10 Sep 2025 14:38:01 -0700 Subject: [PATCH 08/11] wip --- .../{pretrain_mgte.py => pretrain_splade.py} | 111 +++++++++--------- .../src/arctic_embed/core/biencoder_model.py | 76 ++++++++---- .../arctic_embed/src/arctic_embed/trainer.py | 91 ++++++++++---- 3 files changed, 183 insertions(+), 95 deletions(-) rename projects/arctic_embed/examples/finetune_models/{pretrain_mgte.py => pretrain_splade.py} (76%) diff --git a/projects/arctic_embed/examples/finetune_models/pretrain_mgte.py b/projects/arctic_embed/examples/finetune_models/pretrain_splade.py similarity index 76% rename from projects/arctic_embed/examples/finetune_models/pretrain_mgte.py rename to projects/arctic_embed/examples/finetune_models/pretrain_splade.py index 58723f1c..6fdaaeea 100644 --- a/projects/arctic_embed/examples/finetune_models/pretrain_mgte.py +++ b/projects/arctic_embed/examples/finetune_models/pretrain_splade.py @@ -45,55 +45,55 @@ LEARNING_RATE = 1e-4 GRADIENT_CLIPPING = 10.0 +# TODO: need to find a proper English only model, and tokenizer & batch only English pretraining data # DATA_PATH = str(Path(__file__).parent / "data" / "pretrain_amazonqa" / "batched_16384") -DATA_PATH = ( - "s3://ml-dev-sfc-or-dev-misc1-k8s/cortexsearch/biencoder/pretrain_data_arctic_training_format/Alibaba_NLP_gte_multilingual_base/combined_all_32768" -) -# EVAL_DATA_PATHS = [str(path) for path in (Path(__file__).parent / "data" / "eval").iterdir() if path.is_dir()] # fix this -datasets = [ - "amazon_qa", - "ccnews_de_v1", - "ccnews_en_v1", - "ccnews_es_v1", - "ccnews_fr_v1", - "ccnews_it_v1", - "ccnews_pl_v1", - "ccnews_pt_v1", - "faq", - "mc4_de_v1", - "mc4_en_v1", - "mc4_es_v1", - "mc4_fr_v1", - "mc4_it_v1", - "mc4_pl_v1", - "mc4_pt_v1", - "mwiki_de_v1", - "mwiki_en_v1", - "mwiki_es_v1", - "mwiki_fr_v1", - "mwiki_it_v1", - "mwiki_pl_v1", - "mwiki_pt_v1", - "paq", - "pes2o", - "red_pajama", - "red_pajamas_1t_stackexchange", - "s2orc_title_abstracts", - "snippets4", - "techrepo", - "top_stories", - "trivia_qa", - "wikipedia", -] -EVAL_DATA_PATHS = [ - f"s3://ml-dev-sfc-or-dev-misc1-k8s/cortexsearch/biencoder/pretrain_data_arctic_training_format/Alibaba_NLP_gte_multilingual_base/combined_all_32768_eval/{dataset}" - for dataset in datasets -] -# from transformers import AutoTokenizer -# tok = AutoTokenizer.from_pretrained("BAAI/bge-m3-retromae") -# tok.pad_token_id --> 1 -PAD_VALUE = 1 -LEFT_PAD = False +# DATA_PATH = ( +# "s3://ml-dev-sfc-or-dev-misc1-k8s/cortexsearch/biencoder/pretrain_data_arctic_training_format/Alibaba_NLP_gte_multilingual_base/combined_all_16384" +# ) +# datasets = [ +# "amazon_qa", +# # "ccnews_de_v1", +# "ccnews_en_v1", +# # "ccnews_es_v1", +# # "ccnews_fr_v1", +# # "ccnews_it_v1", +# # "ccnews_pl_v1", +# # "ccnews_pt_v1", +# "faq", +# # "mc4_de_v1", +# "mc4_en_v1", +# # "mc4_es_v1", +# # "mc4_fr_v1", +# # "mc4_it_v1", +# # "mc4_pl_v1", +# # "mc4_pt_v1", +# # "mwiki_de_v1", +# "mwiki_en_v1", +# # "mwiki_es_v1", +# # "mwiki_fr_v1", +# # "mwiki_it_v1", +# # "mwiki_pl_v1", +# # "mwiki_pt_v1", +# "paq", +# "pes2o", +# "red_pajama", +# "red_pajamas_1t_stackexchange", +# "s2orc_title_abstracts", +# "snippets4", +# "techrepo", +# "top_stories", +# "trivia_qa", +# "wikipedia", +# ] +# EVAL_DATA_PATHS = [ +# f"s3://ml-dev-sfc-or-dev-misc1-k8s/cortexsearch/biencoder/pretrain_data_arctic_training_format/Alibaba_NLP_gte_multilingual_base/combined_all_16384_eval/{dataset}" +# for dataset in datasets +# ] +# # from transformers import AutoTokenizer +# # tok = AutoTokenizer.from_pretrained("BAAI/bge-m3-retromae") +# # tok.pad_token_id --> 1 +# PAD_VALUE = 1 +# LEFT_PAD = False def now_timestamp_str() -> str: @@ -105,7 +105,7 @@ def now_timestamp_str() -> str: checkpoint_dir = Path(__file__).parent / "checkpoints" / "pretrain_mgte" / ts mconf = BiencoderModelConfig( name_or_path="Alibaba-NLP/gte-multilingual-mlm-base", - pooling=os.getenv("BIENCODER_POOLING", "first_token"), + pooling="splade", kwargs={ "trust_remote_code": True, "unpad_inputs": True, @@ -120,7 +120,8 @@ def now_timestamp_str() -> str: # Depending on how much GPU memory you have, you may need to split each # batch into a number of smaller sub-batches by setting the split_factor. # If you do so, you will probably want to decrease the learning rate accordingly. - # split_factor=4, + split_factor=64, + eval_split_factor=64, max_seq_length_query=32, max_seq_length_doc=256, eval_root_directories=EVAL_DATA_PATHS, @@ -135,7 +136,7 @@ def now_timestamp_str() -> str: wconf = WandBConfig( enable=True, project="arctic-training-arctic-embed-testbed", - name=f"mgte-pretrain-{ts}", + name=f"mgte-pretrain-splade-{ts}", ) # Reference: https://www.deepspeed.ai/training/#gradient-clipping dsconf = { @@ -196,10 +197,12 @@ def configure_non_distributed_distributed_training_if_needed() -> None: use_in_batch_negatives=True, loss_temperature=0.02, overfit_first_batch=False, - mrl_dim=256, + mrl_dim=None, splade_reg_weight=float(os.getenv("SPLADE_REG_WEIGHT", "0.0")), - splade_flops_weight=float(os.getenv("SPLADE_FLOPS_WEIGHT", "0.0")), - splade_nnz_threshold=float(os.getenv("SPLADE_NNZ_THRESHOLD", "1e-3")), + # Per-side SPLADE v2 FLOPs regularizers. + splade_flops_weight_query=float(os.getenv("SPLADE_FLOPS_WEIGHT_QUERY", "1e-1")), + splade_flops_weight_doc=float(os.getenv("SPLADE_FLOPS_WEIGHT_DOC", "1e-4")), + splade_nnz_threshold=float(os.getenv("SPLADE_NNZ_THRESHOLD", "0")), ) trainer = BiencoderTrainer(config=tconf) trainer.train() diff --git a/projects/arctic_embed/src/arctic_embed/core/biencoder_model.py b/projects/arctic_embed/src/arctic_embed/core/biencoder_model.py index bb6868bb..450b7f53 100644 --- a/projects/arctic_embed/src/arctic_embed/core/biencoder_model.py +++ b/projects/arctic_embed/src/arctic_embed/core/biencoder_model.py @@ -42,6 +42,9 @@ def __init__(self, encoder: PreTrainedModel, pooling: PoolingOption = "first_tok self.encoder = encoder self.pooling = pooling self.config = encoder.config + # Caches for SPLADE pooled weights from the most recent forward. + self._cached_query_pooled: Optional[Tensor] = None + self._cached_document_pooled: Optional[Tensor] = None def encode(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: out = self.encoder(input_ids=input_ids, attention_mask=attention_mask) @@ -83,32 +86,55 @@ def forward( document_input_ids: Tensor, document_attention_mask: Tensor, ) -> Tuple[Tensor, Tensor]: - query_vectors = self.encode(query_input_ids, query_attention_mask) - document_vectors = self.encode(document_input_ids, document_attention_mask) + # Clear caches from any previous call + self._cached_query_pooled = None + self._cached_document_pooled = None + + if self.pooling == "splade": + # Run encoder to obtain logits for query and cache. + q_out = self.encoder(input_ids=query_input_ids, attention_mask=query_attention_mask) + if not hasattr(q_out, "logits"): + raise ValueError( + f"Encoder of class {self.encoder.__class__} must output `logits` for SPLADE pooling." + ) + self._cached_query_pooled = splade_pool(q_out.logits, query_attention_mask).contiguous() + query_vectors = self._cached_query_pooled + + # Run encoder to obtain logits for document and cache. + d_out = self.encoder(input_ids=document_input_ids, attention_mask=document_attention_mask) + if not hasattr(d_out, "logits"): + raise ValueError( + f"Encoder of class {self.encoder.__class__} must output `logits` for SPLADE pooling." + ) + self._cached_document_pooled = splade_pool(d_out.logits, document_attention_mask).contiguous() + document_vectors = self._cached_document_pooled + else: + query_vectors = self.encode(query_input_ids, query_attention_mask) + document_vectors = self.encode(document_input_ids, document_attention_mask) return query_vectors, document_vectors - def compute_splade_flops(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: - """Compute SPLADE FLOPs proxy regularizer for a batch. + def compute_splade_flops_doc_cached(self) -> Tensor: + """SPLADE v2 FLOPs proxy using cached pooled weights (doc side), Eq. (4). - FLOPs proxy: sum_v log(1 + sum_t ReLU(logits_{t,v})) averaged over batch. + Let w_j(d_i) be the SPLADE weight of term j for sample i (after activation+pooling). + With N in-batch samples, Eq. (4): ℓ_FLOPS = Σ_j ( (1/N) Σ_i w_j(d_i) )^2. """ if self.pooling != "splade": raise ValueError("FLOPs regularizer requested but pooling is not 'splade'.") - out = self.encoder(input_ids=input_ids, attention_mask=attention_mask) - if not hasattr(out, "logits"): - raise ValueError( - f"Encoder of class {self.encoder.__class__} must output `logits` for SPLADE FLOPs computation." - ) - logits = out.logits # (batch, seq_len, vocab_size) - relu_logits = F.relu(logits) - if attention_mask.dtype != relu_logits.dtype: - attn = attention_mask.to(relu_logits.dtype) - else: - attn = attention_mask - relu_logits = relu_logits * attn[..., None] - token_sums = relu_logits.sum(dim=1) # (batch, vocab_size) - flops_per_example = torch.log1p(token_sums).sum(dim=1) # (batch) - return flops_per_example.mean() + if self._cached_document_pooled is None: + raise RuntimeError("Document pooled cache is empty; call forward first.") + return _splade_flops_batch_mean_squared(self._cached_document_pooled) + + def compute_splade_flops_query_cached(self) -> Tensor: + """SPLADE v2 FLOPs proxy using cached pooled weights (query side), Eq. (4). + + Same formula as doc side; we apply a different scalar weight in the trainer. + """ + if self.pooling != "splade": + raise ValueError("FLOPs regularizer requested but pooling is not 'splade'.") + if self._cached_query_pooled is None: + raise RuntimeError("Query pooled cache is empty; call forward first.") + return _splade_flops_batch_mean_squared(self._cached_query_pooled) def save_pretrained( self, @@ -188,3 +214,13 @@ def splade_pool(logits: Tensor, attention_mask: Tensor) -> Tensor: # Replace -inf (for fully-masked sequences) with zeros to avoid nans downstream pooled = torch.where(torch.isinf(pooled), torch.zeros_like(pooled), pooled) return pooled + +def _splade_flops_batch_mean_squared(pooled_weights: Tensor) -> Tensor: + """Compute Eq. (4): || (1/N) Σ_i w(d_i) ||_2^2, where w(d_i) are pooled SPLADE weights. + + pooled_weights: (batch, vocab_size) + returns: scalar Tensor (batch-mean vector squared L2 norm) + """ + assert pooled_weights.ndim == 2 + mean_vec = pooled_weights.mean(dim=0) + return (mean_vec.pow(2).sum()) diff --git a/projects/arctic_embed/src/arctic_embed/trainer.py b/projects/arctic_embed/src/arctic_embed/trainer.py index 46cc6f2c..dae96668 100644 --- a/projects/arctic_embed/src/arctic_embed/trainer.py +++ b/projects/arctic_embed/src/arctic_embed/trainer.py @@ -15,6 +15,7 @@ from __future__ import annotations +import os from typing import Callable from typing import Dict from typing import List @@ -57,8 +58,10 @@ class BiencoderTrainerConfig(TrainerConfig): data: ContrastivePretokenizedDataConfig mrl_dim: Optional[int] = None eval_interval: Optional[int] = None - splade_reg_weight: float = 0.0 - splade_flops_weight: float = 0.0 + # SPLADE regularization (v1 L1 is deprecated; use FLOPs per side instead) + splade_reg_weight: float = 0.0 # deprecated; kept for backward-compat log + splade_flops_weight_query: float = 0.0 + splade_flops_weight_doc: float = 0.0 splade_nnz_threshold: float = 1e-3 @@ -141,6 +144,25 @@ def eval_and_log_cb(self: BiencoderTrainer) -> None: logger.info(f"Global Step: {self.global_step}/{self.training_horizon} Eval: {metrics}") +def splade_flops_warmup_cb(self: BiencoderTrainer) -> None: + """Pre-step callback to linearly warm up SPLADE FLOPs weights.""" + if not hasattr(self.config, "_splade_flops_warmup_steps"): + # Initialize warmup config on first call + self.config._splade_flops_warmup_steps = int(os.environ.get("SPLADE_FLOPS_WARMUP_STEPS", "2000")) + self.config._splade_flops_weight_query_target = self.config.splade_flops_weight_query + self.config._splade_flops_weight_doc_target = self.config.splade_flops_weight_doc + + if self.global_step < self.config._splade_flops_warmup_steps: + # Linear warmup from 0 to target + warmup_factor = self.global_step / self.config._splade_flops_warmup_steps + self.config.splade_flops_weight_query = warmup_factor * self.config._splade_flops_weight_query_target + self.config.splade_flops_weight_doc = warmup_factor * self.config._splade_flops_weight_doc_target + else: + # Restore target weights after warmup + self.config.splade_flops_weight_query = self.config._splade_flops_weight_query_target + self.config.splade_flops_weight_doc = self.config._splade_flops_weight_doc_target + + class BiencoderTrainer(Trainer): name = "biencoder" config: BiencoderTrainerConfig @@ -153,6 +175,7 @@ class BiencoderTrainer(Trainer): count_total_queries_seen: int = 0 count_total_documents_seen: int = 0 callbacks: List[Tuple[str, Callable]] = [ + ("pre-step", splade_flops_warmup_cb), ("post-backward", rescale_grad_cb), ("post-step", log_grad_norm_cb), ("post-step", eval_and_log_cb), @@ -211,9 +234,11 @@ def forward_and_gather(self, batch: ContrastiveLearningBatch) -> Tuple[Tensor, T @torch.no_grad() def eval(self, batch: ContrastiveLearningBatch) -> Dict[str, float]: query_embeddings, document_embeddings, relations = self.forward_and_gather(batch) + # Access underlying Biencoder if wrapped by DeepSpeed + model_core = getattr(self.model, "module", self.model) if self.config.use_in_batch_negatives: relations[relations == 0] = -1 - if isinstance(self.model, Biencoder) and self.model.pooling == "splade": + if isinstance(model_core, Biencoder) and model_core.pooling == "splade": scores = torch.matmul(query_embeddings, document_embeddings.transpose(0, 1)) else: q_emb = F.normalize(query_embeddings, dim=1) @@ -221,7 +246,7 @@ def eval(self, batch: ContrastiveLearningBatch) -> Dict[str, float]: scores = torch.matmul(q_emb, d_emb.transpose(0, 1)) loss_infonce = info_nce_loss(scores, relations=relations, temperature=self.config.loss_temperature).item() metrics = {"infoNCE": loss_infonce} - if isinstance(self.model, Biencoder) and self.model.pooling == "splade": + if isinstance(model_core, Biencoder) and model_core.pooling == "splade": thr = self.config.splade_nnz_threshold avg_terms_q = (query_embeddings > thr).to(torch.float32).sum(dim=1).mean().item() avg_terms_d = (document_embeddings > thr).to(torch.float32).sum(dim=1).mean().item() @@ -239,9 +264,10 @@ def loss(self, batch: ContrastiveLearningBatch) -> Tensor: query_embeddings, document_embeddings, relations = self.forward_and_gather(batch) # InfoNCE loss (SPLADE uses raw dot product; dense uses MRL path). + model_core = getattr(self.model, "module", self.model) if self.config.use_in_batch_negatives: relations[relations == 0] = -1 - if isinstance(self.model, Biencoder) and self.model.pooling == "splade": + if isinstance(model_core, Biencoder) and model_core.pooling == "splade": scores = torch.matmul(query_embeddings, document_embeddings.transpose(0, 1)) loss = info_nce_loss(scores, relations=relations, temperature=self.config.loss_temperature) loss_base = loss @@ -255,19 +281,27 @@ def loss(self, batch: ContrastiveLearningBatch) -> Tensor: temperature=self.config.loss_temperature, ) - # Optional SPLADE-style L1 sparsity regularization on pooled embeddings. - if isinstance(self.model, Biencoder) and self.model.pooling == "splade" and self.config.splade_reg_weight > 0: + # Remove L1 regularizer in favor of SPLADE v2 FLOPs regularizers per side. + # Keep code path off unless legacy weight is set. + if ( + isinstance(model_core, Biencoder) + and model_core.pooling == "splade" + and self.config.splade_reg_weight > 0 + ): + # No-op by default; legacy users can still get old behavior if they set this. reg_q = query_embeddings.abs().sum(dim=1).mean() reg_d = document_embeddings.abs().sum(dim=1).mean() - reg = self.config.splade_reg_weight * (reg_q + reg_d) - loss = loss + reg - - # Optional SPLADE-style FLOPs regularization (proxy for number of active terms). - if isinstance(self.model, Biencoder) and self.model.pooling == "splade" and self.config.splade_flops_weight > 0: - flops_q = self.model.compute_splade_flops(batch.query_tokens, batch.query_attention_mask) - flops_d = self.model.compute_splade_flops(batch.document_tokens, batch.document_attention_mask) - flops_reg = self.config.splade_flops_weight * (flops_q + flops_d) - loss = loss + flops_reg + loss_reg = self.config.splade_reg_weight * (reg_q + reg_d) + loss = loss + loss_reg + + # SPLADE v2 FLOPs regularization with separate query and document terms. + if isinstance(model_core, Biencoder) and model_core.pooling == "splade": + if self.config.splade_flops_weight_query > 0: + flops_q = model_core.compute_splade_flops_query_cached() + loss = loss + self.config.splade_flops_weight_query * flops_q + if self.config.splade_flops_weight_doc > 0: + flops_d = model_core.compute_splade_flops_doc_cached() + loss = loss + self.config.splade_flops_weight_doc * flops_d # Weights and Biases logging. # NOTE: We log more than is feasible to do in a callback, so we do it here @@ -285,7 +319,7 @@ def loss(self, batch: ContrastiveLearningBatch) -> Tensor: "train/loss_no_truncate": loss_base.item(), } # SPLADE average active terms per query/doc for this step. - if isinstance(self.model, Biencoder) and self.model.pooling == "splade": + if isinstance(model_core, Biencoder) and model_core.pooling == "splade": thr = self.config.splade_nnz_threshold metrics["train/avg_terms_query"] = ( (query_embeddings > thr).to(torch.float32).sum(dim=1).mean().item() @@ -296,10 +330,25 @@ def loss(self, batch: ContrastiveLearningBatch) -> Tensor: if loss_truncated is not None: truncated_loss_name = f"train/loss_truncate_{self.config.mrl_dim}" metrics[truncated_loss_name] = loss_truncated.item() - if isinstance(self.model, Biencoder) and self.model.pooling == "splade" and self.config.splade_reg_weight > 0: - metrics["train/splade_reg_weight"] = self.config.splade_reg_weight - if isinstance(self.model, Biencoder) and self.model.pooling == "splade" and self.config.splade_flops_weight > 0: - metrics["train/splade_flops_weight"] = self.config.splade_flops_weight + if isinstance(model_core, Biencoder) and model_core.pooling == "splade": + if self.config.splade_reg_weight > 0: + metrics["train/splade_l1_loss"] = loss_reg.item() + metrics["train/splade_l1_loss_weight"] = self.config.splade_reg_weight + else: + metrics["train/splade_l1_loss"] = 0.0 + metrics["train/splade_l1_loss_weight"] = 0.0 + if self.config.splade_flops_weight_query > 0: + metrics["train/splade_query_flops_loss"] = flops_q.item() + metrics["train/splade_query_flops_loss_weight"] = self.config.splade_flops_weight_query + else: + metrics["train/splade_query_flops_loss"] = 0.0 + metrics["train/splade_query_flops_loss_weight"] = 0.0 + if self.config.splade_flops_weight_doc > 0: + metrics["train/splade_doc_flops_loss"] = flops_d.item() + metrics["train/splade_doc_flops_loss_weight"] = self.config.splade_flops_weight_doc + else: + metrics["train/splade_doc_flops_loss"] = 0.0 + metrics["train/splade_doc_flops_loss_weight"] = 0.0 wandb.log(metrics, step=self.global_step) return loss From 56a999c25ef04ad960f51929a0c60ba15b555541 Mon Sep 17 00:00:00 2001 From: Puxuan Yu Date: Fri, 19 Sep 2025 15:02:28 -0700 Subject: [PATCH 09/11] splade based on modernbert --- .../finetune_models/pretrain_splade.py | 89 +++++++------------ 1 file changed, 34 insertions(+), 55 deletions(-) diff --git a/projects/arctic_embed/examples/finetune_models/pretrain_splade.py b/projects/arctic_embed/examples/finetune_models/pretrain_splade.py index 6fdaaeea..ceaf3ce8 100644 --- a/projects/arctic_embed/examples/finetune_models/pretrain_splade.py +++ b/projects/arctic_embed/examples/finetune_models/pretrain_splade.py @@ -25,17 +25,18 @@ Model page: https://huggingface.co/intfloat/e5-base-v2 Better negative mining paper: https://arxiv.org/abs/2407.15831 """ +import os import sys from datetime import datetime from datetime import timezone from pathlib import Path -import os from arctic_embed.biencoder_model_factory import BiencoderModelConfig from arctic_embed.contrastive_dataloader import ContrastivePretokenizedDataConfig from arctic_embed.core.cuda_allocator_config import CUDA_ALLOCATOR_CONFIG_FOR_DYNAMICALLY_SIZED_DATA from arctic_embed.trainer import BiencoderTrainer from arctic_embed.trainer import BiencoderTrainerConfig +from transformers import AutoTokenizer from arctic_training.config.checkpoint import CheckpointConfig from arctic_training.config.logger import LoggerConfig @@ -43,57 +44,35 @@ from arctic_training.config.wandb import WandBConfig from arctic_training.scheduler.wsd_factory import WSDSchedulerConfig -LEARNING_RATE = 1e-4 +LEARNING_RATE = 2e-5 GRADIENT_CLIPPING = 10.0 # TODO: need to find a proper English only model, and tokenizer & batch only English pretraining data -# DATA_PATH = str(Path(__file__).parent / "data" / "pretrain_amazonqa" / "batched_16384") -# DATA_PATH = ( -# "s3://ml-dev-sfc-or-dev-misc1-k8s/cortexsearch/biencoder/pretrain_data_arctic_training_format/Alibaba_NLP_gte_multilingual_base/combined_all_16384" -# ) -# datasets = [ -# "amazon_qa", -# # "ccnews_de_v1", -# "ccnews_en_v1", -# # "ccnews_es_v1", -# # "ccnews_fr_v1", -# # "ccnews_it_v1", -# # "ccnews_pl_v1", -# # "ccnews_pt_v1", -# "faq", -# # "mc4_de_v1", -# "mc4_en_v1", -# # "mc4_es_v1", -# # "mc4_fr_v1", -# # "mc4_it_v1", -# # "mc4_pl_v1", -# # "mc4_pt_v1", -# # "mwiki_de_v1", -# "mwiki_en_v1", -# # "mwiki_es_v1", -# # "mwiki_fr_v1", -# # "mwiki_it_v1", -# # "mwiki_pl_v1", -# # "mwiki_pt_v1", -# "paq", -# "pes2o", -# "red_pajama", -# "red_pajamas_1t_stackexchange", -# "s2orc_title_abstracts", -# "snippets4", -# "techrepo", -# "top_stories", -# "trivia_qa", -# "wikipedia", -# ] -# EVAL_DATA_PATHS = [ -# f"s3://ml-dev-sfc-or-dev-misc1-k8s/cortexsearch/biencoder/pretrain_data_arctic_training_format/Alibaba_NLP_gte_multilingual_base/combined_all_16384_eval/{dataset}" -# for dataset in datasets -# ] -# # from transformers import AutoTokenizer -# # tok = AutoTokenizer.from_pretrained("BAAI/bge-m3-retromae") -# # tok.pad_token_id --> 1 -# PAD_VALUE = 1 -# LEFT_PAD = False +DATA_PATH = "s3://ml-dev-sfc-or-dev-misc1-k8s/cortexsearch/biencoder/pretrain_data_arctic_training_format/answerdotai_ModernBERT_base/combined_all_32768" +datasets = [ + "amazon_qa", + "ccnews_en_v1", + "faq", + "mc4_en_v1", + "mwiki_en_v1", + "paq", + "pes2o", + "red_pajama", + "red_pajamas_1t_stackexchange", + "s2orc_title_abstracts", + "snippets4", + "techrepo", + "top_stories", + "trivia_qa", + "wikipedia", +] +EVAL_DATA_PATHS = [ + f"s3://ml-dev-sfc-or-dev-misc1-k8s/cortexsearch/biencoder/pretrain_data_arctic_training_format/answerdotai_ModernBERT_base/combined_all_32768_eval/{dataset}" + for dataset in datasets +] + + +MODEL_NAME = "answerdotai/ModernBERT-base" +LEFT_PAD = False def now_timestamp_str() -> str: @@ -104,7 +83,7 @@ def now_timestamp_str() -> str: ts = now_timestamp_str() checkpoint_dir = Path(__file__).parent / "checkpoints" / "pretrain_mgte" / ts mconf = BiencoderModelConfig( - name_or_path="Alibaba-NLP/gte-multilingual-mlm-base", + name_or_path=MODEL_NAME, pooling="splade", kwargs={ "trust_remote_code": True, @@ -120,14 +99,14 @@ def now_timestamp_str() -> str: # Depending on how much GPU memory you have, you may need to split each # batch into a number of smaller sub-batches by setting the split_factor. # If you do so, you will probably want to decrease the learning rate accordingly. - split_factor=64, - eval_split_factor=64, + split_factor=16, + eval_split_factor=16, max_seq_length_query=32, max_seq_length_doc=256, eval_root_directories=EVAL_DATA_PATHS, eval_max_seq_length_doc=32, eval_max_seq_length_query=256, - pad_value=PAD_VALUE, + pad_value=AutoTokenizer.from_pretrained(MODEL_NAME).pad_token_id, left_pad=LEFT_PAD, ) sconf = WSDSchedulerConfig(num_warmup_steps=2000, num_decay_steps=2000) @@ -136,7 +115,7 @@ def now_timestamp_str() -> str: wconf = WandBConfig( enable=True, project="arctic-training-arctic-embed-testbed", - name=f"mgte-pretrain-splade-{ts}", + name=f"modernbert-pretrain-splade-{ts}", ) # Reference: https://www.deepspeed.ai/training/#gradient-clipping dsconf = { From 1d8599d02e8832213f0a6cdbff94564555ffd850 Mon Sep 17 00:00:00 2001 From: Puxuan Yu Date: Fri, 19 Sep 2025 16:03:15 -0700 Subject: [PATCH 10/11] fix nan issue (fp32 only) --- .../examples/finetune_models/pretrain_splade.py | 11 ++++++----- .../src/arctic_embed/biencoder_model_factory.py | 4 ++++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/projects/arctic_embed/examples/finetune_models/pretrain_splade.py b/projects/arctic_embed/examples/finetune_models/pretrain_splade.py index ceaf3ce8..59ef55cf 100644 --- a/projects/arctic_embed/examples/finetune_models/pretrain_splade.py +++ b/projects/arctic_embed/examples/finetune_models/pretrain_splade.py @@ -46,7 +46,6 @@ LEARNING_RATE = 2e-5 GRADIENT_CLIPPING = 10.0 -# TODO: need to find a proper English only model, and tokenizer & batch only English pretraining data DATA_PATH = "s3://ml-dev-sfc-or-dev-misc1-k8s/cortexsearch/biencoder/pretrain_data_arctic_training_format/answerdotai_ModernBERT_base/combined_all_32768" datasets = [ "amazon_qa", @@ -90,6 +89,8 @@ def now_timestamp_str() -> str: "unpad_inputs": True, "use_memory_efficient_attention": True, }, + dtype="fp32", + attn_implementation="flash_attention_2", ) dconf = ContrastivePretokenizedDataConfig( filesystem="s3", @@ -109,7 +110,7 @@ def now_timestamp_str() -> str: pad_value=AutoTokenizer.from_pretrained(MODEL_NAME).pad_token_id, left_pad=LEFT_PAD, ) -sconf = WSDSchedulerConfig(num_warmup_steps=2000, num_decay_steps=2000) +sconf = WSDSchedulerConfig(num_warmup_steps=5000, num_decay_steps=5000) oconf = OptimizerConfig(weight_decay=0.01, learning_rate=LEARNING_RATE) lconf = LoggerConfig(level="INFO") wconf = WandBConfig( @@ -174,13 +175,13 @@ def configure_non_distributed_distributed_training_if_needed() -> None: loss_log_interval=0, eval_frequency=300, use_in_batch_negatives=True, - loss_temperature=0.02, + loss_temperature=0.05, overfit_first_batch=False, mrl_dim=None, splade_reg_weight=float(os.getenv("SPLADE_REG_WEIGHT", "0.0")), # Per-side SPLADE v2 FLOPs regularizers. - splade_flops_weight_query=float(os.getenv("SPLADE_FLOPS_WEIGHT_QUERY", "1e-1")), - splade_flops_weight_doc=float(os.getenv("SPLADE_FLOPS_WEIGHT_DOC", "1e-4")), + splade_flops_weight_query=float(os.getenv("SPLADE_FLOPS_WEIGHT_QUERY", "1e-2")), + splade_flops_weight_doc=float(os.getenv("SPLADE_FLOPS_WEIGHT_DOC", "1e-5")), splade_nnz_threshold=float(os.getenv("SPLADE_NNZ_THRESHOLD", "0")), ) trainer = BiencoderTrainer(config=tconf) diff --git a/projects/arctic_embed/src/arctic_embed/biencoder_model_factory.py b/projects/arctic_embed/src/arctic_embed/biencoder_model_factory.py index 98094422..338207c3 100644 --- a/projects/arctic_embed/src/arctic_embed/biencoder_model_factory.py +++ b/projects/arctic_embed/src/arctic_embed/biencoder_model_factory.py @@ -57,6 +57,10 @@ def create_model(self, model_config: AutoConfig) -> Biencoder: assert isinstance(arctic_training_model_config, BiencoderModelConfig) trust_remote_code = arctic_training_model_config.kwargs.get("trust_remote_code", None) if arctic_training_model_config.pooling == "splade": + print(f"🤨 model_config: {model_config}") + print(f"🤨 self.config.attn_implementation: {self.config.attn_implementation}") + print(f"🤨 self.config.dtype.value: {self.config.dtype.value}") + print(f"🤨 trust_remote_code: {trust_remote_code}") encoder = AutoModelForMaskedLM.from_pretrained( self.config.name_or_path, config=model_config, From 966467060d93ed8a34e174c0b0ed4055c1b48049 Mon Sep 17 00:00:00 2001 From: Puxuan Yu Date: Wed, 1 Oct 2025 17:45:15 -0700 Subject: [PATCH 11/11] wip --- .../examples/finetune_models/pretrain_splade.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/projects/arctic_embed/examples/finetune_models/pretrain_splade.py b/projects/arctic_embed/examples/finetune_models/pretrain_splade.py index 59ef55cf..4d159db7 100644 --- a/projects/arctic_embed/examples/finetune_models/pretrain_splade.py +++ b/projects/arctic_embed/examples/finetune_models/pretrain_splade.py @@ -89,7 +89,7 @@ def now_timestamp_str() -> str: "unpad_inputs": True, "use_memory_efficient_attention": True, }, - dtype="fp32", + dtype="bf16", attn_implementation="flash_attention_2", ) dconf = ContrastivePretokenizedDataConfig( @@ -100,8 +100,8 @@ def now_timestamp_str() -> str: # Depending on how much GPU memory you have, you may need to split each # batch into a number of smaller sub-batches by setting the split_factor. # If you do so, you will probably want to decrease the learning rate accordingly. - split_factor=16, - eval_split_factor=16, + split_factor=8, + eval_split_factor=8, max_seq_length_query=32, max_seq_length_doc=256, eval_root_directories=EVAL_DATA_PATHS, @@ -180,8 +180,8 @@ def configure_non_distributed_distributed_training_if_needed() -> None: mrl_dim=None, splade_reg_weight=float(os.getenv("SPLADE_REG_WEIGHT", "0.0")), # Per-side SPLADE v2 FLOPs regularizers. - splade_flops_weight_query=float(os.getenv("SPLADE_FLOPS_WEIGHT_QUERY", "1e-2")), - splade_flops_weight_doc=float(os.getenv("SPLADE_FLOPS_WEIGHT_DOC", "1e-5")), + splade_flops_weight_query=float(os.getenv("SPLADE_FLOPS_WEIGHT_QUERY", "1e-1")), + splade_flops_weight_doc=float(os.getenv("SPLADE_FLOPS_WEIGHT_DOC", "1e-4")), splade_nnz_threshold=float(os.getenv("SPLADE_NNZ_THRESHOLD", "0")), ) trainer = BiencoderTrainer(config=tconf)