Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions python/arcticdb/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,17 +191,35 @@ class OutputFormat(str, Enum):
PYARROW = "PYARROW"
POLARS = "POLARS"

@staticmethod
def resolve(value: Union["OutputFormat", str, None], default: "OutputFormat" = None) -> Optional["OutputFormat"]:
"""Convert a string or OutputFormat to an OutputFormat enum value.

Case-insensitive string matching for backwards compatibility.
Returns default (which may be None) when value is None.
Raises ValueError for unknown string values.
"""
if value is None:
return default
if isinstance(value, OutputFormat):
return value
try:
return OutputFormat(value.upper())
Comment thread
jamesblackburn marked this conversation as resolved.
except (ValueError, AttributeError):
raise ValueError(f"Unknown OutputFormat: {value!r}. Expected OutputFormat enum or string.")


def output_format_to_internal(output_format: Union[OutputFormat, str]) -> InternalOutputFormat:
if output_format.lower() == OutputFormat.PANDAS.lower():
fmt = OutputFormat.resolve(output_format)
if fmt == OutputFormat.PANDAS:
return InternalOutputFormat.PANDAS
elif output_format.lower() == OutputFormat.PYARROW.lower():
elif fmt == OutputFormat.PYARROW:
if not _PYARROW_AVAILABLE:
raise ModuleNotFoundError(
"ArcticDB's pyarrow optional dependency missing but is required to use arrow output format."
)
return InternalOutputFormat.ARROW
elif output_format.lower() == OutputFormat.POLARS.lower():
elif fmt == OutputFormat.POLARS:
if not _PYARROW_AVAILABLE or not _POLARS_AVAILABLE:
raise ModuleNotFoundError(
"ArcticDB's pyarrow or polars optional dependencies are missing but are required to use polars output format."
Comment thread
jamesblackburn marked this conversation as resolved.
Expand Down Expand Up @@ -267,7 +285,7 @@ def arrow_output_string_format_to_internal(
or _PYARROW_AVAILABLE
and arrow_string_format == pa.string()
):
if output_format.lower() == OutputFormat.POLARS.lower():
if OutputFormat.resolve(output_format) == OutputFormat.POLARS:
raise ValueError(
"SMALL_STRING is not supported with POLARS output format. Please use LARGE_STRING instead."
)
Expand All @@ -283,11 +301,11 @@ def __init__(
output_format: Union[OutputFormat, str] = OutputFormat.PANDAS,
arrow_string_format_default: ArrowOutputStringFormat = ArrowOutputStringFormat.LARGE_STRING,
):
self.output_format = output_format
self.output_format = OutputFormat.resolve(output_format)
self.arrow_string_format_default = arrow_string_format_default

def set_output_format(self, output_format: Union[OutputFormat, str]):
self.output_format = output_format
self.output_format = OutputFormat.resolve(output_format)

def set_arrow_string_format_default(self, arrow_string_format_default: ArrowOutputStringFormat):
self.arrow_string_format_default = arrow_string_format_default
Expand Down
10 changes: 4 additions & 6 deletions python/arcticdb/version_store/_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2283,15 +2283,13 @@ def _get_read_queries(

return read_queries

def _get_read_options_and_output_format(
self, **kwargs
) -> Tuple[_PythonVersionStoreReadOptions, Union[OutputFormat, str]]:
def _get_read_options_and_output_format(self, **kwargs) -> Tuple[_PythonVersionStoreReadOptions, OutputFormat]:
proto_cfg = self._lib_cfg.lib_desc.version.write_options
read_options = _PythonVersionStoreReadOptions()
read_options.set_force_strings_to_object(_assume_false("force_string_to_object", kwargs))
read_options.set_optimise_string_memory(_assume_false("optimise_string_memory", kwargs))
output_format = self.resolve_runtime_defaults(
"output_format", proto_cfg, global_default=OutputFormat.PANDAS, **kwargs
output_format = OutputFormat.resolve(
self.resolve_runtime_defaults("output_format", proto_cfg, global_default=OutputFormat.PANDAS, **kwargs)
)
read_options.set_output_format(output_format_to_internal(output_format))
read_options.set_dynamic_schema(resolve_defaults("dynamic_schema", proto_cfg, global_default=False, **kwargs))
Expand Down Expand Up @@ -2895,7 +2893,7 @@ def _adapt_frame_data(self, frame_data, norm, output_format):
)
if self._test_convert_arrow_back_to_pandas:
data = convert_arrow_to_pandas_for_tests(data)
if output_format.lower() == OutputFormat.POLARS.lower():
if output_format == OutputFormat.POLARS:
data = pl.from_arrow(data, rechunk=False)
else:
data = self._normalizer.denormalize(frame_data, norm)
Expand Down
7 changes: 4 additions & 3 deletions python/tests/unit/arcticdb/test_arrow_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ def expected_output_type(arctic_output_format, library_output_format, output_for
expected_output_format = (
output_format_override or library_output_format or arctic_output_format or OutputFormat.PANDAS
)
if expected_output_format.lower() == OutputFormat.PANDAS.lower():
fmt = OutputFormat.resolve(expected_output_format, default=OutputFormat.PANDAS)
if fmt == OutputFormat.PANDAS:
return pd.DataFrame
if expected_output_format.lower() == OutputFormat.PYARROW.lower():
if fmt == OutputFormat.PYARROW:
return pa.Table
if expected_output_format.lower() == OutputFormat.POLARS.lower():
if fmt == OutputFormat.POLARS:
return pl.DataFrame
raise ValueError("Unexpected format")

Expand Down
68 changes: 68 additions & 0 deletions python/tests/unit/arcticdb/test_output_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import pytest

from arcticdb.options import OutputFormat, RuntimeOptions, output_format_to_internal
from arcticdb_ext.version_store import InternalOutputFormat


@pytest.mark.parametrize("value", [OutputFormat.PANDAS, OutputFormat.PYARROW, OutputFormat.POLARS])
def test_resolve_enum_passthrough(value):
assert OutputFormat.resolve(value) is value


@pytest.mark.parametrize(
"value, expected",
[
("PANDAS", OutputFormat.PANDAS),
("pandas", OutputFormat.PANDAS),
("Pandas", OutputFormat.PANDAS),
("PYARROW", OutputFormat.PYARROW),
("pyarrow", OutputFormat.PYARROW),
("PyArrow", OutputFormat.PYARROW),
("POLARS", OutputFormat.POLARS),
("polars", OutputFormat.POLARS),
],
)
def test_resolve_case_insensitive_strings(value, expected):
assert OutputFormat.resolve(value) == expected


def test_resolve_none_returns_none():
assert OutputFormat.resolve(None) is None


def test_resolve_none_with_default():
assert OutputFormat.resolve(None, default=OutputFormat.PANDAS) == OutputFormat.PANDAS


@pytest.mark.parametrize("value", ["INVALID", "", "arrow", "dataframe", 123])
def test_resolve_invalid_raises(value):
with pytest.raises(ValueError, match="Unknown OutputFormat"):
OutputFormat.resolve(value)


def test_output_format_to_internal_pandas():
assert output_format_to_internal(OutputFormat.PANDAS) == InternalOutputFormat.PANDAS


def test_output_format_to_internal_pyarrow():
assert output_format_to_internal(OutputFormat.PYARROW) == InternalOutputFormat.ARROW


def test_output_format_to_internal_string_input():
assert output_format_to_internal("pyarrow") == InternalOutputFormat.ARROW


def test_runtime_options_default_output_format():
opts = RuntimeOptions()
assert opts.output_format == OutputFormat.PANDAS


def test_runtime_options_string_resolved_on_init():
opts = RuntimeOptions(output_format="pyarrow")
assert opts.output_format == OutputFormat.PYARROW


def test_runtime_options_set_output_format_resolves_string():
opts = RuntimeOptions()
opts.set_output_format("polars")
assert opts.output_format == OutputFormat.POLARS
Loading