Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
351 changes: 287 additions & 64 deletions library/README.md

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions library/docs/source/guide/get_started/api_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ If you want to use other models offered by Geti Library, you can get a list of a

.. code-block:: python

from getitune.backend.lightning.cli.utils import list_models
from getitune.utils import list_models

model_lists = list_models(task="DETECTION")
print(model_lists)
Expand Down Expand Up @@ -113,7 +113,7 @@ If you want to use other models offered by Geti Library, you can get a list of a

.. code-block:: python

from getitune.backend.lightning.cli.utils import list_models
from getitune.utils import list_models

model_lists = list_models(task="DETECTION", print_table=True)

Expand All @@ -132,7 +132,7 @@ If you want to use other models offered by Geti Library, you can get a list of a

.. code-block:: python

from getitune.backend.lightning.cli.utils import list_models
from getitune.utils import list_models

model_lists = list_models(task="DETECTION", pattern="tile")
print(model_lists)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ The list of supported recipes for classification is available with the command l

.. code-block:: python

from getitune.backend.lightning.cli.utils import list_models
from getitune.utils import list_models

model_lists = list_models(task="MULTI_CLASS_CLS", pattern="*efficient")
print(model_lists)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ The list of supported recipes for object detection is available with the command

.. code-block:: python

from getitune.backend.lightning.cli.utils import list_models
from getitune.utils import list_models

model_lists = list_models(task="DETECTION", pattern="atss")
print(model_lists)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ The list of supported recipes for instance segmentation is available with the co

.. code-block:: python

from getitune.backend.lightning.cli.utils import list_models
from getitune.utils import list_models

model_lists = list_models(task="INSTANCE_SEGMENTATION")
print(model_lists)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ The list of supported recipes for semantic segmentation is available with the co

.. code-block:: python

from getitune.backend.lightning.cli.utils import list_models
from getitune.utils import list_models

model_lists = list_models(task="SEMANTIC_SEGMENTATION")
print(model_lists)
Expand Down
11 changes: 0 additions & 11 deletions library/src/getitune/backend/lightning/cli/__init__.py

This file was deleted.

118 changes: 60 additions & 58 deletions library/src/getitune/backend/lightning/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,84 +709,86 @@ def dummy_infer(model: LightningModel, batch_size: int = 1) -> float:
def from_config(
cls,
config_path: PathLike,
data_root: PathLike | None = None,
data: DataModule | PathLike | None = None,
work_dir: PathLike | None = None,
device: str | None = None,
checkpoint: str | None = None,
task: str | None = None,
**kwargs,
) -> LightningEngine:
"""Builds the engine from a configuration file.

Args:
config_path (PathLike): The configuration file path.
data_root (PathLike | None): Root directory for the data.
Defaults to None. If data_root is None, use the data_root from the configuration file.
data (DataModule | PathLike | None): Either a pre-built
:class:`~getitune.data.module.DataModule` or a root directory
path for the data. When *None*, the data root is read from
the configuration file.
work_dir (PathLike | None, optional): Working directory for the engine.
Defaults to None. If work_dir is None, use the work_dir from the configuration file.
kwargs: Arguments that can override the engine's arguments.
device (str | None, optional): Device to use (e.g., ``"auto"``, ``"xpu"``, ``"cpu"``, ``"gpu"``).
Defaults to None.
checkpoint (str | None, optional): Path to a checkpoint for pretrained or warm-start weights.
Defaults to None.
task (str | None, optional): Task type for disambiguation. Defaults to None.
kwargs: Backend-specific keyword arguments forwarded to the engine constructor.

Returns:
Engine: An instance of the Engine class.
LightningEngine: An instance of the LightningEngine class.

Example:
>>> engine = LightningEngine.from_config(
... config="config.yaml",
... config_path="config.yaml",
... data="/path/to/dataset",
... )
... engine.train()
"""
from getitune.cli.utils.jsonargparse import get_instantiated_classes

# For the Engine argument, prepend 'engine.' for CLI parser
filter_kwargs = ["device", "checkpoint", "task"]
for key in filter_kwargs:
if key in kwargs:
kwargs[f"engine.{key}"] = kwargs.pop(key)
# Separate a pre-built DataModule from a plain data-root path so that
# the CLI parser only receives a file-system path (or None).
provided_datamodule: DataModule | None = None
if isinstance(data, DataModule):
provided_datamodule = data
# Give the CLI parser the DataModule's own data_root so it can
# still instantiate the model with the correct label_info from
# the dataset. The config-parsed DataModule is discarded below.
data_root: PathLike | None = getattr(data, "data_root", None)
else:
data_root = data # PathLike | None

# Route common args explicitly under the process's "engine" namespace.
process_kwargs: dict[str, object] = dict(kwargs)
if device is not None:
process_kwargs["engine.device"] = device
if checkpoint is not None:
process_kwargs["engine.checkpoint"] = checkpoint
if task is not None:
process_kwargs["engine.task"] = task
Comment thread
kprokofi marked this conversation as resolved.
Outdated
instantiated_config, train_kwargs = get_instantiated_classes(
config=config_path,
data_root=data_root,
work_dir=work_dir,
**kwargs,
**process_kwargs,
)
engine_kwargs = {**instantiated_config.get("engine", {}), **train_kwargs}

# Remove any input that is not currently available in Engine and print a warning message.
set_valid_args = TrainerArgumentsCache.get_trainer_constructor_args().union(
set(inspect.signature(LightningEngine.__init__).parameters.keys()),
valid_keys = TrainerArgumentsCache.get_trainer_constructor_args() | set(
inspect.signature(LightningEngine.__init__).parameters,
)
# Keys that are legitimate train/test/export method kwargs. They come
# from OTXCLI.prepare_subcommand_kwargs("train") being merged into the
# engine kwargs, and should be silently dropped here -- warning about
# them is noise, because they are consumed at method-call time by the
# caller, not by the engine ctor.
known_method_kwargs = {
"resume",
"adaptive_bs",
"max_epochs",
"run_hpo",
"hpo_config",
"checkpoint",
"export_demo_package",
"export_format",
"explain",
"dump_options",
"precision",
}
Comment thread
kprokofi marked this conversation as resolved.
removed_args = []
for engine_key in list(engine_kwargs.keys()):
if engine_key not in set_valid_args:
engine_kwargs.pop(engine_key)
if engine_key not in known_method_kwargs:
removed_args.append(engine_key)
if removed_args:
msg = (
f"Warning: {removed_args} -> not available in Engine constructor. "
"It will be ignored. Use what need in the right places."
)
warn(msg, stacklevel=1)
merged = {**instantiated_config.get("engine", {}), **train_kwargs}
engine_kwargs = {k: v for k, v in merged.items() if k in valid_keys}

if (datamodule := instantiated_config.get("data")) is None:
msg = "Cannot instantiate datamodule from config."
raise ValueError(msg)
if not isinstance(datamodule, DataModule):
raise TypeError(datamodule)
# Use the caller-supplied DataModule when provided; otherwise build
# one from the parsed configuration.
if provided_datamodule is not None:
datamodule: DataModule = provided_datamodule
else:
raw_data: DataModule | Any = instantiated_config.get("data")
if raw_data is None:
msg = "Cannot instantiate datamodule from config."
raise ValueError(msg)
if not isinstance(raw_data, DataModule):
raise TypeError(raw_data)
datamodule = raw_data

if (model := instantiated_config.get("model")) is None:
msg = "Cannot instantiate model from config."
Expand Down Expand Up @@ -814,7 +816,7 @@ def from_model_name(
cls,
model_name: str,
task: TaskType,
data_root: PathLike | None = None,
data: PathLike | None = None,
work_dir: PathLike | None = None,
**kwargs,
) -> LightningEngine:
Expand All @@ -823,8 +825,8 @@ def from_model_name(
Args:
model_name (str): The model name.
task (TaskType): The type of getitune task.
data_root (PathLike | None): Root directory for the data.
Defaults to None. If data_root is None, use the data_root from the configuration file.
data (PathLike | None): Root directory for the data.
Defaults to None. If data is None, use the data_root from the configuration file.
work_dir (PathLike | None, optional): Working directory for the engine.
Defaults to None. If work_dir is None, use the work_dir from the configuration file.
kwargs: Arguments that can override the engine's arguments.
Expand All @@ -836,7 +838,7 @@ def from_model_name(
>>> engine = LightningEngine.from_model_name(
... model_name="atss_mobilenetv2",
... task="DETECTION",
... data_root=<dataset/path>,
... data=<dataset/path>,
... )
... engine.train()

Expand All @@ -848,7 +850,7 @@ def from_model_name(
>>> engine = LightningEngine(
... model_name="atss_mobilenetv2",
... task="DETECTION",
... data_root=<dataset/path>,
... data=<dataset/path>,
... **overriding,
... )
"""
Expand All @@ -866,7 +868,7 @@ def from_model_name(

return cls.from_config(
config_path=config,
data_root=data_root,
data=data,
work_dir=work_dir,
**kwargs,
)
Expand Down
35 changes: 35 additions & 0 deletions library/src/getitune/backend/openvino/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,41 @@ def is_supported(model: MODEL, data: DATA) -> bool:

return check_model and check_data

@classmethod
def from_config(
cls,
config_path: PathLike,
data: DataModule | PathLike | None = None,
work_dir: PathLike | None = None,
device: str | None = None,
checkpoint: str | None = None,
task: str | None = None,
**kwargs,
) -> OVEngine:
"""OVEngine does not support construction from a recipe config.

OpenVINO models are selected by passing a ``.xml`` or ``.onnx``
weights path directly to :func:`~getitune.engine.create_engine` or
as the *model* argument to :class:`OVEngine`.

Args:
config_path: Unused — included for API compatibility.
data: Unused — included for API compatibility.
work_dir: Unused — included for API compatibility.
device: Unused — included for API compatibility.
checkpoint: Unused — included for API compatibility.
task: Unused — included for API compatibility.
**kwargs: Unused — included for API compatibility.

Raises:
NotImplementedError: Always raised.
"""
msg = (
f"OVEngine does not support construction from a recipe config '{config_path}'. "
"Pass a .xml or .onnx model path directly to create_engine() instead."
)
raise NotImplementedError(msg)

def _update_checkpoint(self, checkpoint: PathLike | None) -> OVModel:
"""Update the OVModel with the given checkpoint path.

Expand Down
Loading