Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions arctic_training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from arctic_training.checkpoint.hf_engine import HFCheckpointEngine
from arctic_training.config.checkpoint import CheckpointConfig
from arctic_training.config.data import DataConfig
from arctic_training.config.experiment_tracking import ExperimentTrackingConfig
from arctic_training.config.logger import LoggerConfig
from arctic_training.config.model import ModelConfig
from arctic_training.config.optimizer import OptimizerConfig
Expand All @@ -39,6 +40,11 @@
from arctic_training.data.sft_factory import SFTDataFactory
from arctic_training.data.snowflake_source import SnowflakeDataSource
from arctic_training.data.source import DataSource
from arctic_training.experiment_tracking.snowflake_tracker import SnowflakeExperimentTrackingConfig
from arctic_training.experiment_tracking.snowflake_tracker import SnowflakeExpTracker
from arctic_training.experiment_tracking.tracker import ExperimentTracker
from arctic_training.experiment_tracking.wandb_tracker import WandBExperimentTrackingConfig
from arctic_training.experiment_tracking.wandb_tracker import WandBTracker
from arctic_training.logging import logger
from arctic_training.model.factory import ModelFactory
from arctic_training.model.hf_factory import HFModelFactory
Expand Down
8 changes: 6 additions & 2 deletions arctic_training/checkpoint/ds_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,11 @@ def client_state(self) -> Dict[str, Any]:
"np_random_state": np.random.get_state(),
"python_random_state": random.getstate(),
"global_step": self.trainer.global_step,
"wandb_run_id": self.trainer.wandb_run_id,
"experiment_tracker_state": (
self.trainer.experiment_tracker.get_resume_state()
if self.trainer.experiment_tracker is not None
else None
),
}
if self.device != torch.device("cpu"):
state["torch_cuda_random_state"] = torch.cuda.get_rng_state()
Expand Down Expand Up @@ -112,7 +116,7 @@ def load(self, model) -> None:
if self.device != torch.device("cpu"):
torch.cuda.set_rng_state(client_states["torch_cuda_random_state"])

self.trainer.wandb_run_id = client_states["wandb_run_id"]
self.trainer.experiment_tracker_state = client_states.get("experiment_tracker_state")

# Helpful ckpt resume debugging snippet
# norm = model_norm(model)
Expand Down
24 changes: 24 additions & 0 deletions arctic_training/config/experiment_tracking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from arctic_training.config.base import BaseConfig


class ExperimentTrackingConfig(BaseConfig):
type: str = "wandb"
""" Experiment tracking backend type. """

enable: bool = False
""" Whether to enable experiment tracking. """
34 changes: 31 additions & 3 deletions arctic_training/config/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from arctic_training.config.checkpoint import CheckpointConfig
from arctic_training.config.data import DataConfig
from arctic_training.config.enums import DType
from arctic_training.config.experiment_tracking import ExperimentTrackingConfig
from arctic_training.config.logger import LoggerConfig
from arctic_training.config.model import ModelConfig
from arctic_training.config.optimizer import OptimizerConfig
Expand All @@ -47,10 +48,10 @@
from arctic_training.config.utils import HumanInt
from arctic_training.config.utils import UniqueKeyLoader
from arctic_training.config.utils import parse_human_val
from arctic_training.config.wandb import WandBConfig
from arctic_training.registry import _get_class_attr_type_hints
from arctic_training.registry import get_registered_checkpoint_engine
from arctic_training.registry import get_registered_data_factory
from arctic_training.registry import get_registered_experiment_tracker
from arctic_training.registry import get_registered_model_factory
from arctic_training.registry import get_registered_optimizer_factory
from arctic_training.registry import get_registered_scheduler_factory
Expand Down Expand Up @@ -88,8 +89,8 @@ class TrainerConfig(BaseConfig):
logger: LoggerConfig = Field(default_factory=LoggerConfig)
""" Logger configuration. """

wandb: WandBConfig = Field(default_factory=WandBConfig)
""" Weights and Biases configuration. """
experiment_tracking: ExperimentTrackingConfig = Field(default_factory=ExperimentTrackingConfig)
""" Experiment tracking configuration (e.g., wandb, snowflake). """

scheduler: SchedulerConfig = Field(default_factory=SchedulerConfig)
""" Scheduler configuration. """
Expand Down Expand Up @@ -336,6 +337,33 @@ def init_tokenizer_config(cls, v: Union[Dict, TokenizerConfig], info: Validation
)
return cast(TokenizerConfig, subconfig)

@model_validator(mode="before")
@classmethod
def migrate_wandb_config(cls, data: Any) -> Any:
"""Accept legacy ``wandb:`` YAML key and convert it to ``experiment_tracking:``."""
if isinstance(data, dict) and "wandb" in data:
wandb_config = data.pop("wandb")
if "experiment_tracking" not in data:
if isinstance(wandb_config, dict):
wandb_config.setdefault("type", "wandb")
data["experiment_tracking"] = wandb_config
return data

@field_validator("experiment_tracking", mode="before")
@classmethod
def init_experiment_tracking_config(
cls,
v: Union[Dict, ExperimentTrackingConfig],
) -> ExperimentTrackingConfig:
if isinstance(v, ExperimentTrackingConfig):
return v
config_dict = v if isinstance(v, dict) else {}
tracker_type = config_dict.get("type", "wandb")
tracker_cls = get_registered_experiment_tracker(tracker_type)
config_cls = _get_class_attr_type_hints(tracker_cls, "config")[0]
config_dict["type"] = tracker_type
return config_cls(**config_dict)

@model_validator(mode="after")
def validate_eval_interval(self) -> Self:
if self.data.eval_sources or self.data.train_eval_split[1] > 0.0:
Expand Down
19 changes: 3 additions & 16 deletions arctic_training/config/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
# Backward compatibility alias — use ExperimentTrackingConfig / WandBExperimentTrackingConfig instead.
from arctic_training.experiment_tracking.wandb_tracker import WandBExperimentTrackingConfig as WandBConfig

from arctic_training.config.base import BaseConfig


class WandBConfig(BaseConfig):
enable: bool = False
""" Whether to enable Weights and Biases logging. """

entity: Optional[str] = None
""" Weights and Biases entity name. """

project: Optional[str] = "arctic-training"
""" Weights and Biases project name. """

name: Optional[str] = None
""" Weights and Biases run name. """
__all__ = ["WandBConfig"]
14 changes: 14 additions & 0 deletions arctic_training/experiment_tracking/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
119 changes: 119 additions & 0 deletions arctic_training/experiment_tracking/snowflake_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING
from typing import Any
from typing import Dict
from typing import Optional

from arctic_training.config.experiment_tracking import ExperimentTrackingConfig
from arctic_training.experiment_tracking.tracker import ExperimentTracker
from arctic_training.logging import logger

if TYPE_CHECKING:
from arctic_training.trainer.trainer import Trainer


class SnowflakeExperimentTrackingConfig(ExperimentTrackingConfig):
type: str = "snowflake"

account: str = ""
""" Snowflake account identifier. """

user: str = ""
""" Snowflake user name. """

password: str = ""
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From managed API path, SPCS will be responsible for setting the config. SPCS job only has OAuth token available to be used. Could we support the authentication type token = <oauth_token>, authenticator = 'oauth'?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, let me update for different auth types and I will let you test!

""" Snowflake PAT or password. """

role: str = ""
""" Snowflake role. """

warehouse: str = ""
""" Snowflake warehouse. """

database: str = ""
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, are db and schema mandatory here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure. There might be a default - I will verify.

""" Snowflake database. """

schema_name: str = ""
""" Snowflake schema. """

experiment_name: str = ""
""" Name of the experiment in Snowflake experiment tracking. """

run_name: Optional[str] = None
""" Name of the run. If not set, one will be generated. """

@property
def connection_params(self):
return dict(
account=self.account,
user=self.user,
authentication="PAT",
password=self.password,
role=self.role,
warehouse=self.warehouse,
database=self.database,
schema=self.schema_name,
)


class SnowflakeExpTracker(ExperimentTracker):
name: str = "snowflake"
config: SnowflakeExperimentTrackingConfig

def __init__(self, trainer: "Trainer", config: SnowflakeExperimentTrackingConfig) -> None:
super().__init__(trainer, config)
self._experiment: Any = None
self._run_name: Optional[str] = config.run_name

def start(self, run_config: Dict[str, Any]) -> None:
try:
from snowflake.ml.experiment.experiment_tracking import ExperimentTracking
from snowflake.snowpark import Session
except ImportError:
raise ImportError(
"Snowflake experiment tracking requires the snowflake-ml-python and "
"snowflake-snowpark-python packages. Install them with:\n"
" pip install snowflake-ml-python snowflake-snowpark-python"
)

session = Session.builder.configs(self.config.connection_params).create()
self._experiment = ExperimentTracking(session=session)
self._experiment.set_experiment(self.config.experiment_name)

self._experiment.start_run(self._run_name)
self._experiment.log_params(run_config)
logger.info(f"Snowflake experiment tracking started: {self.config.experiment_name}/{self._run_name}")

def log_metrics(self, metrics: Dict[str, Any], step: int) -> None:
if self._experiment is not None:
self._experiment.log_metrics(metrics, step=step)

def log_params(self, params: Dict[str, Any]) -> None:
if self._experiment is not None:
self._experiment.log_params(params)

def finish(self) -> None:
if self._experiment is not None:
self._experiment.end_run()

def get_resume_state(self) -> Optional[Dict[str, Any]]:
if self._run_name is None:
return None
return {"run_name": self._run_name}

def set_resume_state(self, state: Dict[str, Any]) -> None:
self._run_name = state.get("run_name")
95 changes: 95 additions & 0 deletions arctic_training/experiment_tracking/tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from abc import abstractmethod
from typing import TYPE_CHECKING
from typing import Any
from typing import Dict
from typing import Optional

from arctic_training.config.experiment_tracking import ExperimentTrackingConfig
from arctic_training.registry import RegistryMeta
from arctic_training.registry import _validate_class_attribute_set
from arctic_training.registry import _validate_class_attribute_type
from arctic_training.registry import _validate_class_method

if TYPE_CHECKING:
from arctic_training.trainer.trainer import Trainer


class ExperimentTracker(ABC, metaclass=RegistryMeta):
"""Base class for experiment tracking backends."""

name: str
"""
Name of the experiment tracker used for registration. This name
should be unique and is used in configs to select the tracker.
"""

config: ExperimentTrackingConfig
"""
The type of the config class that the tracker uses. This should be a
subclass of ExperimentTrackingConfig.
"""

@classmethod
def _validate_subclass(cls) -> None:
_validate_class_attribute_set(cls, "name")
_validate_class_attribute_type(cls, "config", ExperimentTrackingConfig)
_validate_class_method(cls, "start", ["self", "run_config"])
_validate_class_method(cls, "log_metrics", ["self", "metrics", "step"])
_validate_class_method(cls, "log_params", ["self", "params"])
_validate_class_method(cls, "finish", ["self"])
_validate_class_method(cls, "get_resume_state", ["self"])
_validate_class_method(cls, "set_resume_state", ["self", "state"])

def __init__(self, trainer: "Trainer", config: ExperimentTrackingConfig) -> None:
self._trainer = trainer
self.config = config

@property
def trainer(self) -> "Trainer":
return self._trainer

@abstractmethod
def start(self, run_config: Dict[str, Any]) -> None:
"""Initialize and start the tracking run."""
raise NotImplementedError

@abstractmethod
def log_metrics(self, metrics: Dict[str, Any], step: int) -> None:
"""Log metrics for the current step."""
raise NotImplementedError

@abstractmethod
def log_params(self, params: Dict[str, Any]) -> None:
"""Log parameters/config for the run."""
raise NotImplementedError

@abstractmethod
def finish(self) -> None:
"""Finalize and close the tracking run."""
raise NotImplementedError

@abstractmethod
def get_resume_state(self) -> Optional[Dict[str, Any]]:
"""Return state needed to resume this tracker across checkpoint restarts."""
raise NotImplementedError

@abstractmethod
def set_resume_state(self, state: Dict[str, Any]) -> None:
"""Restore tracker state from a previous checkpoint."""
raise NotImplementedError
Loading
Loading