diff --git a/spark/v3.1/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala b/spark/v3.1/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala index 30b5df5317..5460e83d21 100644 --- a/spark/v3.1/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala +++ b/spark/v3.1/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.optimizer.RewriteMergeInto import org.apache.spark.sql.catalyst.optimizer.RewriteUpdate import org.apache.spark.sql.catalyst.parser.extensions.IcebergSparkSqlExtensionsParser import org.apache.spark.sql.execution.datasources.v2.ExtendedDataSourceV2Strategy +import org.apache.spark.sql.execution.datasources.v2.ExtendedV2Writes class IcebergSparkSessionExtensions extends (SparkSessionExtensions => Unit) { @@ -51,6 +52,11 @@ class IcebergSparkSessionExtensions extends (SparkSessionExtensions => Unit) { extensions.injectOptimizerRule { spark => RewriteUpdate(spark) } extensions.injectOptimizerRule { spark => RewriteMergeInto(spark) } + // pre-CBO extensions + // attach the Iceberg table's required distribution and ordering to V2 writes after + // the optimizer has resolved the write target but before physical planning + extensions.injectPreCBORule { _ => ExtendedV2Writes } + // planner extensions extensions.injectPlannerStrategy { spark => ExtendedDataSourceV2Strategy(spark) } } diff --git a/spark/v3.1/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedV2Writes.scala b/spark/v3.1/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedV2Writes.scala new file mode 100644 index 0000000000..bdd45f843d --- /dev/null +++ b/spark/v3.1/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedV2Writes.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.spark.Spark3Util +import org.apache.spark.sql.catalyst.plans.logical.AppendData +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.OverwriteByExpression +import org.apache.spark.sql.catalyst.plans.logical.OverwritePartitionsDynamic +import org.apache.spark.sql.catalyst.plans.logical.RepartitionByExpression +import org.apache.spark.sql.catalyst.plans.logical.Sort +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.utils.DistributionAndOrderingUtils +import org.apache.spark.sql.catalyst.utils.PlanUtils.isIcebergRelation +import org.apache.spark.sql.connector.iceberg.distributions.Distribution +import org.apache.spark.sql.connector.iceberg.distributions.Distributions +import org.apache.spark.sql.connector.iceberg.distributions.OrderedDistribution +import org.apache.spark.sql.connector.iceberg.expressions.SortOrder + +/** + * Backport of Spark 3.2's V2Writes for v3.1 AppendData/OverwriteByExpression/ + * OverwritePartitionsDynamic. Attaches a local Sort only when the table has an explicit sort + * order or RANGE distribution; never attaches an Exchange. Unsorted partitioned writes thus + * require write.spark.fanout.enabled=true or pre-clustered input. + * + * No distribution: Spark 3.1 only has strict RepartitionByExpression. Spark 3.4+'s + * RebalancePartitions (used by Spark 3.5's V2Writes) doesn't exist here, and a strict + * repartition would turn skewed partition keys into stragglers. + * + * MERGE/UPDATE/DELETE are skipped — RewriteRowLevelOperationHelper.buildWritePlan already + * prepares those queries (using the unwrapped Spark3Util.buildRequiredOrdering, which still + * synthesizes the partition prefix); alreadyPrepared() detects its shape to avoid double-wrap. + */ +object ExtendedV2Writes extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case a @ AppendData(r: DataSourceV2Relation, query, _, _) + if isIcebergRelation(r) && !alreadyPrepared(query) => + a.withNewQuery(prepareQuery(r, query)) + + case o @ OverwriteByExpression(r: DataSourceV2Relation, _, query, _, _) + if isIcebergRelation(r) && !alreadyPrepared(query) => + o.withNewQuery(prepareQuery(r, query)) + + case o @ OverwritePartitionsDynamic(r: DataSourceV2Relation, query, _, _) + if isIcebergRelation(r) && !alreadyPrepared(query) => + o.withNewQuery(prepareQuery(r, query)) + } + + // Matches the shapes RewriteRowLevelOperationHelper.buildWritePlan produces. Bare + // Sort(_, false, _) is intentionally NOT matched — it would swallow a user's + // sortWithinPartitions on the wrong columns and skip the table-required ordering. + private def alreadyPrepared(query: LogicalPlan): Boolean = query match { + case Sort(_, false, RepartitionByExpression(_, _, None)) => true + case RepartitionByExpression(_, _, None) => true + case _ => false + } + + private def prepareQuery(r: DataSourceV2Relation, query: LogicalPlan): LogicalPlan = { + val icebergTable = Spark3Util.toIcebergTable(r.table) + // Distribution is computed only to surface OrderedDistribution.ordering; we always pass + // unspecified() to prepareQuery so no Exchange is attached. + val tableDistribution = Spark3Util.buildRequiredDistribution(icebergTable) + val ordering = requiredOrdering(tableDistribution, icebergTable) + DistributionAndOrderingUtils.prepareQuery( + Distributions.unspecified(), ordering, query, conf) + } + + // Delegate to Spark3Util.buildRequiredOrdering only for OrderedDistribution or sorted tables. + // Unsorted tables get an empty ordering — fanout or pre-clustered input is required. + private def requiredOrdering( + distribution: Distribution, + icebergTable: org.apache.iceberg.Table): Array[SortOrder] = { + if (distribution.isInstanceOf[OrderedDistribution] || !icebergTable.sortOrder().isUnsorted) { + Spark3Util.buildRequiredOrdering(distribution, icebergTable) + } else { + Array.empty[SortOrder] + } + } +} diff --git a/spark/v3.1/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRequiredDistributionAndOrdering.java b/spark/v3.1/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRequiredDistributionAndOrdering.java new file mode 100644 index 0000000000..4671d0dab3 --- /dev/null +++ b/spark/v3.1/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRequiredDistributionAndOrdering.java @@ -0,0 +1,674 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import java.math.BigDecimal; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.NullOrder; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.source.ThreeColumnRecord; +import org.apache.spark.SparkException; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.assertj.core.api.Assertions; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; + +public class TestRequiredDistributionAndOrdering extends SparkExtensionsTestBase { + + public TestRequiredDistributionAndOrdering( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @After + public void dropTestTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + // Unsorted partitioned table: rule must not synthesize a partition-spec sort. Fanout is + // enabled so FanoutDataWriter accepts the unclustered bucket transitions. + @Test + public void testNoSyntheticPartitionSortWithBucketTransforms() throws NoSuchTableException { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c1))", + tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s'='true')", + tableName, TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @Test + public void testPartitionColumnsArePrependedForRangeDistribution() throws NoSuchTableException { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c1))", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + // should automatically prepend partition columns to the ordering + sql("ALTER TABLE %s WRITE ORDERED BY c1, c2", tableName); + + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @Test + public void testSortOrderIncludesPartitionColumns() throws NoSuchTableException { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c1))", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + // should succeed with a correct sort order + sql("ALTER TABLE %s WRITE ORDERED BY bucket(2, c3), c1, c2", tableName); + + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @Test + public void testHashDistributionOnBucketedColumn() throws NoSuchTableException { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c1))", + tableName); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); + Dataset ds = spark.createDataFrame(data, ThreeColumnRecord.class); + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c1"); + + // should automatically prepend partition columns to the local ordering after hash distribution + sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION ORDERED BY c1, c2", tableName); + + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + // INSERT VALUES into an unsorted partitioned table across various transform types. Fanout is + // enabled because the rule no longer clusters unsorted tables. + + @Test + public void testInsertValuesOnDecimalBucketedColumn() { + sql( + "CREATE TABLE %s (c1 INT, c2 DECIMAL(20, 2)) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c2)) " + + "TBLPROPERTIES ('%s'='true')", + tableName, TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED); + + sql("INSERT INTO %s VALUES (1, 20.2), (2, 40.2), (3, 60.2)", tableName); + + List expected = + ImmutableList.of( + row(1, new BigDecimal("20.20")), + row(2, new BigDecimal("40.20")), + row(3, new BigDecimal("60.20"))); + + assertEquals("Rows must match", expected, sql("SELECT * FROM %s ORDER BY c1", tableName)); + } + + @Test + public void testInsertValuesOnStringBucketedColumn() { + sql( + "CREATE TABLE %s (c1 INT, c2 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(2, c2)) " + + "TBLPROPERTIES ('%s'='true')", + tableName, TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED); + + sql("INSERT INTO %s VALUES (1, 'A'), (2, 'B')", tableName); + + List expected = ImmutableList.of(row(1, "A"), row(2, "B")); + + assertEquals("Rows must match", expected, sql("SELECT * FROM %s ORDER BY c1", tableName)); + } + + @Test + public void testInsertValuesOnDecimalTruncatedColumn() { + sql( + "CREATE TABLE %s (c1 INT, c2 DECIMAL(20, 2)) " + + "USING iceberg " + + "PARTITIONED BY (truncate(2, c2)) " + + "TBLPROPERTIES ('%s'='true')", + tableName, TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED); + + sql("INSERT INTO %s VALUES (1, 20.2), (2, 40.2)", tableName); + + List expected = + ImmutableList.of(row(1, new BigDecimal("20.20")), row(2, new BigDecimal("40.20"))); + + assertEquals("Rows must match", expected, sql("SELECT * FROM %s ORDER BY c1", tableName)); + } + + @Test + public void testInsertValuesOnLongTruncatedColumn() { + sql( + "CREATE TABLE %s (c1 INT, c2 BIGINT) " + + "USING iceberg " + + "PARTITIONED BY (truncate(2, c2)) " + + "TBLPROPERTIES ('%s'='true')", + tableName, TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED); + + sql("INSERT INTO %s VALUES (1, 22222222222222), (2, 444444444444)", tableName); + + List expected = ImmutableList.of(row(1, 22222222222222L), row(2, 444444444444L)); + + assertEquals("Rows must match", expected, sql("SELECT * FROM %s ORDER BY c1", tableName)); + } + + // testRangeDistributionWithQuotedColumnNames from the v3.2 suite is intentionally omitted: + // v3.1 SortOrderToSpark passes raw column names through Expressions.column, which can't parse + // dotted identifiers. v3.2 fixed this by indexing schema-quoted names; that fix is out of scope + // for this backport. + + // Unclustered input. Rule attaches a local Sort (with partition prefix) for sorted/RANGE + // tables; HASH/NONE on unsorted tables is a no-op, so those enable fanout. Distribution is + // never injected — see ExtendedV2Writes class doc. *FailsWithoutRule tests below pin down + // baseline writer behavior when the rule is disabled. + + // HASH on an unsorted table: rule is a no-op. Fanout enabled. For rule-injected clustering + // here, set a sort order — see testHashDistributionWithExplicitSortOrder. + @Test + public void testHashDistributionModeViaTableProperty() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id INT, category STRING, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(4, id))", + tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s'='hash', '%s'='true')", + tableName, + TableProperties.WRITE_DISTRIBUTION_MODE, + TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals( + "Distribution mode must be hash", + "hash", + table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE)); + + Dataset inputDF = unclusteredInput(); + + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(20L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @Test + public void testRangeDistributionModeViaSortOrder() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id INT, category STRING, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(4, id))", + tableName); + + // WRITE ORDERED BY implicitly sets the distribution mode to range + sql("ALTER TABLE %s WRITE ORDERED BY category, id", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals( + "Distribution mode must be range", + "range", + table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE)); + SortOrder expectedOrder = + SortOrder.builderFor(table.schema()) + .withOrderId(1) + .asc("category", NullOrder.NULLS_FIRST) + .asc("id", NullOrder.NULLS_FIRST) + .build(); + Assert.assertEquals("Sort order must match", expectedOrder, table.sortOrder()); + + Dataset inputDF = unclusteredInput(); + + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(20L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @Test + public void testHashDistributionWithExplicitSortOrder() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id INT, category STRING, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(4, id))", + tableName); + + sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION ORDERED BY category", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals( + "Distribution mode must be hash", + "hash", + table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE)); + + Dataset inputDF = unclusteredInput(); + + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(20L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @Test + public void testNoneDistributionModeViaTableProperty() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id INT, category STRING, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(4, id))", + tableName); + // `none` distribution: the rule attaches no repartition and no synthesized sort. Fanout is + // enabled so the FanoutDataWriter accepts the unclustered input directly. + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s'='none', '%s'='true')", + tableName, + TableProperties.WRITE_DISTRIBUTION_MODE, + TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals( + "Distribution mode must be none", + "none", + table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE)); + + Dataset inputDF = unclusteredInput(); + + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(20L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + @Test + public void testRangeDistributionModeViaTableProperty() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id INT, category STRING, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(4, id))", + tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s'='range')", + tableName, TableProperties.WRITE_DISTRIBUTION_MODE); + + Table table = validationCatalog.loadTable(tableIdent); + Assert.assertEquals( + "Distribution mode must be range", + "range", + table.properties().get(TableProperties.WRITE_DISTRIBUTION_MODE)); + + Dataset inputDF = unclusteredInput(); + + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(20L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + // Rule disabled — asserts ClusteredDataWriter rejects unclustered input. Pre-rule baseline. + + @Test + public void testNoneDistributionFailsWithoutRule() { + sql( + "CREATE TABLE %s (id INT, category STRING, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(4, id))", + tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s'='none')", + tableName, TableProperties.WRITE_DISTRIBUTION_MODE); + + assertWriterRejectsUnclusteredInput(); + } + + @Test + public void testHashDistributionFailsWithoutRule() { + sql( + "CREATE TABLE %s (id INT, category STRING, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(4, id))", + tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s'='hash')", + tableName, TableProperties.WRITE_DISTRIBUTION_MODE); + + assertWriterRejectsUnclusteredInput(); + } + + @Test + public void testRangeDistributionFailsWithoutRule() { + sql( + "CREATE TABLE %s (id INT, category STRING, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(4, id))", + tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s'='range')", + tableName, TableProperties.WRITE_DISTRIBUTION_MODE); + + assertWriterRejectsUnclusteredInput(); + } + + private void assertWriterRejectsUnclusteredInput() { + Dataset inputDF = unclusteredInput(); + spark + .conf() + .set( + "spark.sql.optimizer.excludedRules", + "org.apache.spark.sql.execution.datasources.v2.ExtendedV2Writes"); + try { + Assertions.assertThatThrownBy(() -> inputDF.writeTo(tableName).append()) + .as( + "ClusteredDataWriter should reject unclustered input when ExtendedV2Writes is disabled") + .isInstanceOf(SparkException.class) + .hasStackTraceContaining("Incoming records violate the writer assumption"); + } finally { + spark.conf().unset("spark.sql.optimizer.excludedRules"); + } + } + + // Empty input on the rule's no-op path (HASH unsorted): must still produce a clean snapshot. + @Test + public void testEmptyInputWithHashDistribution() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id INT, category STRING, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(4, id))", + tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s'='hash')", + tableName, TableProperties.WRITE_DISTRIBUTION_MODE); + + Dataset emptyDF = unclusteredInput().where("1 = 0"); + emptyDF.writeTo(tableName).append(); + + assertEquals( + "Row count must be zero", + ImmutableList.of(row(0L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + // Null partition values, one row per task — trivially clustered. Null-handling guard. + @Test + public void testNullPartitionValuesWithHashDistribution() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id INT, category STRING, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (category)", + tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s'='hash')", + tableName, TableProperties.WRITE_DISTRIBUTION_MODE); + + List data = + ImmutableList.of( + new ThreeColumnRecord(1, null, "d1"), + new ThreeColumnRecord(2, null, "d2"), + new ThreeColumnRecord(3, null, "d3"), + new ThreeColumnRecord(4, null, "d4")); + Dataset inputDF = + spark + .createDataFrame(data, ThreeColumnRecord.class) + .selectExpr("c1 AS id", "c2 AS category", "c3 AS data") + .repartition(4); + + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(4L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + // High-cardinality bucket transform: rule is a no-op (unsorted), fanout handles unclustering. + @Test + public void testHighCardinalityBucketWithHashDistribution() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id INT, category STRING, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(64, id))", + tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s'='hash', '%s'='true')", + tableName, + TableProperties.WRITE_DISTRIBUTION_MODE, + TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED); + + Dataset inputDF = unclusteredInput(); + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(20L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + // Single shuffle partition → all bucket values in one task, unclustered. Fanout required. + @Test + public void testHashDistributionWithSingleShufflePartition() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id INT, category STRING, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(4, id))", + tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s'='hash', '%s'='true')", + tableName, + TableProperties.WRITE_DISTRIBUTION_MODE, + TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED); + + String original = spark.conf().get("spark.sql.shuffle.partitions"); + spark.conf().set("spark.sql.shuffle.partitions", "1"); + try { + unclusteredInput().writeTo(tableName).append(); + } finally { + spark.conf().set("spark.sql.shuffle.partitions", original); + } + + assertEquals( + "Row count must match", + ImmutableList.of(row(20L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + // AQE coalesces post-shuffle partitions → multiple buckets in one task. Fanout required. + @Test + public void testHashDistributionWithAQEEnabled() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id INT, category STRING, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(4, id))", + tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s'='hash', '%s'='true')", + tableName, + TableProperties.WRITE_DISTRIBUTION_MODE, + TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED); + + String original = spark.conf().get("spark.sql.adaptive.enabled"); + spark.conf().set("spark.sql.adaptive.enabled", "true"); + try { + unclusteredInput().writeTo(tableName).append(); + } finally { + spark.conf().set("spark.sql.adaptive.enabled", original); + } + + assertEquals( + "Row count must match", + ImmutableList.of(row(20L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + // Fanout writer accepts unclustered input directly. Rule is a no-op (unsorted, no distribution). + @Test + public void testFanoutWriterWithHashDistribution() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id INT, category STRING, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(4, id))", + tableName); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s'='hash', '%s'='true')", + tableName, + TableProperties.WRITE_DISTRIBUTION_MODE, + TableProperties.SPARK_WRITE_PARTITIONED_FANOUT_ENABLED); + + unclusteredInput().writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(20L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + // saveAsTable("append") on an existing V2 table produces AppendData, the same plan node the + // rule matches for writeTo(...).append(). RANGE distribution keeps the rule active (sort + // attached) so this exercises the full path, not the no-op branch. + @Test + public void testSaveAsTableAppendWithRangeDistribution() throws NoSuchTableException { + sql( + "CREATE TABLE %s (id INT, category STRING, data STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(4, id))", + tableName); + sql("ALTER TABLE %s WRITE ORDERED BY category, id", tableName); + + unclusteredInput().write().mode("append").saveAsTable(tableName); + + assertEquals( + "Row count must match", + ImmutableList.of(row(20L)), + sql("SELECT count(*) FROM %s", tableName)); + } + + // 20 rows across 4 buckets, randomly spread across 4 Spark partitions. Worst-case + // unclustered input for ClusteredDataWriter. + private Dataset unclusteredInput() { + List data = + ImmutableList.of( + new ThreeColumnRecord(0, "B", "d0"), + new ThreeColumnRecord(1, "A", "d1"), + new ThreeColumnRecord(2, "C", "d2"), + new ThreeColumnRecord(3, "B", "d3"), + new ThreeColumnRecord(4, "A", "d4"), + new ThreeColumnRecord(5, "C", "d5"), + new ThreeColumnRecord(6, "B", "d6"), + new ThreeColumnRecord(7, "A", "d7"), + new ThreeColumnRecord(8, "C", "d8"), + new ThreeColumnRecord(9, "B", "d9"), + new ThreeColumnRecord(10, "A", "d10"), + new ThreeColumnRecord(11, "C", "d11"), + new ThreeColumnRecord(12, "B", "d12"), + new ThreeColumnRecord(13, "A", "d13"), + new ThreeColumnRecord(14, "C", "d14"), + new ThreeColumnRecord(15, "B", "d15"), + new ThreeColumnRecord(16, "A", "d16"), + new ThreeColumnRecord(17, "C", "d17"), + new ThreeColumnRecord(18, "B", "d18"), + new ThreeColumnRecord(19, "A", "d19")); + return spark + .createDataFrame(data, ThreeColumnRecord.class) + .selectExpr("c1 AS id", "c2 AS category", "c3 AS data") + .repartition(4); + } +}