Skip to content

Commit 6e68c9c

Browse files
authored
Initial work for file format writer API (#3119)
Initial work for #3100. Since this is a large change, doing it in parts similar to the `AuthManager` so it's easier to review and move the existing code around. # Rationale for this change Introduces the pluggable file format writer API: `FileFormatWriter`, `FileFormatModel`, and `FileFormatFactory` in `pyiceberg/io/fileformat.py`. Moves `DataFileStatistics` from `pyarrow.py` with a re-export for backward compatibility. The move is more forward looking and the idea is to keep the stats generic in the future as we add additional formats too. This is the first part of work for #3100. No behavioral changes; the write path remains hardcoded to Parquet. ## Are these changes tested? Yes. `tests/io/test_fileformat.py` tests backward-compatible import of `DataFileStatistics` ## Are there any user-facing changes? No
1 parent 240e519 commit 6e68c9c

3 files changed

Lines changed: 265 additions & 74 deletions

File tree

pyiceberg/io/fileformat.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""File Format API for writing Iceberg data files."""
19+
20+
from __future__ import annotations
21+
22+
from abc import ABC, abstractmethod
23+
from dataclasses import dataclass
24+
from typing import TYPE_CHECKING, Any
25+
26+
from pyiceberg.io import OutputFile
27+
from pyiceberg.manifest import FileFormat
28+
from pyiceberg.partitioning import PartitionField, PartitionSpec, partition_record_value
29+
from pyiceberg.schema import Schema
30+
from pyiceberg.typedef import Properties, Record
31+
32+
if TYPE_CHECKING:
33+
import pyarrow as pa
34+
35+
from pyiceberg.io.pyarrow import StatsAggregator
36+
37+
38+
@dataclass(frozen=True)
39+
class DataFileStatistics:
40+
record_count: int
41+
column_sizes: dict[int, int]
42+
value_counts: dict[int, int]
43+
null_value_counts: dict[int, int]
44+
nan_value_counts: dict[int, int]
45+
column_aggregates: dict[int, StatsAggregator]
46+
split_offsets: list[int]
47+
48+
def _partition_value(self, partition_field: PartitionField, schema: Schema) -> Any:
49+
if partition_field.source_id not in self.column_aggregates:
50+
return None
51+
52+
source_field = schema.find_field(partition_field.source_id)
53+
iceberg_transform = partition_field.transform
54+
55+
if not iceberg_transform.preserves_order:
56+
raise ValueError(
57+
f"Cannot infer partition value from parquet metadata for a non-linear Partition Field: "
58+
f"{partition_field.name} with transform {partition_field.transform}"
59+
)
60+
61+
transform_func = iceberg_transform.transform(source_field.field_type)
62+
63+
lower_value = transform_func(
64+
partition_record_value(
65+
partition_field=partition_field,
66+
value=self.column_aggregates[partition_field.source_id].current_min,
67+
schema=schema,
68+
)
69+
)
70+
upper_value = transform_func(
71+
partition_record_value(
72+
partition_field=partition_field,
73+
value=self.column_aggregates[partition_field.source_id].current_max,
74+
schema=schema,
75+
)
76+
)
77+
if lower_value != upper_value:
78+
raise ValueError(
79+
f"Cannot infer partition value from parquet metadata as there are more than one partition values "
80+
f"for Partition Field: {partition_field.name}. {lower_value=}, {upper_value=}"
81+
)
82+
83+
return lower_value
84+
85+
def partition(self, partition_spec: PartitionSpec, schema: Schema) -> Record:
86+
return Record(*[self._partition_value(field, schema) for field in partition_spec.fields])
87+
88+
def to_serialized_dict(self) -> dict[str, Any]:
89+
lower_bounds = {}
90+
upper_bounds = {}
91+
92+
for k, agg in self.column_aggregates.items():
93+
_min = agg.min_as_bytes()
94+
if _min is not None:
95+
lower_bounds[k] = _min
96+
_max = agg.max_as_bytes()
97+
if _max is not None:
98+
upper_bounds[k] = _max
99+
return {
100+
"record_count": self.record_count,
101+
"column_sizes": self.column_sizes,
102+
"value_counts": self.value_counts,
103+
"null_value_counts": self.null_value_counts,
104+
"nan_value_counts": self.nan_value_counts,
105+
"lower_bounds": lower_bounds,
106+
"upper_bounds": upper_bounds,
107+
"split_offsets": self.split_offsets,
108+
}
109+
110+
111+
class FileFormatWriter(ABC):
112+
"""Writes data to a single file in a specific format."""
113+
114+
_result: DataFileStatistics | None = None
115+
116+
@abstractmethod
117+
def write(self, table: pa.Table) -> None:
118+
"""Write a batch of data. May be called multiple times."""
119+
120+
@abstractmethod
121+
def close(self) -> DataFileStatistics:
122+
"""Finalize the file and return statistics."""
123+
124+
def result(self) -> DataFileStatistics:
125+
"""Return statistics from a previous close() call."""
126+
if self._result is None:
127+
raise RuntimeError("Writer has not been closed yet")
128+
return self._result
129+
130+
def __enter__(self) -> FileFormatWriter:
131+
"""Enter the context manager."""
132+
return self
133+
134+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
135+
"""Exit the context manager, closing the writer and caching statistics."""
136+
if exc_type is not None:
137+
try:
138+
self.close()
139+
except Exception:
140+
pass
141+
return
142+
self._result = self.close()
143+
144+
145+
class FileFormatModel(ABC):
146+
"""Represents a file format's capabilities. Creates writers."""
147+
148+
@property
149+
@abstractmethod
150+
def format(self) -> FileFormat: ...
151+
152+
@abstractmethod
153+
def file_extension(self) -> str:
154+
"""Return file extension without dot, e.g. 'parquet', 'orc'."""
155+
156+
@abstractmethod
157+
def create_writer(
158+
self,
159+
output_file: OutputFile,
160+
file_schema: Schema,
161+
properties: Properties,
162+
) -> FileFormatWriter: ...
163+
164+
165+
class FileFormatFactory:
166+
"""Registry of FileFormatModel implementations."""
167+
168+
_registry: dict[FileFormat, FileFormatModel] = {}
169+
170+
@classmethod
171+
def register(cls, model: FileFormatModel) -> None:
172+
if model.format in cls._registry:
173+
existing = cls._registry[model.format]
174+
raise ValueError(
175+
f"Cannot register {type(model).__name__}: {type(existing).__name__} is already registered for {model.format}"
176+
)
177+
cls._registry[model.format] = model
178+
179+
@classmethod
180+
def get(cls, file_format: FileFormat) -> FileFormatModel:
181+
if file_format not in cls._registry:
182+
raise ValueError(f"No writer registered for {file_format}. Available: {list(cls._registry.keys())}")
183+
return cls._registry[file_format]
184+
185+
@classmethod
186+
def available_formats(cls) -> list[FileFormat]:
187+
return list(cls._registry.keys())

pyiceberg/io/pyarrow.py

Lines changed: 2 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,13 @@
120120
OutputFile,
121121
OutputStream,
122122
)
123+
from pyiceberg.io.fileformat import DataFileStatistics as DataFileStatistics
123124
from pyiceberg.manifest import (
124125
DataFile,
125126
DataFileContent,
126127
FileFormat,
127128
)
128-
from pyiceberg.partitioning import PartitionField, PartitionFieldValue, PartitionKey, PartitionSpec, partition_record_value
129+
from pyiceberg.partitioning import PartitionFieldValue, PartitionKey, PartitionSpec
129130
from pyiceberg.schema import (
130131
PartnerAccessor,
131132
PreOrderSchemaVisitor,
@@ -2473,79 +2474,6 @@ def parquet_path_to_id_mapping(
24732474
return result
24742475

24752476

2476-
@dataclass(frozen=True)
2477-
class DataFileStatistics:
2478-
record_count: int
2479-
column_sizes: dict[int, int]
2480-
value_counts: dict[int, int]
2481-
null_value_counts: dict[int, int]
2482-
nan_value_counts: dict[int, int]
2483-
column_aggregates: dict[int, StatsAggregator]
2484-
split_offsets: list[int]
2485-
2486-
def _partition_value(self, partition_field: PartitionField, schema: Schema) -> Any:
2487-
if partition_field.source_id not in self.column_aggregates:
2488-
return None
2489-
2490-
source_field = schema.find_field(partition_field.source_id)
2491-
iceberg_transform = partition_field.transform
2492-
2493-
if not iceberg_transform.preserves_order:
2494-
raise ValueError(
2495-
f"Cannot infer partition value from parquet metadata for a non-linear Partition Field: "
2496-
f"{partition_field.name} with transform {partition_field.transform}"
2497-
)
2498-
2499-
transform_func = iceberg_transform.transform(source_field.field_type)
2500-
2501-
lower_value = transform_func(
2502-
partition_record_value(
2503-
partition_field=partition_field,
2504-
value=self.column_aggregates[partition_field.source_id].current_min,
2505-
schema=schema,
2506-
)
2507-
)
2508-
upper_value = transform_func(
2509-
partition_record_value(
2510-
partition_field=partition_field,
2511-
value=self.column_aggregates[partition_field.source_id].current_max,
2512-
schema=schema,
2513-
)
2514-
)
2515-
if lower_value != upper_value:
2516-
raise ValueError(
2517-
f"Cannot infer partition value from parquet metadata as there are more than one partition values "
2518-
f"for Partition Field: {partition_field.name}. {lower_value=}, {upper_value=}"
2519-
)
2520-
2521-
return lower_value
2522-
2523-
def partition(self, partition_spec: PartitionSpec, schema: Schema) -> Record:
2524-
return Record(*[self._partition_value(field, schema) for field in partition_spec.fields])
2525-
2526-
def to_serialized_dict(self) -> dict[str, Any]:
2527-
lower_bounds = {}
2528-
upper_bounds = {}
2529-
2530-
for k, agg in self.column_aggregates.items():
2531-
_min = agg.min_as_bytes()
2532-
if _min is not None:
2533-
lower_bounds[k] = _min
2534-
_max = agg.max_as_bytes()
2535-
if _max is not None:
2536-
upper_bounds[k] = _max
2537-
return {
2538-
"record_count": self.record_count,
2539-
"column_sizes": self.column_sizes,
2540-
"value_counts": self.value_counts,
2541-
"null_value_counts": self.null_value_counts,
2542-
"nan_value_counts": self.nan_value_counts,
2543-
"lower_bounds": lower_bounds,
2544-
"upper_bounds": upper_bounds,
2545-
"split_offsets": self.split_offsets,
2546-
}
2547-
2548-
25492477
def data_file_statistics_from_parquet_metadata(
25502478
parquet_metadata: pq.FileMetaData,
25512479
stats_columns: dict[int, StatisticsCollector],

tests/io/test_fileformat.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from typing import Any
19+
20+
import pytest
21+
22+
from pyiceberg.io.fileformat import DataFileStatistics, FileFormatFactory, FileFormatModel, FileFormatWriter
23+
from pyiceberg.manifest import FileFormat
24+
25+
26+
def test_get_unregistered_format_raises() -> None:
27+
"""Getting an unregistered format should raise ValueError."""
28+
with pytest.raises(ValueError, match="No writer registered for"):
29+
FileFormatFactory.get(FileFormat.AVRO)
30+
31+
32+
def test_backward_compat_import() -> None:
33+
"""DataFileStatistics can still be imported from pyiceberg.io.pyarrow."""
34+
from pyiceberg.io.fileformat import DataFileStatistics as dFS # noqa: F401
35+
from pyiceberg.io.pyarrow import DataFileStatistics # noqa: F401
36+
37+
assert DataFileStatistics is dFS
38+
39+
40+
def test_duplicate_registration_raises() -> None:
41+
"""Registering the same format twice should raise ValueError."""
42+
43+
class _DummyModel(FileFormatModel):
44+
@property
45+
def format(self) -> FileFormat:
46+
return FileFormat.ORC
47+
48+
def file_extension(self) -> str:
49+
return "orc"
50+
51+
def create_writer(self, output_file: Any, file_schema: Any, properties: Any) -> Any:
52+
raise NotImplementedError
53+
54+
original = dict(FileFormatFactory._registry)
55+
try:
56+
model = _DummyModel()
57+
FileFormatFactory.register(model)
58+
with pytest.raises(ValueError, match="already registered"):
59+
FileFormatFactory.register(model)
60+
finally:
61+
FileFormatFactory._registry = original
62+
63+
64+
def test_result_before_close_raises() -> None:
65+
"""Calling result before close should raise an error."""
66+
67+
class _DummyWriter(FileFormatWriter):
68+
def write(self, table: Any) -> None:
69+
pass
70+
71+
def close(self) -> DataFileStatistics:
72+
raise NotImplementedError
73+
74+
writer = _DummyWriter()
75+
with pytest.raises(RuntimeError, match="Writer has not been closed yet"):
76+
writer.result()

0 commit comments

Comments
 (0)