diff --git a/python/arcticdb/options.py b/python/arcticdb/options.py index e2d44f0db49..9eb3848b2c1 100644 --- a/python/arcticdb/options.py +++ b/python/arcticdb/options.py @@ -191,17 +191,35 @@ class OutputFormat(str, Enum): PYARROW = "PYARROW" POLARS = "POLARS" + @staticmethod + def resolve(value: Union["OutputFormat", str, None], default: "OutputFormat" = None) -> Optional["OutputFormat"]: + """Convert a string or OutputFormat to an OutputFormat enum value. + + Case-insensitive string matching for backwards compatibility. + Returns default (which may be None) when value is None. + Raises ValueError for unknown string values. + """ + if value is None: + return default + if isinstance(value, OutputFormat): + return value + try: + return OutputFormat(value.upper()) + except (ValueError, AttributeError): + raise ValueError(f"Unknown OutputFormat: {value!r}. Expected OutputFormat enum or string.") + def output_format_to_internal(output_format: Union[OutputFormat, str]) -> InternalOutputFormat: - if output_format.lower() == OutputFormat.PANDAS.lower(): + fmt = OutputFormat.resolve(output_format) + if fmt == OutputFormat.PANDAS: return InternalOutputFormat.PANDAS - elif output_format.lower() == OutputFormat.PYARROW.lower(): + elif fmt == OutputFormat.PYARROW: if not _PYARROW_AVAILABLE: raise ModuleNotFoundError( "ArcticDB's pyarrow optional dependency missing but is required to use arrow output format." ) return InternalOutputFormat.ARROW - elif output_format.lower() == OutputFormat.POLARS.lower(): + elif fmt == OutputFormat.POLARS: if not _PYARROW_AVAILABLE or not _POLARS_AVAILABLE: raise ModuleNotFoundError( "ArcticDB's pyarrow or polars optional dependencies are missing but are required to use polars output format." @@ -267,7 +285,7 @@ def arrow_output_string_format_to_internal( or _PYARROW_AVAILABLE and arrow_string_format == pa.string() ): - if output_format.lower() == OutputFormat.POLARS.lower(): + if OutputFormat.resolve(output_format) == OutputFormat.POLARS: raise ValueError( "SMALL_STRING is not supported with POLARS output format. Please use LARGE_STRING instead." ) @@ -283,11 +301,11 @@ def __init__( output_format: Union[OutputFormat, str] = OutputFormat.PANDAS, arrow_string_format_default: ArrowOutputStringFormat = ArrowOutputStringFormat.LARGE_STRING, ): - self.output_format = output_format + self.output_format = OutputFormat.resolve(output_format) self.arrow_string_format_default = arrow_string_format_default def set_output_format(self, output_format: Union[OutputFormat, str]): - self.output_format = output_format + self.output_format = OutputFormat.resolve(output_format) def set_arrow_string_format_default(self, arrow_string_format_default: ArrowOutputStringFormat): self.arrow_string_format_default = arrow_string_format_default diff --git a/python/arcticdb/version_store/_store.py b/python/arcticdb/version_store/_store.py index f0222f38914..78f922b1fd2 100644 --- a/python/arcticdb/version_store/_store.py +++ b/python/arcticdb/version_store/_store.py @@ -2283,15 +2283,13 @@ def _get_read_queries( return read_queries - def _get_read_options_and_output_format( - self, **kwargs - ) -> Tuple[_PythonVersionStoreReadOptions, Union[OutputFormat, str]]: + def _get_read_options_and_output_format(self, **kwargs) -> Tuple[_PythonVersionStoreReadOptions, OutputFormat]: proto_cfg = self._lib_cfg.lib_desc.version.write_options read_options = _PythonVersionStoreReadOptions() read_options.set_force_strings_to_object(_assume_false("force_string_to_object", kwargs)) read_options.set_optimise_string_memory(_assume_false("optimise_string_memory", kwargs)) - output_format = self.resolve_runtime_defaults( - "output_format", proto_cfg, global_default=OutputFormat.PANDAS, **kwargs + output_format = OutputFormat.resolve( + self.resolve_runtime_defaults("output_format", proto_cfg, global_default=OutputFormat.PANDAS, **kwargs) ) read_options.set_output_format(output_format_to_internal(output_format)) read_options.set_dynamic_schema(resolve_defaults("dynamic_schema", proto_cfg, global_default=False, **kwargs)) @@ -2895,7 +2893,7 @@ def _adapt_frame_data(self, frame_data, norm, output_format): ) if self._test_convert_arrow_back_to_pandas: data = convert_arrow_to_pandas_for_tests(data) - if output_format.lower() == OutputFormat.POLARS.lower(): + if output_format == OutputFormat.POLARS: data = pl.from_arrow(data, rechunk=False) else: data = self._normalizer.denormalize(frame_data, norm) diff --git a/python/tests/unit/arcticdb/test_arrow_api.py b/python/tests/unit/arcticdb/test_arrow_api.py index 135fbfb0f82..a83e76401b7 100644 --- a/python/tests/unit/arcticdb/test_arrow_api.py +++ b/python/tests/unit/arcticdb/test_arrow_api.py @@ -35,11 +35,12 @@ def expected_output_type(arctic_output_format, library_output_format, output_for expected_output_format = ( output_format_override or library_output_format or arctic_output_format or OutputFormat.PANDAS ) - if expected_output_format.lower() == OutputFormat.PANDAS.lower(): + fmt = OutputFormat.resolve(expected_output_format, default=OutputFormat.PANDAS) + if fmt == OutputFormat.PANDAS: return pd.DataFrame - if expected_output_format.lower() == OutputFormat.PYARROW.lower(): + if fmt == OutputFormat.PYARROW: return pa.Table - if expected_output_format.lower() == OutputFormat.POLARS.lower(): + if fmt == OutputFormat.POLARS: return pl.DataFrame raise ValueError("Unexpected format") diff --git a/python/tests/unit/arcticdb/test_output_format.py b/python/tests/unit/arcticdb/test_output_format.py new file mode 100644 index 00000000000..c796643a3b0 --- /dev/null +++ b/python/tests/unit/arcticdb/test_output_format.py @@ -0,0 +1,68 @@ +import pytest + +from arcticdb.options import OutputFormat, RuntimeOptions, output_format_to_internal +from arcticdb_ext.version_store import InternalOutputFormat + + +@pytest.mark.parametrize("value", [OutputFormat.PANDAS, OutputFormat.PYARROW, OutputFormat.POLARS]) +def test_resolve_enum_passthrough(value): + assert OutputFormat.resolve(value) is value + + +@pytest.mark.parametrize( + "value, expected", + [ + ("PANDAS", OutputFormat.PANDAS), + ("pandas", OutputFormat.PANDAS), + ("Pandas", OutputFormat.PANDAS), + ("PYARROW", OutputFormat.PYARROW), + ("pyarrow", OutputFormat.PYARROW), + ("PyArrow", OutputFormat.PYARROW), + ("POLARS", OutputFormat.POLARS), + ("polars", OutputFormat.POLARS), + ], +) +def test_resolve_case_insensitive_strings(value, expected): + assert OutputFormat.resolve(value) == expected + + +def test_resolve_none_returns_none(): + assert OutputFormat.resolve(None) is None + + +def test_resolve_none_with_default(): + assert OutputFormat.resolve(None, default=OutputFormat.PANDAS) == OutputFormat.PANDAS + + +@pytest.mark.parametrize("value", ["INVALID", "", "arrow", "dataframe", 123]) +def test_resolve_invalid_raises(value): + with pytest.raises(ValueError, match="Unknown OutputFormat"): + OutputFormat.resolve(value) + + +def test_output_format_to_internal_pandas(): + assert output_format_to_internal(OutputFormat.PANDAS) == InternalOutputFormat.PANDAS + + +def test_output_format_to_internal_pyarrow(): + assert output_format_to_internal(OutputFormat.PYARROW) == InternalOutputFormat.ARROW + + +def test_output_format_to_internal_string_input(): + assert output_format_to_internal("pyarrow") == InternalOutputFormat.ARROW + + +def test_runtime_options_default_output_format(): + opts = RuntimeOptions() + assert opts.output_format == OutputFormat.PANDAS + + +def test_runtime_options_string_resolved_on_init(): + opts = RuntimeOptions(output_format="pyarrow") + assert opts.output_format == OutputFormat.PYARROW + + +def test_runtime_options_set_output_format_resolves_string(): + opts = RuntimeOptions() + opts.set_output_format("polars") + assert opts.output_format == OutputFormat.POLARS