Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
filters: Filter | None = None,
context: DataLoaderContext | None = None,
max_attempts: int = 3,
batch_size: int | None = None,
):
"""
Args:
Expand All @@ -90,6 +91,9 @@ def __init__(
filters: Row filter expression, defaults to always_true() (all rows)
context: Data loader context
max_attempts: Total number of attempts including the initial try (default 3)
batch_size: Number of rows per RecordBatch yielded by each split.
Controls memory usage per worker — smaller values reduce peak memory
but increase per-batch overhead. None uses the PyArrow default (~131K rows).
"""
self._catalog = catalog
self._table_id = TableIdentifier(database, table, branch)
Expand All @@ -98,6 +102,7 @@ def __init__(
self._filters = filters if filters is not None else always_true()
self._context = context or DataLoaderContext()
self._max_attempts = max_attempts
self._batch_size = batch_size

@cached_property
def _iceberg_table(self) -> Table:
Expand Down Expand Up @@ -163,4 +168,5 @@ def __iter__(self) -> Iterator[DataLoaderSplit]:
yield DataLoaderSplit(
file_scan_task=scan_task,
scan_context=scan_context,
batch_size=self._batch_size,
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from datafusion.plan import LogicalPlan
from pyarrow import RecordBatch
from pyiceberg.io.pyarrow import ArrowScan
from pyiceberg.table import FileScanTask
from pyiceberg.table import ArrivalOrder, FileScanTask

from openhouse.dataloader._table_scan_context import TableScanContext
from openhouse.dataloader.udf_registry import NoOpRegistry, UDFRegistry
Expand All @@ -22,11 +22,13 @@ def __init__(
scan_context: TableScanContext,
plan: LogicalPlan | None = None,
udf_registry: UDFRegistry | None = None,
batch_size: int | None = None,
):
self._plan = plan
self._file_scan_task = file_scan_task
self._udf_registry = udf_registry or NoOpRegistry()
self._scan_context = scan_context
self._batch_size = batch_size

@property
def id(self) -> str:
Expand All @@ -45,7 +47,8 @@ def __iter__(self) -> Iterator[RecordBatch]:
"""Reads the file scan task and yields Arrow RecordBatches.

Uses PyIceberg's ArrowScan to handle format dispatch, schema resolution,
delete files, and partition spec lookups.
delete files, and partition spec lookups. Batches are streamed
incrementally (not materialized into memory) via ArrivalOrder.
"""
ctx = self._scan_context
arrow_scan = ArrowScan(
Expand All @@ -54,4 +57,7 @@ def __iter__(self) -> Iterator[RecordBatch]:
projected_schema=ctx.projected_schema,
row_filter=ctx.row_filter,
)
yield from arrow_scan.to_record_batches([self._file_scan_task])
yield from arrow_scan.to_record_batches(
[self._file_scan_task],
order=ArrivalOrder(concurrent_streams=1, batch_size=self._batch_size),
)
73 changes: 73 additions & 0 deletions integrations/python/dataloader/tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,76 @@ def test_snapshot_id_with_columns_and_filters(tmp_path):
assert scan_kwargs["snapshot_id"] == 99
assert scan_kwargs["selected_fields"] == (COL_ID,)
assert "row_filter" in scan_kwargs


# --- batch_size tests ---


def test_batch_size_default_returns_all_data(tmp_path):
"""Without batch_size, all data is returned correctly (backwards compatibility)."""
catalog = _make_real_catalog(tmp_path)

loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl")
result = _materialize(loader)

assert result.num_rows == 3
result = result.sort_by(COL_ID)
assert result.column(COL_ID).to_pylist() == TEST_DATA[COL_ID]


def test_batch_size_limits_rows_per_batch(tmp_path):
"""When batch_size is set, each RecordBatch has at most batch_size rows."""
many_rows = {
COL_ID: list(range(100)),
COL_NAME: [f"name_{i}" for i in range(100)],
COL_VALUE: [float(i) for i in range(100)],
}
catalog = _make_real_catalog(tmp_path, data=many_rows)

loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", batch_size=10)
batches = [batch for split in loader for batch in split]

assert len(batches) >= 2, "Expected multiple batches with batch_size=10 and 100 rows"
for batch in batches:
assert batch.num_rows <= 10, f"Batch has {batch.num_rows} rows, expected at most 10"

total_rows = sum(b.num_rows for b in batches)
assert total_rows == 100


def test_batch_size_returns_correct_data(tmp_path):
"""batch_size controls chunking but doesn't alter the data returned."""
catalog = _make_real_catalog(tmp_path)

loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", batch_size=1)
result = _materialize(loader)

assert result.num_rows == 3
result = result.sort_by(COL_ID)
assert result.column(COL_ID).to_pylist() == TEST_DATA[COL_ID]
assert result.column(COL_NAME).to_pylist() == TEST_DATA[COL_NAME]
assert result.column(COL_VALUE).to_pylist() == TEST_DATA[COL_VALUE]


def test_batch_size_with_columns_and_filters(tmp_path):
"""batch_size works alongside column selection and row filters."""
catalog = _make_real_catalog(tmp_path)

loader = OpenHouseDataLoader(
catalog=catalog, database="db", table="tbl", columns=[COL_ID], filters=col(COL_ID) == 1, batch_size=1
)
result = _materialize(loader)

assert result.num_rows == 1
assert set(result.column_names) == {COL_ID}
assert result.column(COL_ID).to_pylist() == [1]


def test_batch_size_with_empty_table(tmp_path):
"""batch_size on an empty table yields no batches."""
catalog = _make_real_catalog(tmp_path, data=EMPTY_DATA)

loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", batch_size=10)
result = _materialize(loader)

assert result.num_rows == 0
46 changes: 46 additions & 0 deletions integrations/python/dataloader/tests/test_data_loader_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def _create_test_split(
iceberg_schema: Schema,
io_properties: dict[str, str] | None = None,
filename: str | None = None,
batch_size: int | None = None,
) -> DataLoaderSplit:
"""Create a DataLoaderSplit for testing by writing data to disk.

Expand Down Expand Up @@ -88,6 +89,7 @@ def _create_test_split(
plan=plan,
file_scan_task=task,
scan_context=scan_context,
batch_size=batch_size,
)


Expand Down Expand Up @@ -199,3 +201,47 @@ def test_split_id_ignores_default_netloc(tmp_path):
split._scan_context.io.fs_by_scheme = MagicMock(return_value=local_fs)
list(split)
split._scan_context.io.fs_by_scheme.assert_called_with("hdfs", expected_netloc)


# --- batch_size tests ---

_BATCH_SCHEMA = Schema(
NestedField(field_id=1, name="id", field_type=LongType(), required=False),
)


def _make_large_table(num_rows: int) -> pa.Table:
return pa.table({"id": pa.array(list(range(num_rows)), type=pa.int64())})


def test_split_batch_size_limits_rows_per_batch(tmp_path):
"""When batch_size is set, each RecordBatch has at most that many rows."""
table = _make_large_table(100)
split = _create_test_split(tmp_path, table, FileFormat.PARQUET, _BATCH_SCHEMA, batch_size=10)

batches = list(split)

assert len(batches) >= 2, "Expected multiple batches with batch_size=10 and 100 rows"
for batch in batches:
assert batch.num_rows <= 10
assert sum(b.num_rows for b in batches) == 100


def test_split_batch_size_none_returns_all_rows(tmp_path):
"""Default batch_size (None) returns all data correctly."""
table = _make_large_table(50)
split = _create_test_split(tmp_path, table, FileFormat.PARQUET, _BATCH_SCHEMA)

result = pa.Table.from_batches(list(split))
assert result.num_rows == 50
assert sorted(result.column("id").to_pylist()) == list(range(50))


def test_split_batch_size_preserves_data(tmp_path):
"""batch_size controls chunking but all data is preserved."""
table = _make_large_table(25)
split = _create_test_split(tmp_path, table, FileFormat.PARQUET, _BATCH_SCHEMA, batch_size=7)

result = pa.Table.from_batches(list(split))
assert result.num_rows == 25
assert sorted(result.column("id").to_pylist()) == list(range(25))
Loading
Loading