diff --git a/arctic_training/__init__.py b/arctic_training/__init__.py index a425f61d..82f22d66 100644 --- a/arctic_training/__init__.py +++ b/arctic_training/__init__.py @@ -23,6 +23,7 @@ from arctic_training.checkpoint.hf_engine import HFCheckpointEngine from arctic_training.config.checkpoint import CheckpointConfig from arctic_training.config.data import DataConfig +from arctic_training.config.experiment_tracking import ExperimentTrackingConfig from arctic_training.config.logger import LoggerConfig from arctic_training.config.model import ModelConfig from arctic_training.config.optimizer import OptimizerConfig @@ -39,6 +40,11 @@ from arctic_training.data.sft_factory import SFTDataFactory from arctic_training.data.snowflake_source import SnowflakeDataSource from arctic_training.data.source import DataSource +from arctic_training.experiment_tracking.snowflake_tracker import SnowflakeExperimentTrackingConfig +from arctic_training.experiment_tracking.snowflake_tracker import SnowflakeExpTracker +from arctic_training.experiment_tracking.tracker import ExperimentTracker +from arctic_training.experiment_tracking.wandb_tracker import WandBExperimentTrackingConfig +from arctic_training.experiment_tracking.wandb_tracker import WandBTracker from arctic_training.logging import logger from arctic_training.model.factory import ModelFactory from arctic_training.model.hf_factory import HFModelFactory diff --git a/arctic_training/checkpoint/ds_engine.py b/arctic_training/checkpoint/ds_engine.py index 8560ecb5..4909fc4c 100644 --- a/arctic_training/checkpoint/ds_engine.py +++ b/arctic_training/checkpoint/ds_engine.py @@ -79,7 +79,11 @@ def client_state(self) -> Dict[str, Any]: "np_random_state": np.random.get_state(), "python_random_state": random.getstate(), "global_step": self.trainer.global_step, - "wandb_run_id": self.trainer.wandb_run_id, + "experiment_tracker_state": ( + self.trainer.experiment_tracker.get_resume_state() + if self.trainer.experiment_tracker is not None + else None + ), } if self.device != torch.device("cpu"): state["torch_cuda_random_state"] = torch.cuda.get_rng_state() @@ -112,7 +116,7 @@ def load(self, model) -> None: if self.device != torch.device("cpu"): torch.cuda.set_rng_state(client_states["torch_cuda_random_state"]) - self.trainer.wandb_run_id = client_states["wandb_run_id"] + self.trainer.experiment_tracker_state = client_states.get("experiment_tracker_state") # Helpful ckpt resume debugging snippet # norm = model_norm(model) diff --git a/arctic_training/config/experiment_tracking.py b/arctic_training/config/experiment_tracking.py new file mode 100644 index 00000000..1bd23bd7 --- /dev/null +++ b/arctic_training/config/experiment_tracking.py @@ -0,0 +1,24 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from arctic_training.config.base import BaseConfig + + +class ExperimentTrackingConfig(BaseConfig): + type: str = "wandb" + """ Experiment tracking backend type. """ + + enable: bool = False + """ Whether to enable experiment tracking. """ diff --git a/arctic_training/config/trainer.py b/arctic_training/config/trainer.py index f501e93b..1a06b415 100644 --- a/arctic_training/config/trainer.py +++ b/arctic_training/config/trainer.py @@ -39,6 +39,7 @@ from arctic_training.config.checkpoint import CheckpointConfig from arctic_training.config.data import DataConfig from arctic_training.config.enums import DType +from arctic_training.config.experiment_tracking import ExperimentTrackingConfig from arctic_training.config.logger import LoggerConfig from arctic_training.config.model import ModelConfig from arctic_training.config.optimizer import OptimizerConfig @@ -47,10 +48,10 @@ from arctic_training.config.utils import HumanInt from arctic_training.config.utils import UniqueKeyLoader from arctic_training.config.utils import parse_human_val -from arctic_training.config.wandb import WandBConfig from arctic_training.registry import _get_class_attr_type_hints from arctic_training.registry import get_registered_checkpoint_engine from arctic_training.registry import get_registered_data_factory +from arctic_training.registry import get_registered_experiment_tracker from arctic_training.registry import get_registered_model_factory from arctic_training.registry import get_registered_optimizer_factory from arctic_training.registry import get_registered_scheduler_factory @@ -88,8 +89,8 @@ class TrainerConfig(BaseConfig): logger: LoggerConfig = Field(default_factory=LoggerConfig) """ Logger configuration. """ - wandb: WandBConfig = Field(default_factory=WandBConfig) - """ Weights and Biases configuration. """ + experiment_tracking: ExperimentTrackingConfig = Field(default_factory=ExperimentTrackingConfig) + """ Experiment tracking configuration (e.g., wandb, snowflake). """ scheduler: SchedulerConfig = Field(default_factory=SchedulerConfig) """ Scheduler configuration. """ @@ -336,6 +337,33 @@ def init_tokenizer_config(cls, v: Union[Dict, TokenizerConfig], info: Validation ) return cast(TokenizerConfig, subconfig) + @model_validator(mode="before") + @classmethod + def migrate_wandb_config(cls, data: Any) -> Any: + """Accept legacy ``wandb:`` YAML key and convert it to ``experiment_tracking:``.""" + if isinstance(data, dict) and "wandb" in data: + wandb_config = data.pop("wandb") + if "experiment_tracking" not in data: + if isinstance(wandb_config, dict): + wandb_config.setdefault("type", "wandb") + data["experiment_tracking"] = wandb_config + return data + + @field_validator("experiment_tracking", mode="before") + @classmethod + def init_experiment_tracking_config( + cls, + v: Union[Dict, ExperimentTrackingConfig], + ) -> ExperimentTrackingConfig: + if isinstance(v, ExperimentTrackingConfig): + return v + config_dict = v if isinstance(v, dict) else {} + tracker_type = config_dict.get("type", "wandb") + tracker_cls = get_registered_experiment_tracker(tracker_type) + config_cls = _get_class_attr_type_hints(tracker_cls, "config")[0] + config_dict["type"] = tracker_type + return config_cls(**config_dict) + @model_validator(mode="after") def validate_eval_interval(self) -> Self: if self.data.eval_sources or self.data.train_eval_split[1] > 0.0: diff --git a/arctic_training/config/wandb.py b/arctic_training/config/wandb.py index 3608f737..b85deb8b 100644 --- a/arctic_training/config/wandb.py +++ b/arctic_training/config/wandb.py @@ -13,20 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +# Backward compatibility alias — use ExperimentTrackingConfig / WandBExperimentTrackingConfig instead. +from arctic_training.experiment_tracking.wandb_tracker import WandBExperimentTrackingConfig as WandBConfig -from arctic_training.config.base import BaseConfig - - -class WandBConfig(BaseConfig): - enable: bool = False - """ Whether to enable Weights and Biases logging. """ - - entity: Optional[str] = None - """ Weights and Biases entity name. """ - - project: Optional[str] = "arctic-training" - """ Weights and Biases project name. """ - - name: Optional[str] = None - """ Weights and Biases run name. """ +__all__ = ["WandBConfig"] diff --git a/arctic_training/experiment_tracking/__init__.py b/arctic_training/experiment_tracking/__init__.py new file mode 100644 index 00000000..3e86bce1 --- /dev/null +++ b/arctic_training/experiment_tracking/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/arctic_training/experiment_tracking/snowflake_tracker.py b/arctic_training/experiment_tracking/snowflake_tracker.py new file mode 100644 index 00000000..6d89cdd0 --- /dev/null +++ b/arctic_training/experiment_tracking/snowflake_tracker.py @@ -0,0 +1,119 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING +from typing import Any +from typing import Dict +from typing import Optional + +from arctic_training.config.experiment_tracking import ExperimentTrackingConfig +from arctic_training.experiment_tracking.tracker import ExperimentTracker +from arctic_training.logging import logger + +if TYPE_CHECKING: + from arctic_training.trainer.trainer import Trainer + + +class SnowflakeExperimentTrackingConfig(ExperimentTrackingConfig): + type: str = "snowflake" + + account: str = "" + """ Snowflake account identifier. """ + + user: str = "" + """ Snowflake user name. """ + + password: str = "" + """ Snowflake PAT or password. """ + + role: str = "" + """ Snowflake role. """ + + warehouse: str = "" + """ Snowflake warehouse. """ + + database: str = "" + """ Snowflake database. """ + + schema_name: str = "" + """ Snowflake schema. """ + + experiment_name: str = "" + """ Name of the experiment in Snowflake experiment tracking. """ + + run_name: Optional[str] = None + """ Name of the run. If not set, one will be generated. """ + + @property + def connection_params(self): + return dict( + account=self.account, + user=self.user, + authentication="PAT", + password=self.password, + role=self.role, + warehouse=self.warehouse, + database=self.database, + schema=self.schema_name, + ) + + +class SnowflakeExpTracker(ExperimentTracker): + name: str = "snowflake" + config: SnowflakeExperimentTrackingConfig + + def __init__(self, trainer: "Trainer", config: SnowflakeExperimentTrackingConfig) -> None: + super().__init__(trainer, config) + self._experiment: Any = None + self._run_name: Optional[str] = config.run_name + + def start(self, run_config: Dict[str, Any]) -> None: + try: + from snowflake.ml.experiment.experiment_tracking import ExperimentTracking + from snowflake.snowpark import Session + except ImportError: + raise ImportError( + "Snowflake experiment tracking requires the snowflake-ml-python and " + "snowflake-snowpark-python packages. Install them with:\n" + " pip install snowflake-ml-python snowflake-snowpark-python" + ) + + session = Session.builder.configs(self.config.connection_params).create() + self._experiment = ExperimentTracking(session=session) + self._experiment.set_experiment(self.config.experiment_name) + + self._experiment.start_run(self._run_name) + self._experiment.log_params(run_config) + logger.info(f"Snowflake experiment tracking started: {self.config.experiment_name}/{self._run_name}") + + def log_metrics(self, metrics: Dict[str, Any], step: int) -> None: + if self._experiment is not None: + self._experiment.log_metrics(metrics, step=step) + + def log_params(self, params: Dict[str, Any]) -> None: + if self._experiment is not None: + self._experiment.log_params(params) + + def finish(self) -> None: + if self._experiment is not None: + self._experiment.end_run() + + def get_resume_state(self) -> Optional[Dict[str, Any]]: + if self._run_name is None: + return None + return {"run_name": self._run_name} + + def set_resume_state(self, state: Dict[str, Any]) -> None: + self._run_name = state.get("run_name") diff --git a/arctic_training/experiment_tracking/tracker.py b/arctic_training/experiment_tracking/tracker.py new file mode 100644 index 00000000..bd745cc1 --- /dev/null +++ b/arctic_training/experiment_tracking/tracker.py @@ -0,0 +1,95 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC +from abc import abstractmethod +from typing import TYPE_CHECKING +from typing import Any +from typing import Dict +from typing import Optional + +from arctic_training.config.experiment_tracking import ExperimentTrackingConfig +from arctic_training.registry import RegistryMeta +from arctic_training.registry import _validate_class_attribute_set +from arctic_training.registry import _validate_class_attribute_type +from arctic_training.registry import _validate_class_method + +if TYPE_CHECKING: + from arctic_training.trainer.trainer import Trainer + + +class ExperimentTracker(ABC, metaclass=RegistryMeta): + """Base class for experiment tracking backends.""" + + name: str + """ + Name of the experiment tracker used for registration. This name + should be unique and is used in configs to select the tracker. + """ + + config: ExperimentTrackingConfig + """ + The type of the config class that the tracker uses. This should be a + subclass of ExperimentTrackingConfig. + """ + + @classmethod + def _validate_subclass(cls) -> None: + _validate_class_attribute_set(cls, "name") + _validate_class_attribute_type(cls, "config", ExperimentTrackingConfig) + _validate_class_method(cls, "start", ["self", "run_config"]) + _validate_class_method(cls, "log_metrics", ["self", "metrics", "step"]) + _validate_class_method(cls, "log_params", ["self", "params"]) + _validate_class_method(cls, "finish", ["self"]) + _validate_class_method(cls, "get_resume_state", ["self"]) + _validate_class_method(cls, "set_resume_state", ["self", "state"]) + + def __init__(self, trainer: "Trainer", config: ExperimentTrackingConfig) -> None: + self._trainer = trainer + self.config = config + + @property + def trainer(self) -> "Trainer": + return self._trainer + + @abstractmethod + def start(self, run_config: Dict[str, Any]) -> None: + """Initialize and start the tracking run.""" + raise NotImplementedError + + @abstractmethod + def log_metrics(self, metrics: Dict[str, Any], step: int) -> None: + """Log metrics for the current step.""" + raise NotImplementedError + + @abstractmethod + def log_params(self, params: Dict[str, Any]) -> None: + """Log parameters/config for the run.""" + raise NotImplementedError + + @abstractmethod + def finish(self) -> None: + """Finalize and close the tracking run.""" + raise NotImplementedError + + @abstractmethod + def get_resume_state(self) -> Optional[Dict[str, Any]]: + """Return state needed to resume this tracker across checkpoint restarts.""" + raise NotImplementedError + + @abstractmethod + def set_resume_state(self, state: Dict[str, Any]) -> None: + """Restore tracker state from a previous checkpoint.""" + raise NotImplementedError diff --git a/arctic_training/experiment_tracking/wandb_tracker.py b/arctic_training/experiment_tracking/wandb_tracker.py new file mode 100644 index 00000000..dc25a7ab --- /dev/null +++ b/arctic_training/experiment_tracking/wandb_tracker.py @@ -0,0 +1,83 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING +from typing import Any +from typing import Dict +from typing import Optional + +import wandb + +from arctic_training.config.experiment_tracking import ExperimentTrackingConfig +from arctic_training.experiment_tracking.tracker import ExperimentTracker + +if TYPE_CHECKING: + from arctic_training.trainer.trainer import Trainer + + +class WandBExperimentTrackingConfig(ExperimentTrackingConfig): + type: str = "wandb" + + entity: Optional[str] = None + """ Weights and Biases entity name. """ + + project: Optional[str] = "arctic-training" + """ Weights and Biases project name. """ + + name: Optional[str] = None + """ Weights and Biases run name. """ + + +class WandBTracker(ExperimentTracker): + name: str = "wandb" + config: WandBExperimentTrackingConfig + + def __init__(self, trainer: "Trainer", config: WandBExperimentTrackingConfig) -> None: + super().__init__(trainer, config) + self._run_id: Optional[str] = None + self._run = None + + def start(self, run_config: Dict[str, Any]) -> None: + if self._run_id is None: + self._run_id = wandb.util.generate_id() + + self._run = wandb.init( + id=self._run_id, + entity=self.config.entity, + project=self.config.project, + name=self.config.name, + config=run_config, + dir=f"{self.trainer.config.logger.output_dir}/wandb", + ) + + def log_metrics(self, metrics: Dict[str, Any], step: int) -> None: + if self._run is not None: + self._run.log(metrics, step=step) + + def log_params(self, params: Dict[str, Any]) -> None: + if self._run is not None: + self._run.config.update(params) + + def finish(self) -> None: + if self._run is not None: + self._run.finish() + + def get_resume_state(self) -> Optional[Dict[str, Any]]: + if self._run_id is None: + return None + return {"wandb_run_id": self._run_id} + + def set_resume_state(self, state: Dict[str, Any]) -> None: + self._run_id = state.get("wandb_run_id") diff --git a/arctic_training/registry.py b/arctic_training/registry.py index 9f9ad879..14625077 100644 --- a/arctic_training/registry.py +++ b/arctic_training/registry.py @@ -36,6 +36,7 @@ from arctic_training.checkpoint.engine import CheckpointEngine from arctic_training.data.factory import DataFactory from arctic_training.data.source import DataSource + from arctic_training.experiment_tracking.tracker import ExperimentTracker from arctic_training.model.factory import ModelFactory from arctic_training.optimizer.factory import OptimizerFactory from arctic_training.scheduler.factory import SchedulerFactory @@ -129,6 +130,10 @@ def get_registered_checkpoint_engine(name: str) -> Type["CheckpointEngine"]: return get_registered_class(class_type="CheckpointEngine", name=name) +def get_registered_experiment_tracker(name: str) -> Type["ExperimentTracker"]: + return get_registered_class(class_type="ExperimentTracker", name=name) + + def get_registered_data_factory(name: str) -> Type["DataFactory"]: return get_registered_class(class_type="DataFactory", name=name) diff --git a/arctic_training/trainer/trainer.py b/arctic_training/trainer/trainer.py index 29686c14..95e6d54f 100644 --- a/arctic_training/trainer/trainer.py +++ b/arctic_training/trainer/trainer.py @@ -19,6 +19,7 @@ from abc import ABC from abc import abstractmethod from functools import cached_property +from typing import Any from typing import Callable from typing import Dict from typing import List @@ -30,7 +31,6 @@ import torch import torch.cuda import torch.distributed.nn -import wandb from deepspeed.accelerator import get_accelerator from deepspeed.runtime.sequence_parallel.ulysses_sp import UlyssesSPAttentionHF from deepspeed.runtime.sequence_parallel.ulysses_sp import UlyssesSPDataLoaderAdapter @@ -38,7 +38,6 @@ from tqdm import tqdm from transformers import set_seed from transformers.integrations.deepspeed import HfDeepSpeedConfig -from wandb.sdk.wandb_run import Run as WandbRun from arctic_training.callback.logging import post_loss_log_cb from arctic_training.callback.mixin import CallbackMixin @@ -47,6 +46,7 @@ from arctic_training.config.trainer import TrainerConfig from arctic_training.data.factory import DataFactory from arctic_training.data.utils import OverfitOneBatchDataLoader +from arctic_training.experiment_tracking.tracker import ExperimentTracker from arctic_training.logging import logger from arctic_training.metrics import Metrics from arctic_training.model.factory import ModelFactory @@ -159,9 +159,9 @@ def __init__(self, config: TrainerConfig, mode: str = "train") -> None: self.global_rank = config.global_rank self.epoch_finished = False self.training_finished = False - self.wandb_experiment: Optional[WandbRun] = None + self.experiment_tracker: Optional[ExperimentTracker] = None self.is_resume = False # Track if we resumed from ckpt - self.wandb_run_id = None + self.experiment_tracker_state: Optional[Dict[str, Any]] = None self._set_seeds(self.config.seed) @@ -280,22 +280,16 @@ def __init__(self, config: TrainerConfig, mode: str = "train") -> None: self.metrics = Metrics(self) - if self.global_rank == 0 and self.config.wandb.enable: - - # in order for resume to continue the same wandb run we need to re-use a run_id from the previous run - if self.wandb_run_id is None: - self.wandb_run_id = wandb.util.generate_id() - - # Note: wandb.init() is not type annotated so we need to use type: ignore - self.wandb_experiment = wandb.init( # type: ignore - id=self.wandb_run_id, - entity=self.config.wandb.entity, - project=self.config.wandb.project, - name=self.config.wandb.name, - config=self.config.model_dump(), - # do not put `wandb` in the root of the repo as it conflicts with wandb package - dir=f"{self.config.logger.output_dir}/wandb", - ) + if self.global_rank == 0 and self.config.experiment_tracking.enable: + from arctic_training.registry import get_registered_experiment_tracker + + tracker_cls = get_registered_experiment_tracker(self.config.experiment_tracking.type) + self.experiment_tracker = tracker_cls(self, self.config.experiment_tracking) + + if self.experiment_tracker_state is not None: + self.experiment_tracker.set_resume_state(self.experiment_tracker_state) + + self.experiment_tracker.start(self.config.model_dump()) def _set_seeds(self, seed: int) -> None: logger.info(f"Setting random seeds to {seed}") @@ -501,11 +495,11 @@ def epoch(self) -> None: append_json_file(self.config.train_log_metrics_path, metrics) - # do not log the first train iteration to wandb, since it's a massive outlier + # do not log the first train iteration, since it's a massive outlier # on all performance metrics, which messes up the scale of the report - if self.wandb_experiment is not None and self.global_step > 1: + if self.experiment_tracker is not None and self.global_step > 1: metrics = {k: v for k, v in metrics.items() if k not in ["iter"]} - self.wandb_experiment.log(metrics, step=self.global_step) + self.experiment_tracker.log_metrics(metrics, step=self.global_step) if self.config.eval_interval != 0 and self.global_step % self.config.eval_interval == 0: self.evaluate() @@ -513,9 +507,9 @@ def epoch(self) -> None: if self.is_eval_log_iter(): self.metrics.print_summary(prefix="eval") - if self.wandb_experiment is not None: + if self.experiment_tracker is not None: metrics = {k: self.metrics.summary_dict[k] for k in ["loss/eval"]} - self.wandb_experiment.log(metrics, step=self.global_step) + self.experiment_tracker.log_metrics(metrics, step=self.global_step) self.metrics.stop_timer("iter") self.epoch_finished = True @@ -551,8 +545,8 @@ def train(self) -> None: if self.config.mem_profiler is not None: torch.cuda.memory._dump_snapshot(self.config.mem_profiler_dir / f"{self.global_rank}.pickle") - if self.wandb_experiment is not None: - self.wandb_experiment.finish() + if self.experiment_tracker is not None: + self.experiment_tracker.finish() @callback_wrapper("evaluate") def evaluate(self) -> None: diff --git a/projects/arctic_embed/examples/finetune_models/finetune_e5_base_unsupervised.py b/projects/arctic_embed/examples/finetune_models/finetune_e5_base_unsupervised.py index 52bdb1d4..f76c0c5d 100644 --- a/projects/arctic_embed/examples/finetune_models/finetune_e5_base_unsupervised.py +++ b/projects/arctic_embed/examples/finetune_models/finetune_e5_base_unsupervised.py @@ -39,7 +39,7 @@ from arctic_training.config.checkpoint import CheckpointConfig from arctic_training.config.logger import LoggerConfig from arctic_training.config.optimizer import OptimizerConfig -from arctic_training.config.wandb import WandBConfig +from arctic_training.experiment_tracking.wandb_tracker import WandBExperimentTrackingConfig from arctic_training.scheduler.wsd_factory import WSDSchedulerConfig LEARNING_RATE = 3e-5 @@ -75,7 +75,7 @@ def now_timestamp_str() -> str: sconf = WSDSchedulerConfig(num_warmup_steps=500, num_decay_steps=1_000, learning_rate=LEARNING_RATE) oconf = OptimizerConfig(weight_decay=0.01, learning_rate=LEARNING_RATE) lconf = LoggerConfig(level="INFO") -wconf = WandBConfig( +wconf = WandBExperimentTrackingConfig( enable=True, project="arctic-training-arctic-embed-testbed", name=f"e5-base-unsupervised-finetune-{ts}", @@ -132,7 +132,7 @@ def configure_non_distributed_distributed_training_if_needed() -> None: optimizer=oconf, logger=lconf, checkpoint=cconf, - wandb=wconf, + experiment_tracking=wconf, deepspeed=dsconf, loss_log_interval=0, eval_interval=100, diff --git a/projects/arctic_embed/src/arctic_embed/trainer.py b/projects/arctic_embed/src/arctic_embed/trainer.py index 333f57b7..901e47fc 100644 --- a/projects/arctic_embed/src/arctic_embed/trainer.py +++ b/projects/arctic_embed/src/arctic_embed/trainer.py @@ -157,17 +157,18 @@ class BiencoderTrainer(Trainer): @property def is_wandb_logger(self) -> bool: - return self.global_rank == 0 and self.config.wandb.enable + return self.global_rank == 0 and self.config.experiment_tracking.enable def pre_train_callback(self) -> None: # Turn on weights and biases on the master worker. if self.is_wandb_logger: import wandb + et = self.config.experiment_tracking wandb.init( - project=self.config.wandb.project, + project=getattr(et, "project", None), config=self.config.model_dump(), - name=self.config.wandb.name, + name=getattr(et, "name", None), dir="/tmp/wandb", save_code=False, )