diff --git a/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py b/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py index fdeb2fcde..ad92e6cb9 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py @@ -78,6 +78,7 @@ def __init__( filters: Filter | None = None, context: DataLoaderContext | None = None, max_attempts: int = 3, + batch_size: int | None = None, ): """ Args: @@ -90,6 +91,9 @@ def __init__( filters: Row filter expression, defaults to always_true() (all rows) context: Data loader context max_attempts: Total number of attempts including the initial try (default 3) + batch_size: Number of rows per RecordBatch yielded by each split. + Controls memory usage per worker — smaller values reduce peak memory + but increase per-batch overhead. None uses the PyArrow default (~131K rows). """ self._catalog = catalog self._table_id = TableIdentifier(database, table, branch) @@ -98,6 +102,7 @@ def __init__( self._filters = filters if filters is not None else always_true() self._context = context or DataLoaderContext() self._max_attempts = max_attempts + self._batch_size = batch_size @cached_property def _iceberg_table(self) -> Table: @@ -163,4 +168,5 @@ def __iter__(self) -> Iterator[DataLoaderSplit]: yield DataLoaderSplit( file_scan_task=scan_task, scan_context=scan_context, + batch_size=self._batch_size, ) diff --git a/integrations/python/dataloader/src/openhouse/dataloader/data_loader_split.py b/integrations/python/dataloader/src/openhouse/dataloader/data_loader_split.py index 12c79af78..e814c3281 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/data_loader_split.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/data_loader_split.py @@ -7,7 +7,7 @@ from datafusion.plan import LogicalPlan from pyarrow import RecordBatch from pyiceberg.io.pyarrow import ArrowScan -from pyiceberg.table import FileScanTask +from pyiceberg.table import ArrivalOrder, FileScanTask from openhouse.dataloader._table_scan_context import TableScanContext from openhouse.dataloader.udf_registry import NoOpRegistry, UDFRegistry @@ -22,11 +22,13 @@ def __init__( scan_context: TableScanContext, plan: LogicalPlan | None = None, udf_registry: UDFRegistry | None = None, + batch_size: int | None = None, ): self._plan = plan self._file_scan_task = file_scan_task self._udf_registry = udf_registry or NoOpRegistry() self._scan_context = scan_context + self._batch_size = batch_size @property def id(self) -> str: @@ -45,7 +47,8 @@ def __iter__(self) -> Iterator[RecordBatch]: """Reads the file scan task and yields Arrow RecordBatches. Uses PyIceberg's ArrowScan to handle format dispatch, schema resolution, - delete files, and partition spec lookups. + delete files, and partition spec lookups. Batches are streamed + incrementally (not materialized into memory) via ArrivalOrder. """ ctx = self._scan_context arrow_scan = ArrowScan( @@ -54,4 +57,7 @@ def __iter__(self) -> Iterator[RecordBatch]: projected_schema=ctx.projected_schema, row_filter=ctx.row_filter, ) - yield from arrow_scan.to_record_batches([self._file_scan_task]) + yield from arrow_scan.to_record_batches( + [self._file_scan_task], + order=ArrivalOrder(concurrent_streams=1, batch_size=self._batch_size), + ) diff --git a/integrations/python/dataloader/tests/test_data_loader.py b/integrations/python/dataloader/tests/test_data_loader.py index b092c1c0d..a7958b706 100644 --- a/integrations/python/dataloader/tests/test_data_loader.py +++ b/integrations/python/dataloader/tests/test_data_loader.py @@ -322,3 +322,76 @@ def test_snapshot_id_with_columns_and_filters(tmp_path): assert scan_kwargs["snapshot_id"] == 99 assert scan_kwargs["selected_fields"] == (COL_ID,) assert "row_filter" in scan_kwargs + + +# --- batch_size tests --- + + +def test_batch_size_default_returns_all_data(tmp_path): + """Without batch_size, all data is returned correctly (backwards compatibility).""" + catalog = _make_real_catalog(tmp_path) + + loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl") + result = _materialize(loader) + + assert result.num_rows == 3 + result = result.sort_by(COL_ID) + assert result.column(COL_ID).to_pylist() == TEST_DATA[COL_ID] + + +def test_batch_size_limits_rows_per_batch(tmp_path): + """When batch_size is set, each RecordBatch has at most batch_size rows.""" + many_rows = { + COL_ID: list(range(100)), + COL_NAME: [f"name_{i}" for i in range(100)], + COL_VALUE: [float(i) for i in range(100)], + } + catalog = _make_real_catalog(tmp_path, data=many_rows) + + loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", batch_size=10) + batches = [batch for split in loader for batch in split] + + assert len(batches) >= 2, "Expected multiple batches with batch_size=10 and 100 rows" + for batch in batches: + assert batch.num_rows <= 10, f"Batch has {batch.num_rows} rows, expected at most 10" + + total_rows = sum(b.num_rows for b in batches) + assert total_rows == 100 + + +def test_batch_size_returns_correct_data(tmp_path): + """batch_size controls chunking but doesn't alter the data returned.""" + catalog = _make_real_catalog(tmp_path) + + loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", batch_size=1) + result = _materialize(loader) + + assert result.num_rows == 3 + result = result.sort_by(COL_ID) + assert result.column(COL_ID).to_pylist() == TEST_DATA[COL_ID] + assert result.column(COL_NAME).to_pylist() == TEST_DATA[COL_NAME] + assert result.column(COL_VALUE).to_pylist() == TEST_DATA[COL_VALUE] + + +def test_batch_size_with_columns_and_filters(tmp_path): + """batch_size works alongside column selection and row filters.""" + catalog = _make_real_catalog(tmp_path) + + loader = OpenHouseDataLoader( + catalog=catalog, database="db", table="tbl", columns=[COL_ID], filters=col(COL_ID) == 1, batch_size=1 + ) + result = _materialize(loader) + + assert result.num_rows == 1 + assert set(result.column_names) == {COL_ID} + assert result.column(COL_ID).to_pylist() == [1] + + +def test_batch_size_with_empty_table(tmp_path): + """batch_size on an empty table yields no batches.""" + catalog = _make_real_catalog(tmp_path, data=EMPTY_DATA) + + loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", batch_size=10) + result = _materialize(loader) + + assert result.num_rows == 0 diff --git a/integrations/python/dataloader/tests/test_data_loader_split.py b/integrations/python/dataloader/tests/test_data_loader_split.py index 0a5f94190..bcca0de0b 100644 --- a/integrations/python/dataloader/tests/test_data_loader_split.py +++ b/integrations/python/dataloader/tests/test_data_loader_split.py @@ -30,6 +30,7 @@ def _create_test_split( iceberg_schema: Schema, io_properties: dict[str, str] | None = None, filename: str | None = None, + batch_size: int | None = None, ) -> DataLoaderSplit: """Create a DataLoaderSplit for testing by writing data to disk. @@ -88,6 +89,7 @@ def _create_test_split( plan=plan, file_scan_task=task, scan_context=scan_context, + batch_size=batch_size, ) @@ -199,3 +201,47 @@ def test_split_id_ignores_default_netloc(tmp_path): split._scan_context.io.fs_by_scheme = MagicMock(return_value=local_fs) list(split) split._scan_context.io.fs_by_scheme.assert_called_with("hdfs", expected_netloc) + + +# --- batch_size tests --- + +_BATCH_SCHEMA = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=False), +) + + +def _make_large_table(num_rows: int) -> pa.Table: + return pa.table({"id": pa.array(list(range(num_rows)), type=pa.int64())}) + + +def test_split_batch_size_limits_rows_per_batch(tmp_path): + """When batch_size is set, each RecordBatch has at most that many rows.""" + table = _make_large_table(100) + split = _create_test_split(tmp_path, table, FileFormat.PARQUET, _BATCH_SCHEMA, batch_size=10) + + batches = list(split) + + assert len(batches) >= 2, "Expected multiple batches with batch_size=10 and 100 rows" + for batch in batches: + assert batch.num_rows <= 10 + assert sum(b.num_rows for b in batches) == 100 + + +def test_split_batch_size_none_returns_all_rows(tmp_path): + """Default batch_size (None) returns all data correctly.""" + table = _make_large_table(50) + split = _create_test_split(tmp_path, table, FileFormat.PARQUET, _BATCH_SCHEMA) + + result = pa.Table.from_batches(list(split)) + assert result.num_rows == 50 + assert sorted(result.column("id").to_pylist()) == list(range(50)) + + +def test_split_batch_size_preserves_data(tmp_path): + """batch_size controls chunking but all data is preserved.""" + table = _make_large_table(25) + split = _create_test_split(tmp_path, table, FileFormat.PARQUET, _BATCH_SCHEMA, batch_size=7) + + result = pa.Table.from_batches(list(split)) + assert result.num_rows == 25 + assert sorted(result.column("id").to_pylist()) == list(range(25)) diff --git a/integrations/spark/spark-3.5/openhouse-spark-itest/src/test/java/com/linkedin/openhouse/spark/catalogtest/SnapshotExpirationRefsTest.java b/integrations/spark/spark-3.5/openhouse-spark-itest/src/test/java/com/linkedin/openhouse/spark/catalogtest/SnapshotExpirationRefsTest.java new file mode 100644 index 000000000..aa8fdd692 --- /dev/null +++ b/integrations/spark/spark-3.5/openhouse-spark-itest/src/test/java/com/linkedin/openhouse/spark/catalogtest/SnapshotExpirationRefsTest.java @@ -0,0 +1,238 @@ +package com.linkedin.openhouse.spark.catalogtest; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.linkedin.openhouse.tablestest.OpenHouseSparkITest; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.MethodOrderer; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestMethodOrder; +import org.junit.jupiter.api.parallel.Execution; +import org.junit.jupiter.api.parallel.ExecutionMode; + +/** + * Tests that snapshot expiration does NOT remove snapshots reachable from tags or branches. This + * demonstrates that member data can remain queryable through refs even after expire_snapshots runs, + * which is the compliance gap motivating the RFC for automatic ref-aware purging. + */ +@TestMethodOrder(MethodOrderer.MethodName.class) +@Execution(ExecutionMode.SAME_THREAD) +public class SnapshotExpirationRefsTest extends OpenHouseSparkITest { + + private static final String DATABASE = "d1_expiration_refs"; + private static final String TEST_PREFIX = "exp_refs_"; + + @AfterEach + public void cleanupAfterTest() { + try (SparkSession spark = getSparkSession()) { + List tables = spark.sql("SHOW TABLES IN openhouse." + DATABASE).collectAsList(); + for (Row table : tables) { + String name = table.getString(1); + spark.sql("DROP TABLE IF EXISTS openhouse." + DATABASE + "." + name); + } + } catch (Exception e) { + System.err.println("Warning: cleanup failed: " + e.getMessage()); + } + } + + /** + * A tag pointing at a snapshot should prevent expire_snapshots from removing that snapshot's + * data. After expiration, querying the table at the tagged snapshot should still return the + * original rows. + */ + @Test + public void testTagPreservesSnapshotThroughExpiration() throws Exception { + try (SparkSession spark = getSparkSession()) { + String tableId = TEST_PREFIX + System.currentTimeMillis(); + String tableName = "openhouse." + DATABASE + "." + tableId; + + // Create table and insert Wave 1 member data + spark.sql("CREATE TABLE " + tableName + " (member_id bigint, event_type string)"); + spark.sql( + "INSERT INTO " + tableName + " VALUES (1001, 'login'), (1002, 'click'), (1003, 'view')"); + + // Capture the Wave 1 snapshot ID + long wave1SnapshotId = getLatestSnapshotId(spark, tableName); + + // Tag the Wave 1 snapshot + spark.sql( + "ALTER TABLE " + + tableName + + " CREATE TAG wave1_members AS OF VERSION " + + wave1SnapshotId); + + // Insert Wave 2 and Wave 3 to push Wave 1 into history + spark.sql("INSERT INTO " + tableName + " VALUES (2001, 'login'), (2002, 'click')"); + spark.sql("INSERT INTO " + tableName + " VALUES (3001, 'view')"); + + int snapshotsBeforeExpiry = getSnapshotIds(spark, tableName).size(); + assertEquals(3, snapshotsBeforeExpiry, "Should have 3 snapshots before expiration"); + + // Expire all snapshots aggressively, keeping only 1 on the main lineage + spark.sql( + "CALL openhouse.system.expire_snapshots(table => '" + + tableName + + "', older_than => TIMESTAMP '2099-01-01 00:00:00', retain_last => 1)"); + + // The tag should still be present in refs + List refNames = getRefNames(spark, tableName); + assertTrue(refNames.contains("wave1_members"), "Tag wave1_members should still exist"); + + // KEY ASSERTION: Wave 1 data should still be queryable through the tagged snapshot + List taggedData = + spark + .sql("SELECT * FROM " + tableName + " VERSION AS OF " + wave1SnapshotId) + .collectAsList(); + assertEquals( + 3, + taggedData.size(), + "Tag should preserve all 3 Wave 1 rows through snapshot expiration"); + } + } + + /** + * A branch pointing at a snapshot should prevent expire_snapshots from removing that snapshot's + * data. After expiration, querying the table at the branch's snapshot should still return the + * original rows. + */ + @Test + public void testBranchPreservesSnapshotThroughExpiration() throws Exception { + try (SparkSession spark = getSparkSession()) { + String tableId = TEST_PREFIX + System.currentTimeMillis(); + String tableName = "openhouse." + DATABASE + "." + tableId; + + // Create table and insert Wave 1 member data + spark.sql("CREATE TABLE " + tableName + " (member_id bigint, event_type string)"); + spark.sql( + "INSERT INTO " + tableName + " VALUES (4001, 'login'), (4002, 'click'), (4003, 'view')"); + + // Capture the Wave 1 snapshot ID + long wave1SnapshotId = getLatestSnapshotId(spark, tableName); + + // Create a branch at the Wave 1 snapshot + spark.sql( + "ALTER TABLE " + + tableName + + " CREATE BRANCH audit_branch AS OF VERSION " + + wave1SnapshotId); + + // Insert Wave 2 and Wave 3 on main to push Wave 1 into history + spark.sql("INSERT INTO " + tableName + " VALUES (5001, 'login'), (5002, 'click')"); + spark.sql("INSERT INTO " + tableName + " VALUES (6001, 'view')"); + + int snapshotsBeforeExpiry = getSnapshotIds(spark, tableName).size(); + assertEquals(3, snapshotsBeforeExpiry, "Should have 3 snapshots before expiration"); + + // Expire all snapshots aggressively, keeping only 1 on the main lineage + spark.sql( + "CALL openhouse.system.expire_snapshots(table => '" + + tableName + + "', older_than => TIMESTAMP '2099-01-01 00:00:00', retain_last => 1)"); + + // The branch should still be present in refs + List refNames = getRefNames(spark, tableName); + assertTrue(refNames.contains("audit_branch"), "Branch audit_branch should still exist"); + + // KEY ASSERTION: Wave 1 data should still be queryable through the branch + List branchData = + spark.sql("SELECT * FROM " + tableName + " VERSION AS OF 'audit_branch'").collectAsList(); + assertEquals( + 3, + branchData.size(), + "Branch should preserve all 3 Wave 1 rows through snapshot expiration"); + } + } + + /** + * Option B validation: Setting the table property history.expire.max-ref-age-ms to a small value + * should cause expire_snapshots to drop refs older than the threshold, making their snapshots + * eligible for expiration. This is the zero-syntax-change solution for compliance. + */ + @Test + public void testMaxRefAgeMsPropertyDropsExpiredTagAndBranch() throws Exception { + try (SparkSession spark = getSparkSession()) { + String tableId = TEST_PREFIX + System.currentTimeMillis(); + String tableName = "openhouse." + DATABASE + "." + tableId; + + // Create table and insert Wave 1 member data + spark.sql("CREATE TABLE " + tableName + " (member_id bigint, event_type string)"); + spark.sql( + "INSERT INTO " + tableName + " VALUES (1001, 'login'), (1002, 'click'), (1003, 'view')"); + + long wave1SnapshotId = getLatestSnapshotId(spark, tableName); + + // Create a tag and a branch at the Wave 1 snapshot + spark.sql( + "ALTER TABLE " + tableName + " CREATE TAG old_tag AS OF VERSION " + wave1SnapshotId); + spark.sql( + "ALTER TABLE " + + tableName + + " CREATE BRANCH old_branch AS OF VERSION " + + wave1SnapshotId); + + // Insert Wave 2 on main so main has a newer snapshot + spark.sql("INSERT INTO " + tableName + " VALUES (2001, 'login'), (2002, 'click')"); + + // Verify both refs exist before expiration + List refsBefore = getRefNames(spark, tableName); + assertTrue(refsBefore.contains("old_tag"), "Tag should exist before expiration"); + assertTrue(refsBefore.contains("old_branch"), "Branch should exist before expiration"); + assertEquals(2, getSnapshotIds(spark, tableName).size(), "Should have 2 snapshots"); + + // Set max-ref-age-ms to 1ms — any ref older than 1ms will be dropped on next expiration + spark.sql( + "ALTER TABLE " + + tableName + + " SET TBLPROPERTIES ('history.expire.max-ref-age-ms' = '1')"); + + // Small delay to ensure refs are older than 1ms + Thread.sleep(10); + + // Run expire_snapshots — should drop the tag and branch, then expire Wave 1 snapshot + spark.sql( + "CALL openhouse.system.expire_snapshots(table => '" + + tableName + + "', older_than => TIMESTAMP '2099-01-01 00:00:00', retain_last => 1)"); + + // KEY ASSERTIONS: refs should be gone + List refsAfter = getRefNames(spark, tableName); + assertFalse(refsAfter.contains("old_tag"), "Tag should be dropped by max-ref-age-ms"); + assertFalse(refsAfter.contains("old_branch"), "Branch should be dropped by max-ref-age-ms"); + assertTrue(refsAfter.contains("main"), "Main branch should always be retained"); + + // Wave 1 snapshot should be expired since no ref protects it anymore + List remainingSnapshots = getSnapshotIds(spark, tableName); + assertEquals(1, remainingSnapshots.size(), "Only the latest main snapshot should remain"); + assertFalse( + remainingSnapshots.contains(wave1SnapshotId), + "Wave 1 snapshot should be expired after ref was dropped"); + } + } + + private static long getLatestSnapshotId(SparkSession spark, String tableName) { + List snapshots = + spark + .sql("SELECT snapshot_id FROM " + tableName + ".snapshots ORDER BY committed_at") + .collectAsList(); + return snapshots.get(snapshots.size() - 1).getLong(0); + } + + private static List getSnapshotIds(SparkSession spark, String tableName) { + return spark.sql("SELECT snapshot_id FROM " + tableName + ".snapshots ORDER BY committed_at") + .collectAsList().stream() + .map(r -> r.getLong(0)) + .collect(Collectors.toList()); + } + + private static List getRefNames(SparkSession spark, String tableName) { + return spark.sql("SELECT name FROM " + tableName + ".refs").collectAsList().stream() + .map(r -> r.getString(0)) + .collect(Collectors.toList()); + } +}