diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.json
new file mode 100644
index 000000000000..c4edaa85a89d
--- /dev/null
+++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.json
@@ -0,0 +1,3 @@
+{
+ "comment": "Modify this file in a trivial way to cause this test suite to run"
+}
diff --git a/.github/trigger_files/beam_PreCommit_Java_Spark4_Versions.json b/.github/trigger_files/beam_PreCommit_Java_Spark4_Versions.json
new file mode 100644
index 000000000000..c4edaa85a89d
--- /dev/null
+++ b/.github/trigger_files/beam_PreCommit_Java_Spark4_Versions.json
@@ -0,0 +1,3 @@
+{
+ "comment": "Modify this file in a trivial way to cause this test suite to run"
+}
diff --git a/.github/workflows/README.md b/.github/workflows/README.md
index f8c77b7180db..b095bc2a2447 100644
--- a/.github/workflows/README.md
+++ b/.github/workflows/README.md
@@ -373,6 +373,7 @@ PostCommit Jobs run in a schedule against master branch and generally do not get
| [ PostCommit Java ValidatesRunner Spark Java8 ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Spark_Java8.yml) | N/A |`beam_PostCommit_Java_ValidatesRunner_Spark_Java8.json`| [](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Spark_Java8.yml?query=event%3Aschedule) |
| [ PostCommit Java ValidatesRunner Spark ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Spark.yml) | N/A |`beam_PostCommit_Java_ValidatesRunner_Spark.json`| [](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Spark.yml?query=event%3Aschedule) |
| [ PostCommit Java ValidatesRunner SparkStructuredStreaming ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.yml) | N/A |`beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.json`| [](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.yml?query=event%3Aschedule) |
+| [ PostCommit Java ValidatesRunner Spark4StructuredStreaming ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.yml) | N/A |`beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.json`| [](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.yml?query=event%3Aschedule) |
| [ PostCommit Java ValidatesRunner Twister2 ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Twister2.yml) | N/A |`beam_PostCommit_Java_ValidatesRunner_Twister2.json`| [](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Twister2.yml?query=event%3Aschedule) |
| [ PostCommit Java ValidatesRunner ULR ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_ULR.yml) | N/A |`beam_PostCommit_Java_ValidatesRunner_ULR.json`| [](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_ULR.yml?query=event%3Aschedule) |
| [ PostCommit Java ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java.yml) | N/A |`beam_PostCommit_Java.json`| [](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java.yml?query=event%3Aschedule) |
diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.yml
new file mode 100644
index 000000000000..b595afe6f42c
--- /dev/null
+++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.yml
@@ -0,0 +1,97 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: PostCommit Java ValidatesRunner Spark4 StructuredStreaming
+
+on:
+ schedule:
+ - cron: '45 4/6 * * *'
+ pull_request_target:
+ paths: ['release/trigger_all_tests.json', '.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.json']
+ workflow_dispatch:
+
+#Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event
+permissions:
+ actions: write
+ pull-requests: write
+ checks: write
+ contents: read
+ deployments: read
+ id-token: none
+ issues: write
+ discussions: read
+ packages: read
+ pages: read
+ repository-projects: read
+ security-events: read
+ statuses: read
+
+# This allows a subsequently queued workflow run to interrupt previous runs
+concurrency:
+ group: '${{ github.workflow }} @ ${{ github.event.pull_request.number || github.sha || github.head_ref || github.ref }}-${{ github.event.schedule || github.event.comment.id || github.event.sender.login }}'
+ cancel-in-progress: true
+
+env:
+ DEVELOCITY_ACCESS_KEY: ${{ secrets.DEVELOCITY_ACCESS_KEY }}
+ GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }}
+ GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }}
+
+jobs:
+ beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming:
+ name: ${{ matrix.job_name }} (${{ matrix.job_phrase }})
+ runs-on: [self-hosted, ubuntu-24.04, main]
+ timeout-minutes: 120
+ strategy:
+ matrix:
+ job_name: [beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming]
+ job_phrase: [Run Spark4 StructuredStreaming ValidatesRunner]
+ if: |
+ github.event_name == 'workflow_dispatch' ||
+ github.event_name == 'pull_request_target' ||
+ (github.event_name == 'schedule' && github.repository == 'apache/beam') ||
+ github.event.comment.body == 'Run Spark4 StructuredStreaming ValidatesRunner'
+ steps:
+ - uses: actions/checkout@v4
+ - name: Setup repository
+ uses: ./.github/actions/setup-action
+ with:
+ comment_phrase: ${{ matrix.job_phrase }}
+ github_token: ${{ secrets.GITHUB_TOKEN }}
+ github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }})
+ - name: Setup environment
+ uses: ./.github/actions/setup-environment-action
+ with:
+ java-version: '17'
+ - name: run validatesStructuredStreamingRunnerBatch script
+ uses: ./.github/actions/gradle-command-self-hosted-action
+ with:
+ gradle-command: :runners:spark:4:validatesStructuredStreamingRunnerBatch
+ arguments: |
+ -PtestJavaVersion=17 \
+ -PdisableSpotlessCheck=true \
+ - name: Archive JUnit Test Results
+ uses: actions/upload-artifact@v4
+ if: ${{ !success() }}
+ with:
+ name: JUnit Test Results
+ path: "**/build/reports/tests/"
+ - name: Publish JUnit Test Results
+ uses: EnricoMi/publish-unit-test-result-action@v2
+ if: always()
+ with:
+ commit: '${{ env.prsha || env.GITHUB_SHA }}'
+ comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }}
+ files: '**/build/test-results/**/*.xml'
+ large_files: true
diff --git a/CHANGES.md b/CHANGES.md
index e8e11e830d14..f092c81f92fb 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -60,7 +60,6 @@
## Highlights
* New highly anticipated feature X added to Python SDK ([#X](https://github.com/apache/beam/issues/X)).
-* New highly anticipated feature Y added to Java SDK ([#Y](https://github.com/apache/beam/issues/Y)).
## I/Os
diff --git a/gradle.properties b/gradle.properties
index 2cff0b656fa7..04c85957e4eb 100644
--- a/gradle.properties
+++ b/gradle.properties
@@ -41,6 +41,6 @@ docker_image_default_repo_prefix=beam_
# supported flink versions
flink_versions=1.17,1.18,1.19,1.20,2.0
# supported spark versions
-spark_versions=3
+spark_versions=3,4
# supported python versions
python_versions=3.10,3.11,3.12,3.13,3.14
diff --git a/runners/spark/4/README.md b/runners/spark/4/README.md
new file mode 100644
index 000000000000..371a44512657
--- /dev/null
+++ b/runners/spark/4/README.md
@@ -0,0 +1,73 @@
+
+# Apache Beam Spark 4 Runner
+
+Experimental Beam runner for Apache Spark 4 (batch-only). Built on the shared
+`runners/spark` source base via `spark_runner.gradle`'s per-version
+source-overrides mechanism: this module contributes the small set of files
+under `src/main/java/.../structuredstreaming/` that diverge from the Spark 3
+implementation. See the parent `runners/spark/` module for the bulk of the
+runner code.
+
+## Requirements
+
+* **Spark 4.0.2** (and other Spark 4.0.x patch releases)
+* **Scala 2.13**
+* **Java 17** — Spark 4 does not run on earlier JDKs
+
+## Status
+
+Batch only. Streaming is tracked in
+[#36841](https://github.com/apache/beam/issues/36841).
+
+## Known issues
+
+### `StackOverflowError` from `slf4j-jdk14` on the runtime classpath
+
+Spark 4 ships `org.slf4j:jul-to-slf4j` to route `java.util.logging` records
+into SLF4J. If `org.slf4j:slf4j-jdk14` is also resolved at runtime — it routes
+the other direction (SLF4J → JUL) — the first log line creates an infinite
+loop:
+
+```
+java.lang.StackOverflowError
+ at org.slf4j.bridge.SLF4JBridgeHandler.publish(...)
+ at java.util.logging.Logger.log(...)
+ at org.slf4j.impl.JDK14LoggerAdapter.log(...)
+ at org.slf4j.bridge.SLF4JBridgeHandler.publish(...)
+ ...
+```
+
+This is the same condition that broke the Spark 3 runner in
+[#26985](https://github.com/apache/beam/issues/26985), fixed in
+[#27001](https://github.com/apache/beam/pull/27001).
+
+The shared `spark_runner.gradle` already excludes `slf4j-jdk14` from the
+runner module's own `configurations.all`, so in-tree builds are unaffected.
+Downstream Gradle consumers that assemble a runtime classpath against
+`beam-runners-spark-4` should mirror that exclude:
+
+```groovy
+configurations.all {
+ exclude group: "org.slf4j", module: "slf4j-jdk14"
+}
+```
+
+For Maven, exclude `org.slf4j:slf4j-jdk14` from any dependency that pulls it
+transitively (commonly the Beam SDK harness and several IO connectors).
diff --git a/runners/spark/4/build.gradle b/runners/spark/4/build.gradle
new file mode 100644
index 000000000000..01fb3680b078
--- /dev/null
+++ b/runners/spark/4/build.gradle
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+def basePath = '..'
+/* All properties required for loading the Spark build script */
+project.ext {
+ spark_major = '4'
+ // Spark 4 version as defined in BeamModulePlugin; requires Scala 2.13 and Java 17
+ spark_version = spark4_version
+ spark_scala_version = '2.13'
+ archives_base_name = 'beam-runners-spark-4'
+}
+
+// Load the main build script which contains all build logic.
+// spark_runner.gradle handles the per-version source-overrides Copy:
+// shared base (runners/spark/src/) + previous majors + this module's ./src/ are
+// merged into build/source-overrides/src using DuplicatesStrategy.INCLUDE so the
+// 11 files under runners/spark/4/src/.../structuredstreaming/ override the
+// shared-base versions.
+apply from: "$basePath/spark_runner.gradle"
+
+// Spark 4 always requires Java 17, so unconditionally add the --add-opens flags
+// required by Kryo and other libraries that use reflection on JDK internals.
+test {
+ jvmArgs "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED",
+ "--add-opens=java.base/java.nio=ALL-UNNAMED",
+ "--add-opens=java.base/java.util=ALL-UNNAMED",
+ "--add-opens=java.base/java.lang.invoke=ALL-UNNAMED"
+}
+
+// Exclude DStream-based streaming tests from the shared-base copy: the Spark 4 module
+// supports only structured streaming (batch) and does not include legacy DStream support.
+// Streaming test utilities also depend on kafka.server.KafkaServerStartable which was
+// removed in Kafka 2.8.0 (the first Kafka version with a _2.13 artifact).
+tasks.named("copyTestSourceOverrides") {
+ exclude "**/translation/streaming/**"
+}
+
diff --git a/runners/spark/4/job-server/build.gradle b/runners/spark/4/job-server/build.gradle
new file mode 100644
index 000000000000..598cf3b4913a
--- /dev/null
+++ b/runners/spark/4/job-server/build.gradle
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+def basePath = '../../job-server'
+
+project.ext {
+ // Look for the source code in the parent module
+ main_source_dirs = ["$basePath/src/main/java"]
+ test_source_dirs = ["$basePath/src/test/java"]
+ main_resources_dirs = ["$basePath/src/main/resources"]
+ test_resources_dirs = ["$basePath/src/test/resources"]
+ archives_base_name = 'beam-runners-spark-4-job-server'
+}
+
+// Load the main build script which contains all build logic.
+apply from: "$basePath/spark_job_server.gradle"
diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java
new file mode 100644
index 000000000000..d32dc14eccc0
--- /dev/null
+++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java
@@ -0,0 +1,332 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.io;
+
+import static java.util.stream.Collectors.toList;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.emptyList;
+import static org.apache.beam.sdk.values.WindowedValues.timestampedValueInGlobalWindow;
+import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
+import static scala.collection.JavaConverters.asScalaIterator;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.IntSupplier;
+import java.util.function.Supplier;
+import javax.annotation.CheckForNull;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.io.BoundedSource;
+import org.apache.beam.sdk.io.BoundedSource.BoundedReader;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.values.WindowedValue;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
+import org.apache.spark.InterruptibleIterator;
+import org.apache.spark.Partition;
+import org.apache.spark.SparkContext;
+import org.apache.spark.TaskContext;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Serializer;
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
+import org.apache.spark.sql.classic.Dataset$;
+import org.apache.spark.sql.connector.catalog.SupportsRead;
+import org.apache.spark.sql.connector.catalog.Table;
+import org.apache.spark.sql.connector.catalog.TableCapability;
+import org.apache.spark.sql.connector.read.Batch;
+import org.apache.spark.sql.connector.read.InputPartition;
+import org.apache.spark.sql.connector.read.PartitionReader;
+import org.apache.spark.sql.connector.read.PartitionReaderFactory;
+import org.apache.spark.sql.connector.read.Scan;
+import org.apache.spark.sql.connector.read.ScanBuilder;
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.util.CaseInsensitiveStringMap;
+import scala.Option;
+import scala.collection.Iterator;
+import scala.reflect.ClassTag;
+
+public class BoundedDatasetFactory {
+ private BoundedDatasetFactory() {}
+
+ /**
+ * Create a {@link Dataset} for a {@link BoundedSource} via a Spark {@link Table}.
+ *
+ *
Unfortunately tables are expected to return an {@link InternalRow}, requiring serialization.
+ * This makes this approach at the time being significantly less performant than creating a
+ * dataset from an RDD.
+ */
+ public static Dataset> createDatasetFromRows(
+ SparkSession session,
+ BoundedSource source,
+ Supplier options,
+ Encoder> encoder) {
+ Params params = new Params<>(encoder, options, session.sparkContext().defaultParallelism());
+ BeamTable table = new BeamTable<>(source, params);
+ LogicalPlan logicalPlan = DataSourceV2Relation.create(table, Option.empty(), Option.empty());
+ // In Spark 4.0+, Dataset$ moved to org.apache.spark.sql.classic and its ofRows() now
+ // takes the classic SparkSession subclass. The runtime instance returned by
+ // SparkSession.builder() is always a classic.SparkSession, so the downcast is safe and
+ // avoids reflection.
+ return (Dataset>)
+ Dataset$.MODULE$
+ .ofRows((org.apache.spark.sql.classic.SparkSession) session, logicalPlan)
+ .as(encoder);
+ }
+
+ /**
+ * Create a {@link Dataset} for a {@link BoundedSource} via a Spark {@link RDD}.
+ *
+ * This is currently the most efficient approach as it avoid any serialization overhead.
+ */
+ public static Dataset> createDatasetFromRDD(
+ SparkSession session,
+ BoundedSource source,
+ Supplier options,
+ Encoder> encoder) {
+ Params params = new Params<>(encoder, options, session.sparkContext().defaultParallelism());
+ RDD> rdd = new BoundedRDD<>(session.sparkContext(), source, params);
+ return session.createDataset(rdd, encoder);
+ }
+
+ /** An {@link RDD} for a bounded Beam source. */
+ private static class BoundedRDD extends RDD> {
+ final BoundedSource source;
+ final Params params;
+
+ public BoundedRDD(SparkContext sc, BoundedSource source, Params params) {
+ super(sc, emptyList(), ClassTag.apply(WindowedValue.class));
+ this.source = source;
+ this.params = params;
+ }
+
+ @Override
+ public Iterator> compute(Partition split, TaskContext context) {
+ return new InterruptibleIterator<>(
+ context,
+ asScalaIterator(new SourcePartitionIterator<>((SourcePartition) split, params)));
+ }
+
+ @Override
+ public Partition[] getPartitions() {
+ return SourcePartition.partitionsOf(source, params).toArray(new Partition[0]);
+ }
+ }
+
+ /** A Spark {@link Table} for a bounded Beam source supporting batch reads only. */
+ private static class BeamTable implements Table, SupportsRead {
+ final BoundedSource source;
+ final Params params;
+
+ BeamTable(BoundedSource source, Params params) {
+ this.source = source;
+ this.params = params;
+ }
+
+ public Encoder> getEncoder() {
+ return params.encoder;
+ }
+
+ @Override
+ public ScanBuilder newScanBuilder(CaseInsensitiveStringMap ignored) {
+ return () ->
+ new Scan() {
+ @Override
+ public StructType readSchema() {
+ return params.encoder.schema();
+ }
+
+ @Override
+ public Batch toBatch() {
+ return new BeamBatch<>(source, params);
+ }
+ };
+ }
+
+ @Override
+ public String name() {
+ return "BeamSource<" + source.getClass().getName() + ">";
+ }
+
+ @Override
+ public StructType schema() {
+ return params.encoder.schema();
+ }
+
+ @Override
+ public Set capabilities() {
+ return ImmutableSet.of(TableCapability.BATCH_READ);
+ }
+
+ private static class BeamBatch implements Batch, Serializable {
+ final BoundedSource source;
+ final Params params;
+
+ private BeamBatch(BoundedSource source, Params params) {
+ this.source = source;
+ this.params = params;
+ }
+
+ @Override
+ public InputPartition[] planInputPartitions() {
+ return SourcePartition.partitionsOf(source, params).toArray(new InputPartition[0]);
+ }
+
+ @Override
+ public PartitionReaderFactory createReaderFactory() {
+ return p -> new BeamPartitionReader<>(((SourcePartition) p), params);
+ }
+ }
+
+ private static class BeamPartitionReader implements PartitionReader {
+ final SourcePartitionIterator iterator;
+ final Serializer> serializer;
+ transient @Nullable InternalRow next;
+
+ BeamPartitionReader(SourcePartition partition, Params params) {
+ iterator = new SourcePartitionIterator<>(partition, params);
+ serializer = ((ExpressionEncoder>) params.encoder).createSerializer();
+ }
+
+ @Override
+ public boolean next() throws IOException {
+ if (iterator.hasNext()) {
+ next = serializer.apply(iterator.next());
+ return true;
+ }
+ return false;
+ }
+
+ @Override
+ public InternalRow get() {
+ if (next == null) {
+ throw new IllegalStateException("Next not available");
+ }
+ return next;
+ }
+
+ @Override
+ public void close() throws IOException {
+ next = null;
+ iterator.close();
+ }
+ }
+ }
+
+ /** A Spark partition wrapping the partitioned Beam {@link BoundedSource}. */
+ private static class SourcePartition implements Partition, InputPartition {
+ final BoundedSource source;
+ final int index;
+
+ SourcePartition(BoundedSource source, IntSupplier idxSupplier) {
+ this.source = source;
+ this.index = idxSupplier.getAsInt();
+ }
+
+ static List> partitionsOf(BoundedSource source, Params params) {
+ try {
+ PipelineOptions options = params.options.get();
+ long desiredSize = source.getEstimatedSizeBytes(options) / params.numPartitions;
+ List extends BoundedSource> split = source.split(desiredSize, options);
+ IntSupplier idxSupplier = new AtomicInteger(0)::getAndIncrement;
+ return split.stream().map(s -> new SourcePartition<>(s, idxSupplier)).collect(toList());
+ } catch (Exception e) {
+ throw new RuntimeException(
+ "Error splitting BoundedSource " + source.getClass().getCanonicalName(), e);
+ }
+ }
+
+ @Override
+ public int index() {
+ return index;
+ }
+
+ @Override
+ public int hashCode() {
+ return index;
+ }
+ }
+
+ /** A partition iterator on a partitioned Beam {@link BoundedSource}. */
+ private static class SourcePartitionIterator extends AbstractIterator>
+ implements Closeable {
+ BoundedReader reader;
+ boolean started = false;
+
+ public SourcePartitionIterator(SourcePartition partition, Params params) {
+ try {
+ reader = partition.source.createReader(params.options.get());
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to create reader from a BoundedSource.", e);
+ }
+ }
+
+ @Override
+ @SuppressWarnings("nullness") // ok, reader not used any longer
+ public void close() throws IOException {
+ if (reader != null) {
+ try {
+ reader.close();
+ } finally {
+ reader = null;
+ }
+ }
+ }
+
+ @Override
+ protected @CheckForNull WindowedValue computeNext() {
+ try {
+ if (started ? reader.advance() : start()) {
+ return timestampedValueInGlobalWindow(reader.getCurrent(), reader.getCurrentTimestamp());
+ } else {
+ close();
+ return endOfData();
+ }
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to start or advance reader.", e);
+ }
+ }
+
+ private boolean start() throws IOException {
+ started = true;
+ return reader.start();
+ }
+ }
+
+ /** Shared parameters. */
+ private static class Params implements Serializable {
+ final Encoder> encoder;
+ final Supplier options;
+ final int numPartitions;
+
+ Params(
+ Encoder> encoder, Supplier options, int numPartitions) {
+ checkArgument(numPartitions > 0, "Number of partitions must be greater than zero.");
+ this.encoder = encoder;
+ this.options = options;
+ this.numPartitions = numPartitions;
+ }
+ }
+}
diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java
new file mode 100644
index 000000000000..e483c2db0df4
--- /dev/null
+++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
+
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGlobalGroupBy;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGroupByWindow;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.explodeWindowedKey;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.value;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueKey;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueValue;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.windowedKV;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1;
+
+import java.util.Collection;
+import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
+import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop;
+import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1;
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderRegistry;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.WindowedValue;
+import org.apache.beam.sdk.values.WindowedValues;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.expressions.Aggregator;
+import scala.Tuple2;
+import scala.collection.IterableOnce;
+
+/**
+ * Translator for {@link Combine.PerKey} using {@link Dataset#groupByKey} with a Spark {@link
+ * Aggregator}.
+ *
+ *
+ * - When using the default global window, window information is dropped and restored after the
+ * aggregation.
+ *
- For non-merging windows, windows are exploded and moved into a composite key for better
+ * distribution. After the aggregation, windowed values are restored from the composite key.
+ *
- All other cases use an aggregator on windowed values that is optimized for the current
+ * windowing strategy.
+ *
+ *
+ * TODOs:
+ * combine with context (CombineFnWithContext)?
+ * combine with sideInputs?
+ * are there other missing features?
+ */
+class CombinePerKeyTranslatorBatch
+ extends TransformTranslator<
+ PCollection>, PCollection>, Combine.PerKey> {
+
+ CombinePerKeyTranslatorBatch() {
+ super(0.2f);
+ }
+
+ @Override
+ public void translate(Combine.PerKey transform, Context cxt) {
+ WindowingStrategy, ?> windowing = cxt.getInput().getWindowingStrategy();
+ CombineFn combineFn = (CombineFn) transform.getFn();
+
+ KvCoder inputCoder = (KvCoder) cxt.getInput().getCoder();
+ KvCoder outputCoder = (KvCoder) cxt.getOutput().getCoder();
+
+ Encoder keyEnc = cxt.keyEncoderOf(inputCoder);
+ Encoder> inputEnc = cxt.encoderOf(inputCoder);
+ Encoder>> wvOutputEnc = cxt.windowedEncoder(outputCoder);
+ Encoder accumEnc = accumEncoder(combineFn, inputCoder.getValueCoder(), cxt);
+
+ final Dataset>> result;
+
+ boolean globalGroupBy = eligibleForGlobalGroupBy(windowing, true);
+ boolean groupByWindow = eligibleForGroupByWindow(windowing, true);
+
+ if (globalGroupBy || groupByWindow) {
+ Aggregator, ?, OutT> valueAgg =
+ Aggregators.value(combineFn, KV::getValue, accumEnc, cxt.valueEncoderOf(outputCoder));
+
+ if (globalGroupBy) {
+ // Drop window and group by key globally to run the aggregation (combineFn), afterwards the
+ // global window is restored
+ result =
+ cxt.getDataset(cxt.getInput())
+ .groupByKey(valueKey(), keyEnc)
+ .mapValues(value(), inputEnc)
+ .agg(valueAgg.toColumn())
+ .map(globalKV(), wvOutputEnc);
+ } else {
+ Encoder> windowedKeyEnc =
+ cxt.tupleEncoder(cxt.windowEncoder(), keyEnc);
+
+ // Group by window and key to run the aggregation (combineFn)
+ result =
+ cxt.getDataset(cxt.getInput())
+ .flatMap(explodeWindowedKey(value()), cxt.tupleEncoder(windowedKeyEnc, inputEnc))
+ .groupByKey(fun1(Tuple2::_1), windowedKeyEnc)
+ .mapValues(fun1(Tuple2::_2), inputEnc)
+ .agg(valueAgg.toColumn())
+ .map(windowedKV(), wvOutputEnc);
+ }
+ } else {
+ // Optimized aggregator for non-merging and session window functions, all others depend on
+ // windowFn.mergeWindows
+ Aggregator>, ?, Collection>> aggregator =
+ Aggregators.windowedValue(
+ combineFn,
+ valueValue(),
+ windowing,
+ cxt.windowEncoder(),
+ accumEnc,
+ cxt.windowedEncoder(outputCoder.getValueCoder()));
+ result =
+ cxt.getDataset(cxt.getInput())
+ .groupByKey(valueKey(), keyEnc)
+ .agg(aggregator.toColumn())
+ .flatMap(explodeWindows(), wvOutputEnc);
+ }
+
+ cxt.putDataset(cxt.getOutput(), result);
+ }
+
+ private static
+ Fun1>>, IterableOnce>>>
+ explodeWindows() {
+ return t ->
+ ScalaInterop.scalaIterator(t._2).map(wv -> wv.withValue(KV.of(t._1, wv.getValue())));
+ }
+
+ private static Fun1, WindowedValue>> globalKV() {
+ return t -> WindowedValues.valueInGlobalWindow(KV.of(t._1, t._2));
+ }
+
+ private Encoder accumEncoder(
+ CombineFn fn, Coder valueCoder, Context cxt) {
+ try {
+ CoderRegistry registry = cxt.getInput().getPipeline().getCoderRegistry();
+ return cxt.encoderOf(fn.getAccumulatorCoder(registry, valueCoder));
+ } catch (CannotProvideCoderException e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyHelpers.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyHelpers.java
new file mode 100644
index 000000000000..f25121e1b478
--- /dev/null
+++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyHelpers.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
+
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple;
+import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING;
+import static org.apache.beam.sdk.transforms.windowing.TimestampCombiner.END_OF_WINDOW;
+
+import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop;
+import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
+import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.WindowedValue;
+import org.apache.beam.sdk.values.WindowedValues;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import scala.Tuple2;
+import scala.collection.IterableOnce;
+
+/**
+ * Package private helpers to support translating grouping transforms using `groupByKey` such as
+ * {@link GroupByKeyTranslatorBatch} or {@link CombinePerKeyTranslatorBatch}.
+ */
+class GroupByKeyHelpers {
+
+ private GroupByKeyHelpers() {}
+
+ /**
+ * Checks if it's possible to use an optimized `groupByKey` that also moves the window into the
+ * key.
+ *
+ * @param windowing The windowing strategy
+ * @param endOfWindowOnly Flag if to limit this optimization to {@link
+ * TimestampCombiner#END_OF_WINDOW}.
+ */
+ static boolean eligibleForGroupByWindow(
+ WindowingStrategy, ?> windowing, boolean endOfWindowOnly) {
+ return !windowing.needsMerge()
+ && (!endOfWindowOnly || windowing.getTimestampCombiner() == END_OF_WINDOW)
+ && windowing.getWindowFn().windowCoder().consistentWithEquals();
+ }
+
+ /**
+ * Checks if it's possible to use an optimized `groupByKey` for the global window.
+ *
+ * @param windowing The windowing strategy
+ * @param endOfWindowOnly Flag if to limit this optimization to {@link
+ * TimestampCombiner#END_OF_WINDOW}.
+ */
+ static boolean eligibleForGlobalGroupBy(
+ WindowingStrategy, ?> windowing, boolean endOfWindowOnly) {
+ return windowing.getWindowFn() instanceof GlobalWindows
+ && (!endOfWindowOnly || windowing.getTimestampCombiner() == END_OF_WINDOW);
+ }
+
+ /**
+ * Explodes a windowed {@link KV} assigned to potentially multiple {@link BoundedWindow}s to a
+ * traversable of composite keys {@code (BoundedWindow, Key)} and value.
+ */
+ static
+ Fun1>, IterableOnce, T>>>
+ explodeWindowedKey(Fun1>, T> valueFn) {
+ return v -> {
+ T value = valueFn.apply(v);
+ K key = v.getValue().getKey();
+ return ScalaInterop.scalaIterator(v.getWindows()).map(w -> tuple(tuple(w, key), value));
+ };
+ }
+
+ static Fun1, V>, WindowedValue>> windowedKV() {
+ return t -> windowedKV(t._1, t._2);
+ }
+
+ static WindowedValue> windowedKV(Tuple2 key, V value) {
+ return WindowedValues.of(KV.of(key._2, value), key._1.maxTimestamp(), key._1, NO_FIRING);
+ }
+
+ static Fun1, V> value() {
+ return v -> v.getValue();
+ }
+
+ static Fun1>, V> valueValue() {
+ return v -> v.getValue().getValue();
+ }
+
+ static Fun1>, K> valueKey() {
+ return v -> v.getValue().getKey();
+ }
+}
diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java
new file mode 100644
index 000000000000..7caf06cb38fd
--- /dev/null
+++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java
@@ -0,0 +1,298 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
+
+import static org.apache.beam.repackaged.core.org.apache.commons.lang3.ArrayUtils.EMPTY_BYTE_ARRAY;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGlobalGroupBy;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGroupByWindow;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.explodeWindowedKey;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueKey;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueValue;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.windowedKV;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.CoderHelpers.toByteArray;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.collectionEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderOf;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.kvEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.windowedValueEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.concat;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun2;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.javaIterator;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf;
+import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING;
+import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;
+import static org.apache.spark.sql.functions.col;
+import static org.apache.spark.sql.functions.collect_list;
+import static org.apache.spark.sql.functions.explode;
+import static org.apache.spark.sql.functions.max;
+import static org.apache.spark.sql.functions.min;
+import static org.apache.spark.sql.functions.struct;
+
+import java.io.Serializable;
+import org.apache.beam.runners.core.InMemoryStateInternals;
+import org.apache.beam.runners.core.ReduceFnRunner;
+import org.apache.beam.runners.core.StateInternalsFactory;
+import org.apache.beam.runners.core.SystemReduceFn;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
+import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
+import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.GroupAlsoByWindowViaOutputBufferFn;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder;
+import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.WindowedValue;
+import org.apache.beam.sdk.values.WindowedValues;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.spark.sql.Column;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.TypedColumn;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import scala.Tuple2;
+import scala.collection.Iterator;
+import scala.collection.JavaConverters;
+import scala.collection.immutable.List;
+
+/**
+ * Translator for {@link GroupByKey} using {@link Dataset#groupByKey} with the built-in aggregation
+ * function {@code collect_list} when applicable.
+ *
+ * Note: Using {@code collect_list} isn't any worse than using {@link ReduceFnRunner}. In the
+ * latter case the entire group (iterator) has to be loaded into memory as well. Either way there's
+ * a risk of OOM errors. When enabling {@link
+ * SparkCommonPipelineOptions#getPreferGroupByKeyToHandleHugeValues()}, a more memory sensitive
+ * iterable is used that can be traversed just once. Attempting to traverse the iterable again will
+ * throw.
+ *
+ *
+ * - When using the default global window, window information is dropped and restored after the
+ * aggregation.
+ *
- For non-merging windows, windows are exploded and moved into a composite key for better
+ * distribution. Though, to keep the amount of shuffled data low, this is only done if values
+ * are assigned to a single window or if there are only few keys and distributing data is
+ * important. After the aggregation, windowed values are restored from the composite key.
+ *
- All other cases are implemented using the SDK {@link ReduceFnRunner}.
+ *
+ */
+class GroupByKeyTranslatorBatch
+ extends TransformTranslator<
+ PCollection>, PCollection>>, GroupByKey> {
+
+ /** Literal of binary encoded Pane info. */
+ private static final Column PANE_NO_FIRING = lit(toByteArray(NO_FIRING, PaneInfoCoder.of()));
+
+ /** Defaults for value in single global window. */
+ private static final List GLOBAL_WINDOW_DETAILS =
+ windowDetails(lit(new byte[][] {EMPTY_BYTE_ARRAY}));
+
+ GroupByKeyTranslatorBatch() {
+ super(0.2f);
+ }
+
+ @Override
+ public void translate(GroupByKey transform, Context cxt) {
+ WindowingStrategy, ?> windowing = cxt.getInput().getWindowingStrategy();
+ TimestampCombiner tsCombiner = windowing.getTimestampCombiner();
+
+ Dataset>> input = cxt.getDataset(cxt.getInput());
+
+ KvCoder inputCoder = (KvCoder) cxt.getInput().getCoder();
+ KvCoder> outputCoder = (KvCoder>) cxt.getOutput().getCoder();
+
+ Encoder valueEnc = cxt.valueEncoderOf(inputCoder);
+ Encoder keyEnc = cxt.keyEncoderOf(inputCoder);
+
+ // In batch we can ignore triggering and allowed lateness parameters
+ final Dataset>>> result;
+
+ boolean useCollectList =
+ !cxt.getOptions()
+ .as(SparkCommonPipelineOptions.class)
+ .getPreferGroupByKeyToHandleHugeValues();
+ if (useCollectList && eligibleForGlobalGroupBy(windowing, false)) {
+ // Collects all values per key in memory. This might be problematic if there's
+ // few keys only
+ // or some highly skewed distribution.
+ result =
+ input
+ .groupBy(col("value.key").as("key"))
+ .agg(collect_list(col("value.value")).as("values"), timestampAggregator(tsCombiner))
+ .select(
+ inGlobalWindow(
+ keyValue(col("key").as(keyEnc), col("values").as(iterableEnc(valueEnc))),
+ windowTimestamp(tsCombiner)));
+
+ } else if (eligibleForGlobalGroupBy(windowing, true)) {
+ // Produces an iterable that can be traversed exactly once. However, on the plus
+ // side, data is
+ // not collected in memory until serialized or done by the user.
+ result =
+ cxt.getDataset(cxt.getInput())
+ .groupByKey(valueKey(), keyEnc)
+ .mapValues(valueValue(), cxt.valueEncoderOf(inputCoder))
+ .mapGroups(fun2((k, it) -> KV.of(k, iterableOnce(it))), cxt.kvEncoderOf(outputCoder))
+ .map(fun1(WindowedValues::valueInGlobalWindow), cxt.windowedEncoder(outputCoder));
+
+ } else if (useCollectList
+ && eligibleForGroupByWindow(windowing, false)
+ && (windowing.getWindowFn().assignsToOneWindow() || transform.fewKeys())) {
+ // Using the window as part of the key should help to better distribute the
+ // data. However, if
+ // values are assigned to multiple windows, more data would be shuffled around.
+ // If there's few
+ // keys only, this is still valuable.
+ // Collects all values per key & window in memory.
+ result =
+ input
+ .select(explode(col("windows")).as("window"), col("value"), col("timestamp"))
+ .groupBy(col("value.key").as("key"), col("window"))
+ .agg(collect_list(col("value.value")).as("values"), timestampAggregator(tsCombiner))
+ .select(
+ inSingleWindow(
+ keyValue(col("key").as(keyEnc), col("values").as(iterableEnc(valueEnc))),
+ col("window").as(cxt.windowEncoder()),
+ windowTimestamp(tsCombiner)));
+
+ } else if (eligibleForGroupByWindow(windowing, true)
+ && (windowing.getWindowFn().assignsToOneWindow() || transform.fewKeys())) {
+ // Using the window as part of the key should help to better distribute the
+ // data. However, if
+ // values are assigned to multiple windows, more data would be shuffled around.
+ // If there's few
+ // keys only, this is still valuable.
+ // Produces an iterable that can be traversed exactly once. However, on the plus
+ // side, data is
+ // not collected in memory until serialized or done by the user.
+ Encoder> windowedKeyEnc =
+ cxt.tupleEncoder(cxt.windowEncoder(), keyEnc);
+ result =
+ cxt.getDataset(cxt.getInput())
+ .flatMap(explodeWindowedKey(valueValue()), cxt.tupleEncoder(windowedKeyEnc, valueEnc))
+ .groupByKey(fun1(t -> t._1()), windowedKeyEnc)
+ .mapValues(fun1(t -> t._2()), valueEnc)
+ .mapGroups(
+ fun2((wKey, it) -> windowedKV(wKey, iterableOnce(it))),
+ cxt.windowedEncoder(outputCoder));
+
+ } else {
+ // Collects all values per key in memory. This might be problematic if there's
+ // few keys only
+ // or some highly skewed distribution.
+
+ // FIXME Revisit this case, implementation is far from ideal:
+ // - iterator traversed at least twice, forcing materialization in memory
+
+ // group by key, then by windows
+ result =
+ input
+ .groupByKey(valueKey(), keyEnc)
+ .flatMapGroups(
+ new GroupAlsoByWindowViaOutputBufferFn<>(
+ windowing,
+ (SerStateInternalsFactory) key -> InMemoryStateInternals.forKey(key),
+ SystemReduceFn.buffering(inputCoder.getValueCoder()),
+ cxt.getOptionsSupplier()),
+ cxt.windowedEncoder(outputCoder));
+ }
+
+ cxt.putDataset(cxt.getOutput(), result);
+ }
+
+ /** Serializable In-memory state internals factory. */
+ private interface SerStateInternalsFactory extends StateInternalsFactory, Serializable {}
+
+ private Encoder> iterableEnc(Encoder enc) {
+ // safe to use list encoder with collect list
+ return (Encoder) collectionEncoder(enc);
+ }
+
+ private static Column[] timestampAggregator(TimestampCombiner tsCombiner) {
+ if (tsCombiner.equals(TimestampCombiner.END_OF_WINDOW)) {
+ return new Column[0]; // no aggregation needed
+ }
+ Column agg =
+ tsCombiner.equals(TimestampCombiner.EARLIEST)
+ ? min(col("timestamp"))
+ : max(col("timestamp"));
+ return new Column[] {agg.as("timestamp")};
+ }
+
+ private static Column windowTimestamp(TimestampCombiner tsCombiner) {
+ if (tsCombiner.equals(TimestampCombiner.END_OF_WINDOW)) {
+ // null will be set to END_OF_WINDOW by the respective deserializer
+ return litNull(DataTypes.LongType);
+ }
+ return col("timestamp");
+ }
+
+ /**
+ * Java {@link Iterable} from Scala {@link Iterator} that can be iterated just once so that we
+ * don't have to load all data into memory.
+ */
+ private static Iterable iterableOnce(Iterator it) {
+ return () -> {
+ checkState(!it.isEmpty(), "Iterator on values can only be consumed once!");
+ return javaIterator(it);
+ };
+ }
+
+ private TypedColumn, KV> keyValue(TypedColumn, K> key, TypedColumn, T> value) {
+ return struct(key.as("key"), value.as("value")).as(kvEncoder(key.encoder(), value.encoder()));
+ }
+
+ private static TypedColumn> inGlobalWindow(
+ TypedColumn, T> value, Column ts) {
+ List fields = concat(timestampedValue(value, ts), GLOBAL_WINDOW_DETAILS);
+ Encoder> enc =
+ windowedValueEncoder(value.encoder(), encoderOf(GlobalWindow.class));
+ return (TypedColumn>)
+ struct(JavaConverters.asJavaCollection(fields).toArray(new Column[0])).as(enc);
+ }
+
+ public static TypedColumn> inSingleWindow(
+ TypedColumn, T> value, TypedColumn, ? extends BoundedWindow> window, Column ts) {
+ Column windows = org.apache.spark.sql.functions.array(window);
+ List fields = concat(timestampedValue(value, ts), windowDetails(windows));
+ Encoder> enc = windowedValueEncoder(value.encoder(), window.encoder());
+ return (TypedColumn>)
+ struct(JavaConverters.asJavaCollection(fields).toArray(new Column[0])).as(enc);
+ }
+
+ private static List timestampedValue(Column value, Column ts) {
+ return seqOf(value.as("value"), ts.as("timestamp")).toList();
+ }
+
+ private static List windowDetails(Column windows) {
+ return seqOf(windows.as("windows"), PANE_NO_FIRING.as("paneInfo")).toList();
+ }
+
+ private static Column lit(T t) {
+ return org.apache.spark.sql.functions.lit(t);
+ }
+
+ @SuppressWarnings("nullness") // NULL literal
+ private static Column litNull(DataType dataType) {
+ return org.apache.spark.sql.functions.lit(null).cast(dataType);
+ }
+}
diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java
new file mode 100644
index 000000000000..6565c2a01c63
--- /dev/null
+++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java
@@ -0,0 +1,333 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
+
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.emptyList;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.replace;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf;
+
+import java.lang.reflect.Constructor;
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal;
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder;
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders;
+import org.apache.spark.sql.catalyst.encoders.AgnosticExpressionPathEncoder;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.sql.catalyst.expressions.BoundReference;
+import org.apache.spark.sql.catalyst.expressions.Expression;
+import org.apache.spark.sql.catalyst.expressions.objects.Invoke;
+import org.apache.spark.sql.catalyst.expressions.objects.NewInstance;
+import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import scala.Option;
+import scala.collection.Iterator;
+import scala.collection.immutable.Seq;
+import scala.reflect.ClassTag;
+
+public class EncoderFactory {
+ // Resolve the Scala case-class primary constructor (the one with the most parameters).
+ // Constructor ordering returned by Class.getConstructors() is JVM-defined and not stable
+ // across Spark versions, so we pick the widest constructor explicitly and then dispatch on
+ // parameter count below to pick the right argument shape per Spark version.
+ private static final Constructor STATIC_INVOKE_CONSTRUCTOR =
+ primaryConstructor(StaticInvoke.class);
+
+ private static final Constructor INVOKE_CONSTRUCTOR = primaryConstructor(Invoke.class);
+
+ private static final Constructor NEW_INSTANCE_CONSTRUCTOR =
+ primaryConstructor(NewInstance.class);
+
+ @SuppressWarnings("unchecked")
+ private static Constructor primaryConstructor(Class cls) {
+ Constructor>[] ctors = cls.getConstructors();
+ Constructor> widest = ctors[0];
+ for (int i = 1; i < ctors.length; i++) {
+ if (ctors[i].getParameterCount() > widest.getParameterCount()) {
+ widest = ctors[i];
+ }
+ }
+ return (Constructor) widest;
+ }
+
+ @SuppressWarnings({"nullness", "unchecked"})
+ static ExpressionEncoder create(
+ Expression serializer, Expression deserializer, Class super T> clazz) {
+ AgnosticEncoder agnosticEncoder = new BeamAgnosticEncoder<>(serializer, deserializer, clazz);
+ return ExpressionEncoder.apply(agnosticEncoder, serializer, deserializer);
+ }
+
+ /**
+ * An {@link AgnosticEncoder} that implements both {@link AgnosticExpressionPathEncoder} (so that
+ * {@code SerializerBuildHelper} / {@code DeserializerBuildHelper} delegate to our pre-built
+ * expressions) and {@link AgnosticEncoders.StructEncoder} (so that {@code
+ * Dataset.select(TypedColumn)} creates an N-attribute plan instead of a 1-attribute wrapped plan,
+ * preventing {@code FIELD_NUMBER_MISMATCH} errors).
+ *
+ * The {@code toCatalyst} / {@code fromCatalyst} methods substitute the {@code input}
+ * expression into the pre-built serializer / deserializer via {@code transformUp}, so that when
+ * this encoder is nested inside a composite encoder (e.g. {@code Encoders.tuple}) the correct
+ * field-level expression is used in place of the root {@code BoundReference} / {@code
+ * GetColumnByOrdinal}.
+ */
+ @SuppressWarnings({"nullness", "unchecked", "deprecation"})
+ private static final class BeamAgnosticEncoder
+ implements AgnosticExpressionPathEncoder, AgnosticEncoders.StructEncoder {
+
+ private final Expression serializer;
+ private final Expression deserializer;
+ private final Class super T> clazz;
+ private final Seq encoderFields;
+
+ BeamAgnosticEncoder(Expression serializer, Expression deserializer, Class super T> clazz) {
+ this.serializer = serializer;
+ this.deserializer = deserializer;
+ this.clazz = clazz;
+ this.encoderFields = buildFields(serializer.dataType());
+ }
+
+ private static Seq buildFields(DataType dt) {
+ if (dt instanceof StructType) {
+ StructField[] structFields = ((StructType) dt).fields();
+ List fields = new ArrayList<>(structFields.length);
+ for (StructField sf : structFields) {
+ fields.add(
+ new AgnosticEncoders.EncoderField(
+ sf.name(),
+ new FieldEncoder<>(sf.dataType(), sf.nullable()),
+ sf.nullable(),
+ sf.metadata(),
+ Option.empty(),
+ Option.empty()));
+ }
+ return seqOf(fields.toArray(new AgnosticEncoders.EncoderField[0]));
+ } else {
+ // Non-struct: wrap in a single "value" field so StructEncoder sees one field.
+ return seqOf(
+ new AgnosticEncoders.EncoderField(
+ "value",
+ new FieldEncoder<>(dt, true),
+ true,
+ Metadata.empty(),
+ Option.empty(),
+ Option.empty()));
+ }
+ }
+
+ // --- AgnosticExpressionPathEncoder ---
+
+ @Override
+ public Expression toCatalyst(Expression input) {
+ return serializer.transformUp(replace(BoundReference.class, input));
+ }
+
+ @Override
+ public Expression fromCatalyst(Expression input) {
+ return deserializer.transformUp(replace(GetColumnByOrdinal.class, input));
+ }
+
+ // --- AgnosticEncoders.StructEncoder ---
+
+ @Override
+ public Seq fields() {
+ return encoderFields;
+ }
+
+ @Override
+ public boolean isStruct() {
+ return true;
+ }
+
+ /**
+ * Setter required by the Scala compiler when implementing the {@link
+ * AgnosticEncoders.StructEncoder} trait from Java. Scala traits with concrete {@code val}
+ * fields generate a synthetic mangled setter ({@code $_setter__$eq}) that the
+ * trait's initializer invokes on subclasses. Java cannot declare {@code val} fields, so we
+ * implement {@link #isStruct()} directly above and accept-but-ignore the trait setter here. The
+ * mangled name is brittle and tied to Spark's Scala source layout — if Spark removes the {@code
+ * isStruct} field from {@code StructEncoder}, this method becomes dead code; if Spark renames
+ * it, compilation will fail and the new mangled name must be substituted.
+ */
+ @Override
+ public void
+ org$apache$spark$sql$catalyst$encoders$AgnosticEncoders$StructEncoder$_setter_$isStruct_$eq(
+ boolean v) {
+ // no-op: isStruct() is implemented directly above
+ }
+
+ // --- AgnosticEncoder / Encoder (explicit to resolve default-method ambiguity) ---
+
+ @Override
+ public boolean isPrimitive() {
+ return false;
+ }
+
+ @Override
+ public StructType schema() {
+ // Build StructType from fields — mirrors the StructEncoder.schema() default.
+ List sfs = new ArrayList<>(encoderFields.size());
+ Iterator it = encoderFields.iterator();
+ while (it.hasNext()) {
+ sfs.add(it.next().structField());
+ }
+ return new StructType(sfs.toArray(new StructField[0]));
+ }
+
+ @Override
+ public DataType dataType() {
+ return schema();
+ }
+
+ @Override
+ public ClassTag clsTag() {
+ return (ClassTag) ClassTag.apply(clazz);
+ }
+ }
+
+ /**
+ * Minimal {@link AgnosticEncoder} stub used to carry per-field {@link DataType} metadata inside
+ * {@link AgnosticEncoders.EncoderField}. The actual serialization / deserialization is handled by
+ * {@link BeamAgnosticEncoder#toCatalyst} and {@link BeamAgnosticEncoder#fromCatalyst}.
+ */
+ @SuppressWarnings({"nullness", "unchecked"})
+ private static final class FieldEncoder implements AgnosticEncoder {
+ private final DataType fieldDataType;
+ private final boolean fieldNullable;
+
+ FieldEncoder(DataType dataType, boolean nullable) {
+ this.fieldDataType = dataType;
+ this.fieldNullable = nullable;
+ }
+
+ @Override
+ public boolean isPrimitive() {
+ return false;
+ }
+
+ @Override
+ public DataType dataType() {
+ return fieldDataType;
+ }
+
+ @Override
+ public StructType schema() {
+ return new StructType().add("value", fieldDataType, fieldNullable);
+ }
+
+ @Override
+ public boolean nullable() {
+ return fieldNullable;
+ }
+
+ @Override
+ public ClassTag clsTag() {
+ return (ClassTag) ClassTag.apply(Object.class);
+ }
+ }
+
+ /**
+ * Invoke method {@code fun} on Class {@code cls}, immediately propagating {@code null} if any
+ * input arg is {@code null}.
+ */
+ static Expression invokeIfNotNull(Class> cls, String fun, DataType type, Expression... args) {
+ return invoke(cls, fun, type, true, args);
+ }
+
+ /** Invoke method {@code fun} on Class {@code cls}. */
+ static Expression invoke(Class> cls, String fun, DataType type, Expression... args) {
+ return invoke(cls, fun, type, false, args);
+ }
+
+ private static Expression invoke(
+ Class> cls, String fun, DataType type, boolean propagateNull, Expression... args) {
+ try {
+ // To address breaking interfaces between various versions of Spark, expressions are
+ // created reflectively. This is fine as it's just needed once to create the query plan.
+ switch (STATIC_INVOKE_CONSTRUCTOR.getParameterCount()) {
+ case 6:
+ // Spark 3.1.x
+ return STATIC_INVOKE_CONSTRUCTOR.newInstance(
+ cls, type, fun, seqOf(args), propagateNull, true);
+ case 7:
+ // Spark 3.2.0
+ return STATIC_INVOKE_CONSTRUCTOR.newInstance(
+ cls, type, fun, seqOf(args), emptyList(), propagateNull, true);
+ case 8:
+ // Spark 3.2.x, 3.3.x
+ return STATIC_INVOKE_CONSTRUCTOR.newInstance(
+ cls, type, fun, seqOf(args), emptyList(), propagateNull, true, true);
+ case 9:
+ // Spark 4.0.x: added Option> parameter
+ return STATIC_INVOKE_CONSTRUCTOR.newInstance(
+ cls, type, fun, seqOf(args), emptyList(), propagateNull, true, true, Option.empty());
+ default:
+ throw new RuntimeException("Unsupported version of Spark");
+ }
+ } catch (IllegalArgumentException | ReflectiveOperationException ex) {
+ throw new RuntimeException(ex);
+ }
+ }
+
+ /** Invoke method {@code fun} on {@code obj} with provided {@code args}. */
+ static Expression invoke(
+ Expression obj, String fun, DataType type, boolean nullable, Expression... args) {
+ try {
+ // To address breaking interfaces between various versions of Spark, expressions are
+ // created reflectively. This is fine as it's just needed once to create the query plan.
+ switch (INVOKE_CONSTRUCTOR.getParameterCount()) {
+ case 6:
+ // Spark 3.1.x
+ return INVOKE_CONSTRUCTOR.newInstance(obj, fun, type, seqOf(args), false, nullable);
+ case 7:
+ // Spark 3.2.0
+ return INVOKE_CONSTRUCTOR.newInstance(
+ obj, fun, type, seqOf(args), emptyList(), false, nullable);
+ case 8:
+ // Spark 3.2.x, 3.3.x, 4.0.x: Invoke constructor is 8 params across all these versions
+ return INVOKE_CONSTRUCTOR.newInstance(
+ obj, fun, type, seqOf(args), emptyList(), false, nullable, true);
+ default:
+ throw new RuntimeException("Unsupported version of Spark");
+ }
+ } catch (IllegalArgumentException | ReflectiveOperationException ex) {
+ throw new RuntimeException(ex);
+ }
+ }
+
+ static Expression newInstance(Class> cls, DataType type, Expression... args) {
+ try {
+ // To address breaking interfaces between various versions of Spark, expressions are
+ // created reflectively. This is fine as it's just needed once to create the query plan.
+ switch (NEW_INSTANCE_CONSTRUCTOR.getParameterCount()) {
+ case 5:
+ return NEW_INSTANCE_CONSTRUCTOR.newInstance(cls, seqOf(args), true, type, Option.empty());
+ case 6:
+ // Spark 3.2.x, 3.3.x, 4.0.x: added immutable.Seq parameter
+ return NEW_INSTANCE_CONSTRUCTOR.newInstance(
+ cls, seqOf(args), emptyList(), true, type, Option.empty());
+ default:
+ throw new RuntimeException("Unsupported version of Spark");
+ }
+ } catch (IllegalArgumentException | ReflectiveOperationException ex) {
+ throw new RuntimeException(ex);
+ }
+ }
+}
diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
new file mode 100644
index 000000000000..173b4653a19b
--- /dev/null
+++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
@@ -0,0 +1,617 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
+
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderFactory.invoke;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderFactory.invokeIfNotNull;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.match;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.replace;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple;
+import static org.apache.spark.sql.types.DataTypes.BinaryType;
+import static org.apache.spark.sql.types.DataTypes.IntegerType;
+import static org.apache.spark.sql.types.DataTypes.LongType;
+
+import java.math.BigDecimal;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow.IntervalWindowCoder;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.WindowedValue;
+import org.apache.beam.sdk.values.WindowedValues;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.catalyst.SerializerBuildHelper;
+import org.apache.spark.sql.catalyst.SerializerBuildHelper.MapElementInformation;
+import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal;
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.sql.catalyst.expressions.BoundReference;
+import org.apache.spark.sql.catalyst.expressions.Coalesce;
+import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct;
+import org.apache.spark.sql.catalyst.expressions.EqualTo;
+import org.apache.spark.sql.catalyst.expressions.Expression;
+import org.apache.spark.sql.catalyst.expressions.GetStructField;
+import org.apache.spark.sql.catalyst.expressions.If;
+import org.apache.spark.sql.catalyst.expressions.IsNotNull;
+import org.apache.spark.sql.catalyst.expressions.IsNull;
+import org.apache.spark.sql.catalyst.expressions.Literal;
+import org.apache.spark.sql.catalyst.expressions.Literal$;
+import org.apache.spark.sql.catalyst.expressions.MapKeys;
+import org.apache.spark.sql.catalyst.expressions.MapValues;
+import org.apache.spark.sql.catalyst.expressions.objects.MapObjects$;
+import org.apache.spark.sql.catalyst.util.ArrayData;
+import org.apache.spark.sql.types.ArrayType;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.MapType;
+import org.apache.spark.sql.types.ObjectType;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.util.MutablePair;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.joda.time.Instant;
+import scala.Option;
+import scala.Some;
+import scala.Tuple2;
+import scala.collection.IndexedSeq;
+import scala.collection.JavaConverters;
+import scala.collection.Seq;
+
+/** {@link Encoders} utility class. */
+public class EncoderHelpers {
+ private static final DataType OBJECT_TYPE = new ObjectType(Object.class);
+ private static final DataType TUPLE2_TYPE = new ObjectType(Tuple2.class);
+ private static final DataType WINDOWED_VALUE = new ObjectType(WindowedValue.class);
+ private static final DataType KV_TYPE = new ObjectType(KV.class);
+ private static final DataType MUTABLE_PAIR_TYPE = new ObjectType(MutablePair.class);
+ private static final DataType LIST_TYPE = new ObjectType(List.class);
+
+ // Collections / maps of these types can be (de)serialized without (de)serializing each member
+ private static final Set> PRIMITIVE_TYPES =
+ ImmutableSet.of(
+ Boolean.class,
+ Byte.class,
+ Short.class,
+ Integer.class,
+ Long.class,
+ Float.class,
+ Double.class);
+
+ // Default encoders by class
+ private static final Map, Encoder>> DEFAULT_ENCODERS = new ConcurrentHashMap<>();
+
+ // Factory for default encoders by class
+ private static @Nullable Encoder> encoderFactory(Class> cls) {
+ if (cls.equals(PaneInfo.class)) {
+ return paneInfoEncoder();
+ } else if (cls.equals(GlobalWindow.class)) {
+ return binaryEncoder(GlobalWindow.Coder.INSTANCE, false);
+ } else if (cls.equals(IntervalWindow.class)) {
+ return binaryEncoder(IntervalWindowCoder.of(), false);
+ } else if (cls.equals(Instant.class)) {
+ return instantEncoder();
+ } else if (cls.equals(String.class)) {
+ return Encoders.STRING();
+ } else if (cls.equals(Boolean.class)) {
+ return Encoders.BOOLEAN();
+ } else if (cls.equals(Integer.class)) {
+ return Encoders.INT();
+ } else if (cls.equals(Long.class)) {
+ return Encoders.LONG();
+ } else if (cls.equals(Float.class)) {
+ return Encoders.FLOAT();
+ } else if (cls.equals(Double.class)) {
+ return Encoders.DOUBLE();
+ } else if (cls.equals(BigDecimal.class)) {
+ return Encoders.DECIMAL();
+ } else if (cls.equals(byte[].class)) {
+ return Encoders.BINARY();
+ } else if (cls.equals(Byte.class)) {
+ return Encoders.BYTE();
+ } else if (cls.equals(Short.class)) {
+ return Encoders.SHORT();
+ }
+ return null;
+ }
+
+ @SuppressWarnings({"nullness", "methodref.return"}) // computeIfAbsent allows null returns
+ private static @Nullable Encoder getOrCreateDefaultEncoder(Class super T> cls) {
+ return (Encoder) DEFAULT_ENCODERS.computeIfAbsent(cls, EncoderHelpers::encoderFactory);
+ }
+
+ /** Gets or creates a default {@link Encoder} for {@link T}. */
+ public static Encoder encoderOf(Class super T> cls) {
+ Encoder enc = getOrCreateDefaultEncoder(cls);
+ if (enc == null) {
+ throw new IllegalArgumentException("No default encoder available for class " + cls);
+ }
+ return enc;
+ }
+
+ /**
+ * Creates a Spark {@link Encoder} for {@link T} of {@link DataTypes#BinaryType BinaryType}
+ * delegating to a Beam {@link Coder} underneath.
+ *
+ * Note: For common types, if available, default Spark {@link Encoder}s are used instead.
+ *
+ * @param coder Beam {@link Coder}
+ */
+ public static Encoder encoderFor(Coder coder) {
+ Encoder enc = getOrCreateDefaultEncoder(coder.getEncodedTypeDescriptor().getRawType());
+ return enc != null ? enc : binaryEncoder(coder, true);
+ }
+
+ /**
+ * Creates a Spark {@link Encoder} for {@link T} of {@link StructType} with fields {@code value},
+ * {@code timestamp}, {@code window} and {@code pane}.
+ *
+ * @param value {@link Encoder} to encode field `{@code value}`.
+ * @param window {@link Encoder} to encode individual windows in field `{@code window}`
+ */
+ public static Encoder> windowedValueEncoder(
+ Encoder value, Encoder window) {
+ Encoder timestamp = encoderOf(Instant.class);
+ Encoder paneInfo = encoderOf(PaneInfo.class);
+ Encoder> windows = collectionEncoder(window);
+ Expression serializer =
+ serializeWindowedValue(rootRef(WINDOWED_VALUE, true), value, timestamp, windows, paneInfo);
+ Expression deserializer =
+ deserializeWindowedValue(
+ rootCol(serializer.dataType()), value, timestamp, windows, paneInfo);
+ return EncoderFactory.create(serializer, deserializer, WindowedValue.class);
+ }
+
+ /**
+ * Creates a one-of Spark {@link Encoder} of {@link StructType} where each alternative is
+ * represented as colum / field named by its index with a separate {@link Encoder} each.
+ *
+ * Externally this is represented as tuple {@code (index, data)} where an index corresponds to
+ * an {@link Encoder} in the provided list.
+ *
+ * @param encoders {@link Encoder}s for each alternative.
+ */
+ public static Encoder> oneOfEncoder(List> encoders) {
+ Expression serializer = serializeOneOf(rootRef(TUPLE2_TYPE, true), encoders);
+ Expression deserializer = deserializeOneOf(rootCol(serializer.dataType()), encoders);
+ return EncoderFactory.create(serializer, deserializer, Tuple2.class);
+ }
+
+ /**
+ * Creates a Spark {@link Encoder} for {@link KV} of {@link StructType} with fields {@code key}
+ * and {@code value}.
+ *
+ * @param key {@link Encoder} to encode field `{@code key}`.
+ * @param value {@link Encoder} to encode field `{@code value}`
+ */
+ public static Encoder> kvEncoder(Encoder key, Encoder value) {
+ Expression serializer = serializeKV(rootRef(KV_TYPE, true), key, value);
+ Expression deserializer = deserializeKV(rootCol(serializer.dataType()), key, value);
+ return EncoderFactory.create(serializer, deserializer, KV.class);
+ }
+
+ /**
+ * Creates a Spark {@link Encoder} of {@link ArrayType} for Java {@link Collection}s with nullable
+ * elements.
+ *
+ * @param enc {@link Encoder} to encode collection elements
+ */
+ public static Encoder> collectionEncoder(Encoder enc) {
+ return collectionEncoder(enc, true);
+ }
+
+ /**
+ * Creates a Spark {@link Encoder} of {@link ArrayType} for Java {@link Collection}s.
+ *
+ * @param enc {@link Encoder} to encode collection elements
+ * @param nullable Allow nullable collection elements
+ */
+ public static Encoder> collectionEncoder(Encoder enc, boolean nullable) {
+ DataType type = new ObjectType(Collection.class);
+ Expression serializer = serializeSeq(rootRef(type, true), enc, nullable);
+ Expression deserializer = deserializeSeq(rootCol(serializer.dataType()), enc, nullable, true);
+ return EncoderFactory.create(serializer, deserializer, Collection.class);
+ }
+
+ /**
+ * Creates a Spark {@link Encoder} of {@link MapType} that deserializes to {@link MapT}.
+ *
+ * @param key {@link Encoder} to encode keys
+ * @param value {@link Encoder} to encode values
+ * @param cls Specific class to use, supported are {@link HashMap} and {@link TreeMap}
+ */
+ public static , K, V> Encoder mapEncoder(
+ Encoder key, Encoder value, Class cls) {
+ Expression serializer = mapSerializer(rootRef(new ObjectType(cls), true), key, value);
+ Expression deserializer = mapDeserializer(rootCol(serializer.dataType()), key, value, cls);
+ return EncoderFactory.create(serializer, deserializer, cls);
+ }
+
+ /**
+ * Creates a Spark {@link Encoder} for Spark's {@link MutablePair} of {@link StructType} with
+ * fields `{@code _1}` and `{@code _2}`.
+ *
+ * This is intended to be used in places such as aggregators.
+ *
+ * @param enc1 {@link Encoder} to encode `{@code _1}`
+ * @param enc2 {@link Encoder} to encode `{@code _2}`
+ */
+ public static Encoder> mutablePairEncoder(
+ Encoder enc1, Encoder enc2) {
+ Expression serializer = serializeMutablePair(rootRef(MUTABLE_PAIR_TYPE, true), enc1, enc2);
+ Expression deserializer = deserializeMutablePair(rootCol(serializer.dataType()), enc1, enc2);
+ return EncoderFactory.create(serializer, deserializer, MutablePair.class);
+ }
+
+ /**
+ * Creates a Spark {@link Encoder} for {@link PaneInfo} of {@link DataTypes#BinaryType
+ * BinaryType}.
+ */
+ private static Encoder paneInfoEncoder() {
+ DataType type = new ObjectType(PaneInfo.class);
+ return EncoderFactory.create(
+ invokeIfNotNull(Utils.class, "paneInfoToBytes", BinaryType, rootRef(type, false)),
+ invokeIfNotNull(Utils.class, "paneInfoFromBytes", type, rootCol(BinaryType)),
+ PaneInfo.class);
+ }
+
+ /**
+ * Creates a Spark {@link Encoder} for Joda {@link Instant} of {@link DataTypes#LongType
+ * LongType}.
+ */
+ private static Encoder instantEncoder() {
+ DataType type = new ObjectType(Instant.class);
+ Expression instant = rootRef(type, true);
+ Expression millis = rootCol(LongType);
+ return EncoderFactory.create(
+ nullSafe(instant, invoke(instant, "getMillis", LongType, false)),
+ nullSafe(millis, invoke(Instant.class, "ofEpochMilli", type, millis)),
+ Instant.class);
+ }
+
+ /**
+ * Creates a Spark {@link Encoder} for {@link T} of {@link DataTypes#BinaryType BinaryType}
+ * delegating to a Beam {@link Coder} underneath.
+ *
+ * @param coder Beam {@link Coder}
+ * @param nullable If to allow nullable items
+ */
+ private static Encoder binaryEncoder(Coder coder, boolean nullable) {
+ Literal litCoder = lit(coder, Coder.class);
+ // T could be private, use OBJECT_TYPE for code generation to not risk an IllegalAccessError
+ return EncoderFactory.create(
+ invokeIfNotNull(
+ CoderHelpers.class,
+ "toByteArray",
+ BinaryType,
+ rootRef(OBJECT_TYPE, nullable),
+ litCoder),
+ invokeIfNotNull(
+ CoderHelpers.class, "fromByteArray", OBJECT_TYPE, rootCol(BinaryType), litCoder),
+ coder.getEncodedTypeDescriptor().getRawType());
+ }
+
+ private static Expression serializeWindowedValue(
+ Expression in,
+ Encoder valueEnc,
+ Encoder timestampEnc,
+ Encoder> windowsEnc,
+ Encoder paneEnc) {
+ return serializerObject(
+ in,
+ tuple("value", serializeField(in, valueEnc, "getValue")),
+ tuple("timestamp", serializeField(in, timestampEnc, "getTimestamp")),
+ tuple("windows", serializeField(in, windowsEnc, "getWindows")),
+ tuple("paneInfo", serializeField(in, paneEnc, "getPaneInfo")));
+ }
+
+ private static Expression serializerObject(Expression in, Tuple2... fields) {
+ return SerializerBuildHelper.createSerializerForObject(in, seqOf(fields));
+ }
+
+ private static Expression deserializeWindowedValue(
+ Expression in,
+ Encoder valueEnc,
+ Encoder timestampEnc,
+ Encoder> windowsEnc,
+ Encoder paneEnc) {
+ Expression value = deserializeField(in, valueEnc, 0, "value");
+ Expression windows = deserializeField(in, windowsEnc, 2, "windows");
+ Expression timestamp = deserializeField(in, timestampEnc, 1, "timestamp");
+ Expression paneInfo = deserializeField(in, paneEnc, 3, "paneInfo");
+ // set timestamp to end of window (maxTimestamp) if null
+ timestamp =
+ ifNotNull(timestamp, invoke(Utils.class, "maxTimestamp", timestamp.dataType(), windows));
+ Expression[] fields = new Expression[] {value, timestamp, windows, paneInfo};
+
+ return nullSafe(paneInfo, invoke(WindowedValues.class, "of", WINDOWED_VALUE, fields));
+ }
+
+ private static Expression serializeMutablePair(
+ Expression in, Encoder enc1, Encoder enc2) {
+ return serializerObject(
+ in,
+ tuple("_1", serializeField(in, enc1, "_1")),
+ tuple("_2", serializeField(in, enc2, "_2")));
+ }
+
+ private static Expression deserializeMutablePair(
+ Expression in, Encoder enc1, Encoder enc2) {
+ Expression field1 = deserializeField(in, enc1, 0, "_1");
+ Expression field2 = deserializeField(in, enc2, 1, "_2");
+ return invoke(MutablePair.class, "apply", MUTABLE_PAIR_TYPE, field1, field2);
+ }
+
+ private static Expression serializeKV(
+ Expression in, Encoder keyEnc, Encoder valueEnc) {
+ return serializerObject(
+ in,
+ tuple("key", serializeField(in, keyEnc, "getKey")),
+ tuple("value", serializeField(in, valueEnc, "getValue")));
+ }
+
+ private static Expression deserializeKV(
+ Expression in, Encoder keyEnc, Encoder valueEnc) {
+ Expression key = deserializeField(in, keyEnc, 0, "key");
+ Expression value = deserializeField(in, valueEnc, 1, "value");
+ return invoke(KV.class, "of", KV_TYPE, key, value);
+ }
+
+ public static Expression serializeOneOf(Expression in, List> encoders) {
+ Expression type = invoke(in, "_1", IntegerType, false);
+ Expression[] args = new Expression[encoders.size() * 2];
+ for (int i = 0; i < encoders.size(); i++) {
+ args[i * 2] = lit(String.valueOf(i));
+ args[i * 2 + 1] = serializeOneOfField(in, type, encoders.get(i), i);
+ }
+ return new CreateNamedStruct(seqOf(args));
+ }
+
+ public static Expression deserializeOneOf(Expression in, List> encoders) {
+ Expression[] args = new Expression[encoders.size()];
+ for (int i = 0; i < encoders.size(); i++) {
+ args[i] = deserializeOneOfField(in, encoders.get(i), i);
+ }
+ return new Coalesce(seqOf(args));
+ }
+
+ private static