diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.json new file mode 100644 index 000000000000..c4edaa85a89d --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.json @@ -0,0 +1,3 @@ +{ + "comment": "Modify this file in a trivial way to cause this test suite to run" +} diff --git a/.github/trigger_files/beam_PreCommit_Java_Spark4_Versions.json b/.github/trigger_files/beam_PreCommit_Java_Spark4_Versions.json new file mode 100644 index 000000000000..c4edaa85a89d --- /dev/null +++ b/.github/trigger_files/beam_PreCommit_Java_Spark4_Versions.json @@ -0,0 +1,3 @@ +{ + "comment": "Modify this file in a trivial way to cause this test suite to run" +} diff --git a/.github/workflows/README.md b/.github/workflows/README.md index f8c77b7180db..b095bc2a2447 100644 --- a/.github/workflows/README.md +++ b/.github/workflows/README.md @@ -373,6 +373,7 @@ PostCommit Jobs run in a schedule against master branch and generally do not get | [ PostCommit Java ValidatesRunner Spark Java8 ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Spark_Java8.yml) | N/A |`beam_PostCommit_Java_ValidatesRunner_Spark_Java8.json`| [![.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark_Java8.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Spark_Java8.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Spark_Java8.yml?query=event%3Aschedule) | | [ PostCommit Java ValidatesRunner Spark ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Spark.yml) | N/A |`beam_PostCommit_Java_ValidatesRunner_Spark.json`| [![.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Spark.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Spark.yml?query=event%3Aschedule) | | [ PostCommit Java ValidatesRunner SparkStructuredStreaming ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.yml) | N/A |`beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.json`| [![.github/workflows/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.yml?query=event%3Aschedule) | +| [ PostCommit Java ValidatesRunner Spark4StructuredStreaming ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.yml) | N/A |`beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.json`| [![.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.yml?query=event%3Aschedule) | | [ PostCommit Java ValidatesRunner Twister2 ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Twister2.yml) | N/A |`beam_PostCommit_Java_ValidatesRunner_Twister2.json`| [![.github/workflows/beam_PostCommit_Java_ValidatesRunner_Twister2.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Twister2.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Twister2.yml?query=event%3Aschedule) | | [ PostCommit Java ValidatesRunner ULR ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_ULR.yml) | N/A |`beam_PostCommit_Java_ValidatesRunner_ULR.json`| [![.github/workflows/beam_PostCommit_Java_ValidatesRunner_ULR.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_ULR.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_ULR.yml?query=event%3Aschedule) | | [ PostCommit Java ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java.yml) | N/A |`beam_PostCommit_Java.json`| [![.github/workflows/beam_PostCommit_Java.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java.yml?query=event%3Aschedule) | diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.yml new file mode 100644 index 000000000000..b595afe6f42c --- /dev/null +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.yml @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: PostCommit Java ValidatesRunner Spark4 StructuredStreaming + +on: + schedule: + - cron: '45 4/6 * * *' + pull_request_target: + paths: ['release/trigger_all_tests.json', '.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.json'] + workflow_dispatch: + +#Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event +permissions: + actions: write + pull-requests: write + checks: write + contents: read + deployments: read + id-token: none + issues: write + discussions: read + packages: read + pages: read + repository-projects: read + security-events: read + statuses: read + +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.pull_request.number || github.sha || github.head_ref || github.ref }}-${{ github.event.schedule || github.event.comment.id || github.event.sender.login }}' + cancel-in-progress: true + +env: + DEVELOCITY_ACCESS_KEY: ${{ secrets.DEVELOCITY_ACCESS_KEY }} + GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }} + GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} + +jobs: + beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming: + name: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) + runs-on: [self-hosted, ubuntu-24.04, main] + timeout-minutes: 120 + strategy: + matrix: + job_name: [beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming] + job_phrase: [Run Spark4 StructuredStreaming ValidatesRunner] + if: | + github.event_name == 'workflow_dispatch' || + github.event_name == 'pull_request_target' || + (github.event_name == 'schedule' && github.repository == 'apache/beam') || + github.event.comment.body == 'Run Spark4 StructuredStreaming ValidatesRunner' + steps: + - uses: actions/checkout@v4 + - name: Setup repository + uses: ./.github/actions/setup-action + with: + comment_phrase: ${{ matrix.job_phrase }} + github_token: ${{ secrets.GITHUB_TOKEN }} + github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) + - name: Setup environment + uses: ./.github/actions/setup-environment-action + with: + java-version: '17' + - name: run validatesStructuredStreamingRunnerBatch script + uses: ./.github/actions/gradle-command-self-hosted-action + with: + gradle-command: :runners:spark:4:validatesStructuredStreamingRunnerBatch + arguments: | + -PtestJavaVersion=17 \ + -PdisableSpotlessCheck=true \ + - name: Archive JUnit Test Results + uses: actions/upload-artifact@v4 + if: ${{ !success() }} + with: + name: JUnit Test Results + path: "**/build/reports/tests/" + - name: Publish JUnit Test Results + uses: EnricoMi/publish-unit-test-result-action@v2 + if: always() + with: + commit: '${{ env.prsha || env.GITHUB_SHA }}' + comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} + files: '**/build/test-results/**/*.xml' + large_files: true diff --git a/CHANGES.md b/CHANGES.md index e8e11e830d14..f092c81f92fb 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -60,7 +60,6 @@ ## Highlights * New highly anticipated feature X added to Python SDK ([#X](https://github.com/apache/beam/issues/X)). -* New highly anticipated feature Y added to Java SDK ([#Y](https://github.com/apache/beam/issues/Y)). ## I/Os diff --git a/gradle.properties b/gradle.properties index 2cff0b656fa7..04c85957e4eb 100644 --- a/gradle.properties +++ b/gradle.properties @@ -41,6 +41,6 @@ docker_image_default_repo_prefix=beam_ # supported flink versions flink_versions=1.17,1.18,1.19,1.20,2.0 # supported spark versions -spark_versions=3 +spark_versions=3,4 # supported python versions python_versions=3.10,3.11,3.12,3.13,3.14 diff --git a/runners/spark/4/README.md b/runners/spark/4/README.md new file mode 100644 index 000000000000..371a44512657 --- /dev/null +++ b/runners/spark/4/README.md @@ -0,0 +1,73 @@ + +# Apache Beam Spark 4 Runner + +Experimental Beam runner for Apache Spark 4 (batch-only). Built on the shared +`runners/spark` source base via `spark_runner.gradle`'s per-version +source-overrides mechanism: this module contributes the small set of files +under `src/main/java/.../structuredstreaming/` that diverge from the Spark 3 +implementation. See the parent `runners/spark/` module for the bulk of the +runner code. + +## Requirements + +* **Spark 4.0.2** (and other Spark 4.0.x patch releases) +* **Scala 2.13** +* **Java 17** — Spark 4 does not run on earlier JDKs + +## Status + +Batch only. Streaming is tracked in +[#36841](https://github.com/apache/beam/issues/36841). + +## Known issues + +### `StackOverflowError` from `slf4j-jdk14` on the runtime classpath + +Spark 4 ships `org.slf4j:jul-to-slf4j` to route `java.util.logging` records +into SLF4J. If `org.slf4j:slf4j-jdk14` is also resolved at runtime — it routes +the other direction (SLF4J → JUL) — the first log line creates an infinite +loop: + +``` +java.lang.StackOverflowError + at org.slf4j.bridge.SLF4JBridgeHandler.publish(...) + at java.util.logging.Logger.log(...) + at org.slf4j.impl.JDK14LoggerAdapter.log(...) + at org.slf4j.bridge.SLF4JBridgeHandler.publish(...) + ... +``` + +This is the same condition that broke the Spark 3 runner in +[#26985](https://github.com/apache/beam/issues/26985), fixed in +[#27001](https://github.com/apache/beam/pull/27001). + +The shared `spark_runner.gradle` already excludes `slf4j-jdk14` from the +runner module's own `configurations.all`, so in-tree builds are unaffected. +Downstream Gradle consumers that assemble a runtime classpath against +`beam-runners-spark-4` should mirror that exclude: + +```groovy +configurations.all { + exclude group: "org.slf4j", module: "slf4j-jdk14" +} +``` + +For Maven, exclude `org.slf4j:slf4j-jdk14` from any dependency that pulls it +transitively (commonly the Beam SDK harness and several IO connectors). diff --git a/runners/spark/4/build.gradle b/runners/spark/4/build.gradle new file mode 100644 index 000000000000..01fb3680b078 --- /dev/null +++ b/runners/spark/4/build.gradle @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +def basePath = '..' +/* All properties required for loading the Spark build script */ +project.ext { + spark_major = '4' + // Spark 4 version as defined in BeamModulePlugin; requires Scala 2.13 and Java 17 + spark_version = spark4_version + spark_scala_version = '2.13' + archives_base_name = 'beam-runners-spark-4' +} + +// Load the main build script which contains all build logic. +// spark_runner.gradle handles the per-version source-overrides Copy: +// shared base (runners/spark/src/) + previous majors + this module's ./src/ are +// merged into build/source-overrides/src using DuplicatesStrategy.INCLUDE so the +// 11 files under runners/spark/4/src/.../structuredstreaming/ override the +// shared-base versions. +apply from: "$basePath/spark_runner.gradle" + +// Spark 4 always requires Java 17, so unconditionally add the --add-opens flags +// required by Kryo and other libraries that use reflection on JDK internals. +test { + jvmArgs "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED", + "--add-opens=java.base/java.nio=ALL-UNNAMED", + "--add-opens=java.base/java.util=ALL-UNNAMED", + "--add-opens=java.base/java.lang.invoke=ALL-UNNAMED" +} + +// Exclude DStream-based streaming tests from the shared-base copy: the Spark 4 module +// supports only structured streaming (batch) and does not include legacy DStream support. +// Streaming test utilities also depend on kafka.server.KafkaServerStartable which was +// removed in Kafka 2.8.0 (the first Kafka version with a _2.13 artifact). +tasks.named("copyTestSourceOverrides") { + exclude "**/translation/streaming/**" +} + diff --git a/runners/spark/4/job-server/build.gradle b/runners/spark/4/job-server/build.gradle new file mode 100644 index 000000000000..598cf3b4913a --- /dev/null +++ b/runners/spark/4/job-server/build.gradle @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +def basePath = '../../job-server' + +project.ext { + // Look for the source code in the parent module + main_source_dirs = ["$basePath/src/main/java"] + test_source_dirs = ["$basePath/src/test/java"] + main_resources_dirs = ["$basePath/src/main/resources"] + test_resources_dirs = ["$basePath/src/test/resources"] + archives_base_name = 'beam-runners-spark-4-job-server' +} + +// Load the main build script which contains all build logic. +apply from: "$basePath/spark_job_server.gradle" diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java new file mode 100644 index 000000000000..d32dc14eccc0 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java @@ -0,0 +1,332 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.io; + +import static java.util.stream.Collectors.toList; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.emptyList; +import static org.apache.beam.sdk.values.WindowedValues.timestampedValueInGlobalWindow; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import static scala.collection.JavaConverters.asScalaIterator; + +import java.io.Closeable; +import java.io.IOException; +import java.io.Serializable; +import java.util.List; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.IntSupplier; +import java.util.function.Supplier; +import javax.annotation.CheckForNull; +import javax.annotation.Nullable; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.io.BoundedSource.BoundedReader; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; +import org.apache.spark.InterruptibleIterator; +import org.apache.spark.Partition; +import org.apache.spark.SparkContext; +import org.apache.spark.TaskContext; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Serializer; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.classic.Dataset$; +import org.apache.spark.sql.connector.catalog.SupportsRead; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import scala.Option; +import scala.collection.Iterator; +import scala.reflect.ClassTag; + +public class BoundedDatasetFactory { + private BoundedDatasetFactory() {} + + /** + * Create a {@link Dataset} for a {@link BoundedSource} via a Spark {@link Table}. + * + *

Unfortunately tables are expected to return an {@link InternalRow}, requiring serialization. + * This makes this approach at the time being significantly less performant than creating a + * dataset from an RDD. + */ + public static Dataset> createDatasetFromRows( + SparkSession session, + BoundedSource source, + Supplier options, + Encoder> encoder) { + Params params = new Params<>(encoder, options, session.sparkContext().defaultParallelism()); + BeamTable table = new BeamTable<>(source, params); + LogicalPlan logicalPlan = DataSourceV2Relation.create(table, Option.empty(), Option.empty()); + // In Spark 4.0+, Dataset$ moved to org.apache.spark.sql.classic and its ofRows() now + // takes the classic SparkSession subclass. The runtime instance returned by + // SparkSession.builder() is always a classic.SparkSession, so the downcast is safe and + // avoids reflection. + return (Dataset>) + Dataset$.MODULE$ + .ofRows((org.apache.spark.sql.classic.SparkSession) session, logicalPlan) + .as(encoder); + } + + /** + * Create a {@link Dataset} for a {@link BoundedSource} via a Spark {@link RDD}. + * + *

This is currently the most efficient approach as it avoid any serialization overhead. + */ + public static Dataset> createDatasetFromRDD( + SparkSession session, + BoundedSource source, + Supplier options, + Encoder> encoder) { + Params params = new Params<>(encoder, options, session.sparkContext().defaultParallelism()); + RDD> rdd = new BoundedRDD<>(session.sparkContext(), source, params); + return session.createDataset(rdd, encoder); + } + + /** An {@link RDD} for a bounded Beam source. */ + private static class BoundedRDD extends RDD> { + final BoundedSource source; + final Params params; + + public BoundedRDD(SparkContext sc, BoundedSource source, Params params) { + super(sc, emptyList(), ClassTag.apply(WindowedValue.class)); + this.source = source; + this.params = params; + } + + @Override + public Iterator> compute(Partition split, TaskContext context) { + return new InterruptibleIterator<>( + context, + asScalaIterator(new SourcePartitionIterator<>((SourcePartition) split, params))); + } + + @Override + public Partition[] getPartitions() { + return SourcePartition.partitionsOf(source, params).toArray(new Partition[0]); + } + } + + /** A Spark {@link Table} for a bounded Beam source supporting batch reads only. */ + private static class BeamTable implements Table, SupportsRead { + final BoundedSource source; + final Params params; + + BeamTable(BoundedSource source, Params params) { + this.source = source; + this.params = params; + } + + public Encoder> getEncoder() { + return params.encoder; + } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap ignored) { + return () -> + new Scan() { + @Override + public StructType readSchema() { + return params.encoder.schema(); + } + + @Override + public Batch toBatch() { + return new BeamBatch<>(source, params); + } + }; + } + + @Override + public String name() { + return "BeamSource<" + source.getClass().getName() + ">"; + } + + @Override + public StructType schema() { + return params.encoder.schema(); + } + + @Override + public Set capabilities() { + return ImmutableSet.of(TableCapability.BATCH_READ); + } + + private static class BeamBatch implements Batch, Serializable { + final BoundedSource source; + final Params params; + + private BeamBatch(BoundedSource source, Params params) { + this.source = source; + this.params = params; + } + + @Override + public InputPartition[] planInputPartitions() { + return SourcePartition.partitionsOf(source, params).toArray(new InputPartition[0]); + } + + @Override + public PartitionReaderFactory createReaderFactory() { + return p -> new BeamPartitionReader<>(((SourcePartition) p), params); + } + } + + private static class BeamPartitionReader implements PartitionReader { + final SourcePartitionIterator iterator; + final Serializer> serializer; + transient @Nullable InternalRow next; + + BeamPartitionReader(SourcePartition partition, Params params) { + iterator = new SourcePartitionIterator<>(partition, params); + serializer = ((ExpressionEncoder>) params.encoder).createSerializer(); + } + + @Override + public boolean next() throws IOException { + if (iterator.hasNext()) { + next = serializer.apply(iterator.next()); + return true; + } + return false; + } + + @Override + public InternalRow get() { + if (next == null) { + throw new IllegalStateException("Next not available"); + } + return next; + } + + @Override + public void close() throws IOException { + next = null; + iterator.close(); + } + } + } + + /** A Spark partition wrapping the partitioned Beam {@link BoundedSource}. */ + private static class SourcePartition implements Partition, InputPartition { + final BoundedSource source; + final int index; + + SourcePartition(BoundedSource source, IntSupplier idxSupplier) { + this.source = source; + this.index = idxSupplier.getAsInt(); + } + + static List> partitionsOf(BoundedSource source, Params params) { + try { + PipelineOptions options = params.options.get(); + long desiredSize = source.getEstimatedSizeBytes(options) / params.numPartitions; + List> split = source.split(desiredSize, options); + IntSupplier idxSupplier = new AtomicInteger(0)::getAndIncrement; + return split.stream().map(s -> new SourcePartition<>(s, idxSupplier)).collect(toList()); + } catch (Exception e) { + throw new RuntimeException( + "Error splitting BoundedSource " + source.getClass().getCanonicalName(), e); + } + } + + @Override + public int index() { + return index; + } + + @Override + public int hashCode() { + return index; + } + } + + /** A partition iterator on a partitioned Beam {@link BoundedSource}. */ + private static class SourcePartitionIterator extends AbstractIterator> + implements Closeable { + BoundedReader reader; + boolean started = false; + + public SourcePartitionIterator(SourcePartition partition, Params params) { + try { + reader = partition.source.createReader(params.options.get()); + } catch (IOException e) { + throw new RuntimeException("Failed to create reader from a BoundedSource.", e); + } + } + + @Override + @SuppressWarnings("nullness") // ok, reader not used any longer + public void close() throws IOException { + if (reader != null) { + try { + reader.close(); + } finally { + reader = null; + } + } + } + + @Override + protected @CheckForNull WindowedValue computeNext() { + try { + if (started ? reader.advance() : start()) { + return timestampedValueInGlobalWindow(reader.getCurrent(), reader.getCurrentTimestamp()); + } else { + close(); + return endOfData(); + } + } catch (IOException e) { + throw new RuntimeException("Failed to start or advance reader.", e); + } + } + + private boolean start() throws IOException { + started = true; + return reader.start(); + } + } + + /** Shared parameters. */ + private static class Params implements Serializable { + final Encoder> encoder; + final Supplier options; + final int numPartitions; + + Params( + Encoder> encoder, Supplier options, int numPartitions) { + checkArgument(numPartitions > 0, "Number of partitions must be greater than zero."); + this.encoder = encoder; + this.options = options; + this.numPartitions = numPartitions; + } + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java new file mode 100644 index 000000000000..e483c2db0df4 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGlobalGroupBy; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGroupByWindow; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.explodeWindowedKey; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.value; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueKey; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueValue; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.windowedKV; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1; + +import java.util.Collection; +import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.Combine.CombineFn; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.expressions.Aggregator; +import scala.Tuple2; +import scala.collection.IterableOnce; + +/** + * Translator for {@link Combine.PerKey} using {@link Dataset#groupByKey} with a Spark {@link + * Aggregator}. + * + *

+ * + * TODOs: + *
  • combine with context (CombineFnWithContext)? + *
  • combine with sideInputs? + *
  • are there other missing features? + */ +class CombinePerKeyTranslatorBatch + extends TransformTranslator< + PCollection>, PCollection>, Combine.PerKey> { + + CombinePerKeyTranslatorBatch() { + super(0.2f); + } + + @Override + public void translate(Combine.PerKey transform, Context cxt) { + WindowingStrategy windowing = cxt.getInput().getWindowingStrategy(); + CombineFn combineFn = (CombineFn) transform.getFn(); + + KvCoder inputCoder = (KvCoder) cxt.getInput().getCoder(); + KvCoder outputCoder = (KvCoder) cxt.getOutput().getCoder(); + + Encoder keyEnc = cxt.keyEncoderOf(inputCoder); + Encoder> inputEnc = cxt.encoderOf(inputCoder); + Encoder>> wvOutputEnc = cxt.windowedEncoder(outputCoder); + Encoder accumEnc = accumEncoder(combineFn, inputCoder.getValueCoder(), cxt); + + final Dataset>> result; + + boolean globalGroupBy = eligibleForGlobalGroupBy(windowing, true); + boolean groupByWindow = eligibleForGroupByWindow(windowing, true); + + if (globalGroupBy || groupByWindow) { + Aggregator, ?, OutT> valueAgg = + Aggregators.value(combineFn, KV::getValue, accumEnc, cxt.valueEncoderOf(outputCoder)); + + if (globalGroupBy) { + // Drop window and group by key globally to run the aggregation (combineFn), afterwards the + // global window is restored + result = + cxt.getDataset(cxt.getInput()) + .groupByKey(valueKey(), keyEnc) + .mapValues(value(), inputEnc) + .agg(valueAgg.toColumn()) + .map(globalKV(), wvOutputEnc); + } else { + Encoder> windowedKeyEnc = + cxt.tupleEncoder(cxt.windowEncoder(), keyEnc); + + // Group by window and key to run the aggregation (combineFn) + result = + cxt.getDataset(cxt.getInput()) + .flatMap(explodeWindowedKey(value()), cxt.tupleEncoder(windowedKeyEnc, inputEnc)) + .groupByKey(fun1(Tuple2::_1), windowedKeyEnc) + .mapValues(fun1(Tuple2::_2), inputEnc) + .agg(valueAgg.toColumn()) + .map(windowedKV(), wvOutputEnc); + } + } else { + // Optimized aggregator for non-merging and session window functions, all others depend on + // windowFn.mergeWindows + Aggregator>, ?, Collection>> aggregator = + Aggregators.windowedValue( + combineFn, + valueValue(), + windowing, + cxt.windowEncoder(), + accumEnc, + cxt.windowedEncoder(outputCoder.getValueCoder())); + result = + cxt.getDataset(cxt.getInput()) + .groupByKey(valueKey(), keyEnc) + .agg(aggregator.toColumn()) + .flatMap(explodeWindows(), wvOutputEnc); + } + + cxt.putDataset(cxt.getOutput(), result); + } + + private static + Fun1>>, IterableOnce>>> + explodeWindows() { + return t -> + ScalaInterop.scalaIterator(t._2).map(wv -> wv.withValue(KV.of(t._1, wv.getValue()))); + } + + private static Fun1, WindowedValue>> globalKV() { + return t -> WindowedValues.valueInGlobalWindow(KV.of(t._1, t._2)); + } + + private Encoder accumEncoder( + CombineFn fn, Coder valueCoder, Context cxt) { + try { + CoderRegistry registry = cxt.getInput().getPipeline().getCoderRegistry(); + return cxt.encoderOf(fn.getAccumulatorCoder(registry, valueCoder)); + } catch (CannotProvideCoderException e) { + throw new RuntimeException(e); + } + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyHelpers.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyHelpers.java new file mode 100644 index 000000000000..f25121e1b478 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyHelpers.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple; +import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING; +import static org.apache.beam.sdk.transforms.windowing.TimestampCombiner.END_OF_WINDOW; + +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindows; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.sdk.values.WindowingStrategy; +import scala.Tuple2; +import scala.collection.IterableOnce; + +/** + * Package private helpers to support translating grouping transforms using `groupByKey` such as + * {@link GroupByKeyTranslatorBatch} or {@link CombinePerKeyTranslatorBatch}. + */ +class GroupByKeyHelpers { + + private GroupByKeyHelpers() {} + + /** + * Checks if it's possible to use an optimized `groupByKey` that also moves the window into the + * key. + * + * @param windowing The windowing strategy + * @param endOfWindowOnly Flag if to limit this optimization to {@link + * TimestampCombiner#END_OF_WINDOW}. + */ + static boolean eligibleForGroupByWindow( + WindowingStrategy windowing, boolean endOfWindowOnly) { + return !windowing.needsMerge() + && (!endOfWindowOnly || windowing.getTimestampCombiner() == END_OF_WINDOW) + && windowing.getWindowFn().windowCoder().consistentWithEquals(); + } + + /** + * Checks if it's possible to use an optimized `groupByKey` for the global window. + * + * @param windowing The windowing strategy + * @param endOfWindowOnly Flag if to limit this optimization to {@link + * TimestampCombiner#END_OF_WINDOW}. + */ + static boolean eligibleForGlobalGroupBy( + WindowingStrategy windowing, boolean endOfWindowOnly) { + return windowing.getWindowFn() instanceof GlobalWindows + && (!endOfWindowOnly || windowing.getTimestampCombiner() == END_OF_WINDOW); + } + + /** + * Explodes a windowed {@link KV} assigned to potentially multiple {@link BoundedWindow}s to a + * traversable of composite keys {@code (BoundedWindow, Key)} and value. + */ + static + Fun1>, IterableOnce, T>>> + explodeWindowedKey(Fun1>, T> valueFn) { + return v -> { + T value = valueFn.apply(v); + K key = v.getValue().getKey(); + return ScalaInterop.scalaIterator(v.getWindows()).map(w -> tuple(tuple(w, key), value)); + }; + } + + static Fun1, V>, WindowedValue>> windowedKV() { + return t -> windowedKV(t._1, t._2); + } + + static WindowedValue> windowedKV(Tuple2 key, V value) { + return WindowedValues.of(KV.of(key._2, value), key._1.maxTimestamp(), key._1, NO_FIRING); + } + + static Fun1, V> value() { + return v -> v.getValue(); + } + + static Fun1>, V> valueValue() { + return v -> v.getValue().getValue(); + } + + static Fun1>, K> valueKey() { + return v -> v.getValue().getKey(); + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java new file mode 100644 index 000000000000..7caf06cb38fd --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import static org.apache.beam.repackaged.core.org.apache.commons.lang3.ArrayUtils.EMPTY_BYTE_ARRAY; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGlobalGroupBy; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGroupByWindow; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.explodeWindowedKey; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueKey; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueValue; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.windowedKV; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.CoderHelpers.toByteArray; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.collectionEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderOf; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.kvEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.windowedValueEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.concat; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun2; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.javaIterator; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf; +import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.collect_list; +import static org.apache.spark.sql.functions.explode; +import static org.apache.spark.sql.functions.max; +import static org.apache.spark.sql.functions.min; +import static org.apache.spark.sql.functions.struct; + +import java.io.Serializable; +import org.apache.beam.runners.core.InMemoryStateInternals; +import org.apache.beam.runners.core.ReduceFnRunner; +import org.apache.beam.runners.core.StateInternalsFactory; +import org.apache.beam.runners.core.SystemReduceFn; +import org.apache.beam.runners.spark.SparkCommonPipelineOptions; +import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; +import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.GroupAlsoByWindowViaOutputBufferFn; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.transforms.GroupByKey; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.TypedColumn; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.checkerframework.checker.nullness.qual.NonNull; +import scala.Tuple2; +import scala.collection.Iterator; +import scala.collection.JavaConverters; +import scala.collection.immutable.List; + +/** + * Translator for {@link GroupByKey} using {@link Dataset#groupByKey} with the built-in aggregation + * function {@code collect_list} when applicable. + * + *

    Note: Using {@code collect_list} isn't any worse than using {@link ReduceFnRunner}. In the + * latter case the entire group (iterator) has to be loaded into memory as well. Either way there's + * a risk of OOM errors. When enabling {@link + * SparkCommonPipelineOptions#getPreferGroupByKeyToHandleHugeValues()}, a more memory sensitive + * iterable is used that can be traversed just once. Attempting to traverse the iterable again will + * throw. + * + *

      + *
    • When using the default global window, window information is dropped and restored after the + * aggregation. + *
    • For non-merging windows, windows are exploded and moved into a composite key for better + * distribution. Though, to keep the amount of shuffled data low, this is only done if values + * are assigned to a single window or if there are only few keys and distributing data is + * important. After the aggregation, windowed values are restored from the composite key. + *
    • All other cases are implemented using the SDK {@link ReduceFnRunner}. + *
    + */ +class GroupByKeyTranslatorBatch + extends TransformTranslator< + PCollection>, PCollection>>, GroupByKey> { + + /** Literal of binary encoded Pane info. */ + private static final Column PANE_NO_FIRING = lit(toByteArray(NO_FIRING, PaneInfoCoder.of())); + + /** Defaults for value in single global window. */ + private static final List GLOBAL_WINDOW_DETAILS = + windowDetails(lit(new byte[][] {EMPTY_BYTE_ARRAY})); + + GroupByKeyTranslatorBatch() { + super(0.2f); + } + + @Override + public void translate(GroupByKey transform, Context cxt) { + WindowingStrategy windowing = cxt.getInput().getWindowingStrategy(); + TimestampCombiner tsCombiner = windowing.getTimestampCombiner(); + + Dataset>> input = cxt.getDataset(cxt.getInput()); + + KvCoder inputCoder = (KvCoder) cxt.getInput().getCoder(); + KvCoder> outputCoder = (KvCoder>) cxt.getOutput().getCoder(); + + Encoder valueEnc = cxt.valueEncoderOf(inputCoder); + Encoder keyEnc = cxt.keyEncoderOf(inputCoder); + + // In batch we can ignore triggering and allowed lateness parameters + final Dataset>>> result; + + boolean useCollectList = + !cxt.getOptions() + .as(SparkCommonPipelineOptions.class) + .getPreferGroupByKeyToHandleHugeValues(); + if (useCollectList && eligibleForGlobalGroupBy(windowing, false)) { + // Collects all values per key in memory. This might be problematic if there's + // few keys only + // or some highly skewed distribution. + result = + input + .groupBy(col("value.key").as("key")) + .agg(collect_list(col("value.value")).as("values"), timestampAggregator(tsCombiner)) + .select( + inGlobalWindow( + keyValue(col("key").as(keyEnc), col("values").as(iterableEnc(valueEnc))), + windowTimestamp(tsCombiner))); + + } else if (eligibleForGlobalGroupBy(windowing, true)) { + // Produces an iterable that can be traversed exactly once. However, on the plus + // side, data is + // not collected in memory until serialized or done by the user. + result = + cxt.getDataset(cxt.getInput()) + .groupByKey(valueKey(), keyEnc) + .mapValues(valueValue(), cxt.valueEncoderOf(inputCoder)) + .mapGroups(fun2((k, it) -> KV.of(k, iterableOnce(it))), cxt.kvEncoderOf(outputCoder)) + .map(fun1(WindowedValues::valueInGlobalWindow), cxt.windowedEncoder(outputCoder)); + + } else if (useCollectList + && eligibleForGroupByWindow(windowing, false) + && (windowing.getWindowFn().assignsToOneWindow() || transform.fewKeys())) { + // Using the window as part of the key should help to better distribute the + // data. However, if + // values are assigned to multiple windows, more data would be shuffled around. + // If there's few + // keys only, this is still valuable. + // Collects all values per key & window in memory. + result = + input + .select(explode(col("windows")).as("window"), col("value"), col("timestamp")) + .groupBy(col("value.key").as("key"), col("window")) + .agg(collect_list(col("value.value")).as("values"), timestampAggregator(tsCombiner)) + .select( + inSingleWindow( + keyValue(col("key").as(keyEnc), col("values").as(iterableEnc(valueEnc))), + col("window").as(cxt.windowEncoder()), + windowTimestamp(tsCombiner))); + + } else if (eligibleForGroupByWindow(windowing, true) + && (windowing.getWindowFn().assignsToOneWindow() || transform.fewKeys())) { + // Using the window as part of the key should help to better distribute the + // data. However, if + // values are assigned to multiple windows, more data would be shuffled around. + // If there's few + // keys only, this is still valuable. + // Produces an iterable that can be traversed exactly once. However, on the plus + // side, data is + // not collected in memory until serialized or done by the user. + Encoder> windowedKeyEnc = + cxt.tupleEncoder(cxt.windowEncoder(), keyEnc); + result = + cxt.getDataset(cxt.getInput()) + .flatMap(explodeWindowedKey(valueValue()), cxt.tupleEncoder(windowedKeyEnc, valueEnc)) + .groupByKey(fun1(t -> t._1()), windowedKeyEnc) + .mapValues(fun1(t -> t._2()), valueEnc) + .mapGroups( + fun2((wKey, it) -> windowedKV(wKey, iterableOnce(it))), + cxt.windowedEncoder(outputCoder)); + + } else { + // Collects all values per key in memory. This might be problematic if there's + // few keys only + // or some highly skewed distribution. + + // FIXME Revisit this case, implementation is far from ideal: + // - iterator traversed at least twice, forcing materialization in memory + + // group by key, then by windows + result = + input + .groupByKey(valueKey(), keyEnc) + .flatMapGroups( + new GroupAlsoByWindowViaOutputBufferFn<>( + windowing, + (SerStateInternalsFactory) key -> InMemoryStateInternals.forKey(key), + SystemReduceFn.buffering(inputCoder.getValueCoder()), + cxt.getOptionsSupplier()), + cxt.windowedEncoder(outputCoder)); + } + + cxt.putDataset(cxt.getOutput(), result); + } + + /** Serializable In-memory state internals factory. */ + private interface SerStateInternalsFactory extends StateInternalsFactory, Serializable {} + + private Encoder> iterableEnc(Encoder enc) { + // safe to use list encoder with collect list + return (Encoder) collectionEncoder(enc); + } + + private static Column[] timestampAggregator(TimestampCombiner tsCombiner) { + if (tsCombiner.equals(TimestampCombiner.END_OF_WINDOW)) { + return new Column[0]; // no aggregation needed + } + Column agg = + tsCombiner.equals(TimestampCombiner.EARLIEST) + ? min(col("timestamp")) + : max(col("timestamp")); + return new Column[] {agg.as("timestamp")}; + } + + private static Column windowTimestamp(TimestampCombiner tsCombiner) { + if (tsCombiner.equals(TimestampCombiner.END_OF_WINDOW)) { + // null will be set to END_OF_WINDOW by the respective deserializer + return litNull(DataTypes.LongType); + } + return col("timestamp"); + } + + /** + * Java {@link Iterable} from Scala {@link Iterator} that can be iterated just once so that we + * don't have to load all data into memory. + */ + private static Iterable iterableOnce(Iterator it) { + return () -> { + checkState(!it.isEmpty(), "Iterator on values can only be consumed once!"); + return javaIterator(it); + }; + } + + private TypedColumn> keyValue(TypedColumn key, TypedColumn value) { + return struct(key.as("key"), value.as("value")).as(kvEncoder(key.encoder(), value.encoder())); + } + + private static TypedColumn> inGlobalWindow( + TypedColumn value, Column ts) { + List fields = concat(timestampedValue(value, ts), GLOBAL_WINDOW_DETAILS); + Encoder> enc = + windowedValueEncoder(value.encoder(), encoderOf(GlobalWindow.class)); + return (TypedColumn>) + struct(JavaConverters.asJavaCollection(fields).toArray(new Column[0])).as(enc); + } + + public static TypedColumn> inSingleWindow( + TypedColumn value, TypedColumn window, Column ts) { + Column windows = org.apache.spark.sql.functions.array(window); + List fields = concat(timestampedValue(value, ts), windowDetails(windows)); + Encoder> enc = windowedValueEncoder(value.encoder(), window.encoder()); + return (TypedColumn>) + struct(JavaConverters.asJavaCollection(fields).toArray(new Column[0])).as(enc); + } + + private static List timestampedValue(Column value, Column ts) { + return seqOf(value.as("value"), ts.as("timestamp")).toList(); + } + + private static List windowDetails(Column windows) { + return seqOf(windows.as("windows"), PANE_NO_FIRING.as("paneInfo")).toList(); + } + + private static Column lit(T t) { + return org.apache.spark.sql.functions.lit(t); + } + + @SuppressWarnings("nullness") // NULL literal + private static Column litNull(DataType dataType) { + return org.apache.spark.sql.functions.lit(null).cast(dataType); + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java new file mode 100644 index 000000000000..6565c2a01c63 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java @@ -0,0 +1,333 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.helpers; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.emptyList; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.replace; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf; + +import java.lang.reflect.Constructor; +import java.util.ArrayList; +import java.util.List; +import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal; +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder; +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders; +import org.apache.spark.sql.catalyst.encoders.AgnosticExpressionPathEncoder; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; +import org.apache.spark.sql.catalyst.expressions.BoundReference; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.objects.Invoke; +import org.apache.spark.sql.catalyst.expressions.objects.NewInstance; +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import scala.Option; +import scala.collection.Iterator; +import scala.collection.immutable.Seq; +import scala.reflect.ClassTag; + +public class EncoderFactory { + // Resolve the Scala case-class primary constructor (the one with the most parameters). + // Constructor ordering returned by Class.getConstructors() is JVM-defined and not stable + // across Spark versions, so we pick the widest constructor explicitly and then dispatch on + // parameter count below to pick the right argument shape per Spark version. + private static final Constructor STATIC_INVOKE_CONSTRUCTOR = + primaryConstructor(StaticInvoke.class); + + private static final Constructor INVOKE_CONSTRUCTOR = primaryConstructor(Invoke.class); + + private static final Constructor NEW_INSTANCE_CONSTRUCTOR = + primaryConstructor(NewInstance.class); + + @SuppressWarnings("unchecked") + private static Constructor primaryConstructor(Class cls) { + Constructor[] ctors = cls.getConstructors(); + Constructor widest = ctors[0]; + for (int i = 1; i < ctors.length; i++) { + if (ctors[i].getParameterCount() > widest.getParameterCount()) { + widest = ctors[i]; + } + } + return (Constructor) widest; + } + + @SuppressWarnings({"nullness", "unchecked"}) + static ExpressionEncoder create( + Expression serializer, Expression deserializer, Class clazz) { + AgnosticEncoder agnosticEncoder = new BeamAgnosticEncoder<>(serializer, deserializer, clazz); + return ExpressionEncoder.apply(agnosticEncoder, serializer, deserializer); + } + + /** + * An {@link AgnosticEncoder} that implements both {@link AgnosticExpressionPathEncoder} (so that + * {@code SerializerBuildHelper} / {@code DeserializerBuildHelper} delegate to our pre-built + * expressions) and {@link AgnosticEncoders.StructEncoder} (so that {@code + * Dataset.select(TypedColumn)} creates an N-attribute plan instead of a 1-attribute wrapped plan, + * preventing {@code FIELD_NUMBER_MISMATCH} errors). + * + *

    The {@code toCatalyst} / {@code fromCatalyst} methods substitute the {@code input} + * expression into the pre-built serializer / deserializer via {@code transformUp}, so that when + * this encoder is nested inside a composite encoder (e.g. {@code Encoders.tuple}) the correct + * field-level expression is used in place of the root {@code BoundReference} / {@code + * GetColumnByOrdinal}. + */ + @SuppressWarnings({"nullness", "unchecked", "deprecation"}) + private static final class BeamAgnosticEncoder + implements AgnosticExpressionPathEncoder, AgnosticEncoders.StructEncoder { + + private final Expression serializer; + private final Expression deserializer; + private final Class clazz; + private final Seq encoderFields; + + BeamAgnosticEncoder(Expression serializer, Expression deserializer, Class clazz) { + this.serializer = serializer; + this.deserializer = deserializer; + this.clazz = clazz; + this.encoderFields = buildFields(serializer.dataType()); + } + + private static Seq buildFields(DataType dt) { + if (dt instanceof StructType) { + StructField[] structFields = ((StructType) dt).fields(); + List fields = new ArrayList<>(structFields.length); + for (StructField sf : structFields) { + fields.add( + new AgnosticEncoders.EncoderField( + sf.name(), + new FieldEncoder<>(sf.dataType(), sf.nullable()), + sf.nullable(), + sf.metadata(), + Option.empty(), + Option.empty())); + } + return seqOf(fields.toArray(new AgnosticEncoders.EncoderField[0])); + } else { + // Non-struct: wrap in a single "value" field so StructEncoder sees one field. + return seqOf( + new AgnosticEncoders.EncoderField( + "value", + new FieldEncoder<>(dt, true), + true, + Metadata.empty(), + Option.empty(), + Option.empty())); + } + } + + // --- AgnosticExpressionPathEncoder --- + + @Override + public Expression toCatalyst(Expression input) { + return serializer.transformUp(replace(BoundReference.class, input)); + } + + @Override + public Expression fromCatalyst(Expression input) { + return deserializer.transformUp(replace(GetColumnByOrdinal.class, input)); + } + + // --- AgnosticEncoders.StructEncoder --- + + @Override + public Seq fields() { + return encoderFields; + } + + @Override + public boolean isStruct() { + return true; + } + + /** + * Setter required by the Scala compiler when implementing the {@link + * AgnosticEncoders.StructEncoder} trait from Java. Scala traits with concrete {@code val} + * fields generate a synthetic mangled setter ({@code $_setter__$eq}) that the + * trait's initializer invokes on subclasses. Java cannot declare {@code val} fields, so we + * implement {@link #isStruct()} directly above and accept-but-ignore the trait setter here. The + * mangled name is brittle and tied to Spark's Scala source layout — if Spark removes the {@code + * isStruct} field from {@code StructEncoder}, this method becomes dead code; if Spark renames + * it, compilation will fail and the new mangled name must be substituted. + */ + @Override + public void + org$apache$spark$sql$catalyst$encoders$AgnosticEncoders$StructEncoder$_setter_$isStruct_$eq( + boolean v) { + // no-op: isStruct() is implemented directly above + } + + // --- AgnosticEncoder / Encoder (explicit to resolve default-method ambiguity) --- + + @Override + public boolean isPrimitive() { + return false; + } + + @Override + public StructType schema() { + // Build StructType from fields — mirrors the StructEncoder.schema() default. + List sfs = new ArrayList<>(encoderFields.size()); + Iterator it = encoderFields.iterator(); + while (it.hasNext()) { + sfs.add(it.next().structField()); + } + return new StructType(sfs.toArray(new StructField[0])); + } + + @Override + public DataType dataType() { + return schema(); + } + + @Override + public ClassTag clsTag() { + return (ClassTag) ClassTag.apply(clazz); + } + } + + /** + * Minimal {@link AgnosticEncoder} stub used to carry per-field {@link DataType} metadata inside + * {@link AgnosticEncoders.EncoderField}. The actual serialization / deserialization is handled by + * {@link BeamAgnosticEncoder#toCatalyst} and {@link BeamAgnosticEncoder#fromCatalyst}. + */ + @SuppressWarnings({"nullness", "unchecked"}) + private static final class FieldEncoder implements AgnosticEncoder { + private final DataType fieldDataType; + private final boolean fieldNullable; + + FieldEncoder(DataType dataType, boolean nullable) { + this.fieldDataType = dataType; + this.fieldNullable = nullable; + } + + @Override + public boolean isPrimitive() { + return false; + } + + @Override + public DataType dataType() { + return fieldDataType; + } + + @Override + public StructType schema() { + return new StructType().add("value", fieldDataType, fieldNullable); + } + + @Override + public boolean nullable() { + return fieldNullable; + } + + @Override + public ClassTag clsTag() { + return (ClassTag) ClassTag.apply(Object.class); + } + } + + /** + * Invoke method {@code fun} on Class {@code cls}, immediately propagating {@code null} if any + * input arg is {@code null}. + */ + static Expression invokeIfNotNull(Class cls, String fun, DataType type, Expression... args) { + return invoke(cls, fun, type, true, args); + } + + /** Invoke method {@code fun} on Class {@code cls}. */ + static Expression invoke(Class cls, String fun, DataType type, Expression... args) { + return invoke(cls, fun, type, false, args); + } + + private static Expression invoke( + Class cls, String fun, DataType type, boolean propagateNull, Expression... args) { + try { + // To address breaking interfaces between various versions of Spark, expressions are + // created reflectively. This is fine as it's just needed once to create the query plan. + switch (STATIC_INVOKE_CONSTRUCTOR.getParameterCount()) { + case 6: + // Spark 3.1.x + return STATIC_INVOKE_CONSTRUCTOR.newInstance( + cls, type, fun, seqOf(args), propagateNull, true); + case 7: + // Spark 3.2.0 + return STATIC_INVOKE_CONSTRUCTOR.newInstance( + cls, type, fun, seqOf(args), emptyList(), propagateNull, true); + case 8: + // Spark 3.2.x, 3.3.x + return STATIC_INVOKE_CONSTRUCTOR.newInstance( + cls, type, fun, seqOf(args), emptyList(), propagateNull, true, true); + case 9: + // Spark 4.0.x: added Option> parameter + return STATIC_INVOKE_CONSTRUCTOR.newInstance( + cls, type, fun, seqOf(args), emptyList(), propagateNull, true, true, Option.empty()); + default: + throw new RuntimeException("Unsupported version of Spark"); + } + } catch (IllegalArgumentException | ReflectiveOperationException ex) { + throw new RuntimeException(ex); + } + } + + /** Invoke method {@code fun} on {@code obj} with provided {@code args}. */ + static Expression invoke( + Expression obj, String fun, DataType type, boolean nullable, Expression... args) { + try { + // To address breaking interfaces between various versions of Spark, expressions are + // created reflectively. This is fine as it's just needed once to create the query plan. + switch (INVOKE_CONSTRUCTOR.getParameterCount()) { + case 6: + // Spark 3.1.x + return INVOKE_CONSTRUCTOR.newInstance(obj, fun, type, seqOf(args), false, nullable); + case 7: + // Spark 3.2.0 + return INVOKE_CONSTRUCTOR.newInstance( + obj, fun, type, seqOf(args), emptyList(), false, nullable); + case 8: + // Spark 3.2.x, 3.3.x, 4.0.x: Invoke constructor is 8 params across all these versions + return INVOKE_CONSTRUCTOR.newInstance( + obj, fun, type, seqOf(args), emptyList(), false, nullable, true); + default: + throw new RuntimeException("Unsupported version of Spark"); + } + } catch (IllegalArgumentException | ReflectiveOperationException ex) { + throw new RuntimeException(ex); + } + } + + static Expression newInstance(Class cls, DataType type, Expression... args) { + try { + // To address breaking interfaces between various versions of Spark, expressions are + // created reflectively. This is fine as it's just needed once to create the query plan. + switch (NEW_INSTANCE_CONSTRUCTOR.getParameterCount()) { + case 5: + return NEW_INSTANCE_CONSTRUCTOR.newInstance(cls, seqOf(args), true, type, Option.empty()); + case 6: + // Spark 3.2.x, 3.3.x, 4.0.x: added immutable.Seq parameter + return NEW_INSTANCE_CONSTRUCTOR.newInstance( + cls, seqOf(args), emptyList(), true, type, Option.empty()); + default: + throw new RuntimeException("Unsupported version of Spark"); + } + } catch (IllegalArgumentException | ReflectiveOperationException ex) { + throw new RuntimeException(ex); + } + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java new file mode 100644 index 000000000000..173b4653a19b --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java @@ -0,0 +1,617 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.helpers; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderFactory.invoke; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderFactory.invokeIfNotNull; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.match; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.replace; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple; +import static org.apache.spark.sql.types.DataTypes.BinaryType; +import static org.apache.spark.sql.types.DataTypes.IntegerType; +import static org.apache.spark.sql.types.DataTypes.LongType; + +import java.math.BigDecimal; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow.IntervalWindowCoder; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.catalyst.SerializerBuildHelper; +import org.apache.spark.sql.catalyst.SerializerBuildHelper.MapElementInformation; +import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal; +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; +import org.apache.spark.sql.catalyst.expressions.BoundReference; +import org.apache.spark.sql.catalyst.expressions.Coalesce; +import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct; +import org.apache.spark.sql.catalyst.expressions.EqualTo; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.GetStructField; +import org.apache.spark.sql.catalyst.expressions.If; +import org.apache.spark.sql.catalyst.expressions.IsNotNull; +import org.apache.spark.sql.catalyst.expressions.IsNull; +import org.apache.spark.sql.catalyst.expressions.Literal; +import org.apache.spark.sql.catalyst.expressions.Literal$; +import org.apache.spark.sql.catalyst.expressions.MapKeys; +import org.apache.spark.sql.catalyst.expressions.MapValues; +import org.apache.spark.sql.catalyst.expressions.objects.MapObjects$; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.ObjectType; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.MutablePair; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Instant; +import scala.Option; +import scala.Some; +import scala.Tuple2; +import scala.collection.IndexedSeq; +import scala.collection.JavaConverters; +import scala.collection.Seq; + +/** {@link Encoders} utility class. */ +public class EncoderHelpers { + private static final DataType OBJECT_TYPE = new ObjectType(Object.class); + private static final DataType TUPLE2_TYPE = new ObjectType(Tuple2.class); + private static final DataType WINDOWED_VALUE = new ObjectType(WindowedValue.class); + private static final DataType KV_TYPE = new ObjectType(KV.class); + private static final DataType MUTABLE_PAIR_TYPE = new ObjectType(MutablePair.class); + private static final DataType LIST_TYPE = new ObjectType(List.class); + + // Collections / maps of these types can be (de)serialized without (de)serializing each member + private static final Set> PRIMITIVE_TYPES = + ImmutableSet.of( + Boolean.class, + Byte.class, + Short.class, + Integer.class, + Long.class, + Float.class, + Double.class); + + // Default encoders by class + private static final Map, Encoder> DEFAULT_ENCODERS = new ConcurrentHashMap<>(); + + // Factory for default encoders by class + private static @Nullable Encoder encoderFactory(Class cls) { + if (cls.equals(PaneInfo.class)) { + return paneInfoEncoder(); + } else if (cls.equals(GlobalWindow.class)) { + return binaryEncoder(GlobalWindow.Coder.INSTANCE, false); + } else if (cls.equals(IntervalWindow.class)) { + return binaryEncoder(IntervalWindowCoder.of(), false); + } else if (cls.equals(Instant.class)) { + return instantEncoder(); + } else if (cls.equals(String.class)) { + return Encoders.STRING(); + } else if (cls.equals(Boolean.class)) { + return Encoders.BOOLEAN(); + } else if (cls.equals(Integer.class)) { + return Encoders.INT(); + } else if (cls.equals(Long.class)) { + return Encoders.LONG(); + } else if (cls.equals(Float.class)) { + return Encoders.FLOAT(); + } else if (cls.equals(Double.class)) { + return Encoders.DOUBLE(); + } else if (cls.equals(BigDecimal.class)) { + return Encoders.DECIMAL(); + } else if (cls.equals(byte[].class)) { + return Encoders.BINARY(); + } else if (cls.equals(Byte.class)) { + return Encoders.BYTE(); + } else if (cls.equals(Short.class)) { + return Encoders.SHORT(); + } + return null; + } + + @SuppressWarnings({"nullness", "methodref.return"}) // computeIfAbsent allows null returns + private static @Nullable Encoder getOrCreateDefaultEncoder(Class cls) { + return (Encoder) DEFAULT_ENCODERS.computeIfAbsent(cls, EncoderHelpers::encoderFactory); + } + + /** Gets or creates a default {@link Encoder} for {@link T}. */ + public static Encoder encoderOf(Class cls) { + Encoder enc = getOrCreateDefaultEncoder(cls); + if (enc == null) { + throw new IllegalArgumentException("No default encoder available for class " + cls); + } + return enc; + } + + /** + * Creates a Spark {@link Encoder} for {@link T} of {@link DataTypes#BinaryType BinaryType} + * delegating to a Beam {@link Coder} underneath. + * + *

    Note: For common types, if available, default Spark {@link Encoder}s are used instead. + * + * @param coder Beam {@link Coder} + */ + public static Encoder encoderFor(Coder coder) { + Encoder enc = getOrCreateDefaultEncoder(coder.getEncodedTypeDescriptor().getRawType()); + return enc != null ? enc : binaryEncoder(coder, true); + } + + /** + * Creates a Spark {@link Encoder} for {@link T} of {@link StructType} with fields {@code value}, + * {@code timestamp}, {@code window} and {@code pane}. + * + * @param value {@link Encoder} to encode field `{@code value}`. + * @param window {@link Encoder} to encode individual windows in field `{@code window}` + */ + public static Encoder> windowedValueEncoder( + Encoder value, Encoder window) { + Encoder timestamp = encoderOf(Instant.class); + Encoder paneInfo = encoderOf(PaneInfo.class); + Encoder> windows = collectionEncoder(window); + Expression serializer = + serializeWindowedValue(rootRef(WINDOWED_VALUE, true), value, timestamp, windows, paneInfo); + Expression deserializer = + deserializeWindowedValue( + rootCol(serializer.dataType()), value, timestamp, windows, paneInfo); + return EncoderFactory.create(serializer, deserializer, WindowedValue.class); + } + + /** + * Creates a one-of Spark {@link Encoder} of {@link StructType} where each alternative is + * represented as colum / field named by its index with a separate {@link Encoder} each. + * + *

    Externally this is represented as tuple {@code (index, data)} where an index corresponds to + * an {@link Encoder} in the provided list. + * + * @param encoders {@link Encoder}s for each alternative. + */ + public static Encoder> oneOfEncoder(List> encoders) { + Expression serializer = serializeOneOf(rootRef(TUPLE2_TYPE, true), encoders); + Expression deserializer = deserializeOneOf(rootCol(serializer.dataType()), encoders); + return EncoderFactory.create(serializer, deserializer, Tuple2.class); + } + + /** + * Creates a Spark {@link Encoder} for {@link KV} of {@link StructType} with fields {@code key} + * and {@code value}. + * + * @param key {@link Encoder} to encode field `{@code key}`. + * @param value {@link Encoder} to encode field `{@code value}` + */ + public static Encoder> kvEncoder(Encoder key, Encoder value) { + Expression serializer = serializeKV(rootRef(KV_TYPE, true), key, value); + Expression deserializer = deserializeKV(rootCol(serializer.dataType()), key, value); + return EncoderFactory.create(serializer, deserializer, KV.class); + } + + /** + * Creates a Spark {@link Encoder} of {@link ArrayType} for Java {@link Collection}s with nullable + * elements. + * + * @param enc {@link Encoder} to encode collection elements + */ + public static Encoder> collectionEncoder(Encoder enc) { + return collectionEncoder(enc, true); + } + + /** + * Creates a Spark {@link Encoder} of {@link ArrayType} for Java {@link Collection}s. + * + * @param enc {@link Encoder} to encode collection elements + * @param nullable Allow nullable collection elements + */ + public static Encoder> collectionEncoder(Encoder enc, boolean nullable) { + DataType type = new ObjectType(Collection.class); + Expression serializer = serializeSeq(rootRef(type, true), enc, nullable); + Expression deserializer = deserializeSeq(rootCol(serializer.dataType()), enc, nullable, true); + return EncoderFactory.create(serializer, deserializer, Collection.class); + } + + /** + * Creates a Spark {@link Encoder} of {@link MapType} that deserializes to {@link MapT}. + * + * @param key {@link Encoder} to encode keys + * @param value {@link Encoder} to encode values + * @param cls Specific class to use, supported are {@link HashMap} and {@link TreeMap} + */ + public static , K, V> Encoder mapEncoder( + Encoder key, Encoder value, Class cls) { + Expression serializer = mapSerializer(rootRef(new ObjectType(cls), true), key, value); + Expression deserializer = mapDeserializer(rootCol(serializer.dataType()), key, value, cls); + return EncoderFactory.create(serializer, deserializer, cls); + } + + /** + * Creates a Spark {@link Encoder} for Spark's {@link MutablePair} of {@link StructType} with + * fields `{@code _1}` and `{@code _2}`. + * + *

    This is intended to be used in places such as aggregators. + * + * @param enc1 {@link Encoder} to encode `{@code _1}` + * @param enc2 {@link Encoder} to encode `{@code _2}` + */ + public static Encoder> mutablePairEncoder( + Encoder enc1, Encoder enc2) { + Expression serializer = serializeMutablePair(rootRef(MUTABLE_PAIR_TYPE, true), enc1, enc2); + Expression deserializer = deserializeMutablePair(rootCol(serializer.dataType()), enc1, enc2); + return EncoderFactory.create(serializer, deserializer, MutablePair.class); + } + + /** + * Creates a Spark {@link Encoder} for {@link PaneInfo} of {@link DataTypes#BinaryType + * BinaryType}. + */ + private static Encoder paneInfoEncoder() { + DataType type = new ObjectType(PaneInfo.class); + return EncoderFactory.create( + invokeIfNotNull(Utils.class, "paneInfoToBytes", BinaryType, rootRef(type, false)), + invokeIfNotNull(Utils.class, "paneInfoFromBytes", type, rootCol(BinaryType)), + PaneInfo.class); + } + + /** + * Creates a Spark {@link Encoder} for Joda {@link Instant} of {@link DataTypes#LongType + * LongType}. + */ + private static Encoder instantEncoder() { + DataType type = new ObjectType(Instant.class); + Expression instant = rootRef(type, true); + Expression millis = rootCol(LongType); + return EncoderFactory.create( + nullSafe(instant, invoke(instant, "getMillis", LongType, false)), + nullSafe(millis, invoke(Instant.class, "ofEpochMilli", type, millis)), + Instant.class); + } + + /** + * Creates a Spark {@link Encoder} for {@link T} of {@link DataTypes#BinaryType BinaryType} + * delegating to a Beam {@link Coder} underneath. + * + * @param coder Beam {@link Coder} + * @param nullable If to allow nullable items + */ + private static Encoder binaryEncoder(Coder coder, boolean nullable) { + Literal litCoder = lit(coder, Coder.class); + // T could be private, use OBJECT_TYPE for code generation to not risk an IllegalAccessError + return EncoderFactory.create( + invokeIfNotNull( + CoderHelpers.class, + "toByteArray", + BinaryType, + rootRef(OBJECT_TYPE, nullable), + litCoder), + invokeIfNotNull( + CoderHelpers.class, "fromByteArray", OBJECT_TYPE, rootCol(BinaryType), litCoder), + coder.getEncodedTypeDescriptor().getRawType()); + } + + private static Expression serializeWindowedValue( + Expression in, + Encoder valueEnc, + Encoder timestampEnc, + Encoder> windowsEnc, + Encoder paneEnc) { + return serializerObject( + in, + tuple("value", serializeField(in, valueEnc, "getValue")), + tuple("timestamp", serializeField(in, timestampEnc, "getTimestamp")), + tuple("windows", serializeField(in, windowsEnc, "getWindows")), + tuple("paneInfo", serializeField(in, paneEnc, "getPaneInfo"))); + } + + private static Expression serializerObject(Expression in, Tuple2... fields) { + return SerializerBuildHelper.createSerializerForObject(in, seqOf(fields)); + } + + private static Expression deserializeWindowedValue( + Expression in, + Encoder valueEnc, + Encoder timestampEnc, + Encoder> windowsEnc, + Encoder paneEnc) { + Expression value = deserializeField(in, valueEnc, 0, "value"); + Expression windows = deserializeField(in, windowsEnc, 2, "windows"); + Expression timestamp = deserializeField(in, timestampEnc, 1, "timestamp"); + Expression paneInfo = deserializeField(in, paneEnc, 3, "paneInfo"); + // set timestamp to end of window (maxTimestamp) if null + timestamp = + ifNotNull(timestamp, invoke(Utils.class, "maxTimestamp", timestamp.dataType(), windows)); + Expression[] fields = new Expression[] {value, timestamp, windows, paneInfo}; + + return nullSafe(paneInfo, invoke(WindowedValues.class, "of", WINDOWED_VALUE, fields)); + } + + private static Expression serializeMutablePair( + Expression in, Encoder enc1, Encoder enc2) { + return serializerObject( + in, + tuple("_1", serializeField(in, enc1, "_1")), + tuple("_2", serializeField(in, enc2, "_2"))); + } + + private static Expression deserializeMutablePair( + Expression in, Encoder enc1, Encoder enc2) { + Expression field1 = deserializeField(in, enc1, 0, "_1"); + Expression field2 = deserializeField(in, enc2, 1, "_2"); + return invoke(MutablePair.class, "apply", MUTABLE_PAIR_TYPE, field1, field2); + } + + private static Expression serializeKV( + Expression in, Encoder keyEnc, Encoder valueEnc) { + return serializerObject( + in, + tuple("key", serializeField(in, keyEnc, "getKey")), + tuple("value", serializeField(in, valueEnc, "getValue"))); + } + + private static Expression deserializeKV( + Expression in, Encoder keyEnc, Encoder valueEnc) { + Expression key = deserializeField(in, keyEnc, 0, "key"); + Expression value = deserializeField(in, valueEnc, 1, "value"); + return invoke(KV.class, "of", KV_TYPE, key, value); + } + + public static Expression serializeOneOf(Expression in, List> encoders) { + Expression type = invoke(in, "_1", IntegerType, false); + Expression[] args = new Expression[encoders.size() * 2]; + for (int i = 0; i < encoders.size(); i++) { + args[i * 2] = lit(String.valueOf(i)); + args[i * 2 + 1] = serializeOneOfField(in, type, encoders.get(i), i); + } + return new CreateNamedStruct(seqOf(args)); + } + + public static Expression deserializeOneOf(Expression in, List> encoders) { + Expression[] args = new Expression[encoders.size()]; + for (int i = 0; i < encoders.size(); i++) { + args[i] = deserializeOneOfField(in, encoders.get(i), i); + } + return new Coalesce(seqOf(args)); + } + + private static Expression serializeOneOfField( + Expression in, Expression type, Encoder enc, int typeIdx) { + Expression litNull = lit(null, serializedType(enc)); + Expression value = invoke(in, "_2", deserializedType(enc), false); + return new If(new EqualTo(type, lit(typeIdx)), serialize(value, enc), litNull); + } + + private static Expression deserializeOneOfField(Expression in, Encoder enc, int idx) { + GetStructField field = new GetStructField(in, idx, Option.empty()); + Expression litNull = lit(null, TUPLE2_TYPE); + Expression newTuple = + EncoderFactory.newInstance(Tuple2.class, TUPLE2_TYPE, lit(idx), deserialize(field, enc)); + return new If(new IsNull(field), litNull, newTuple); + } + + private static Expression serializeField(Expression in, Encoder enc, String getterName) { + Expression ref = serializer(enc).collect(match(BoundReference.class)).head(); + return serialize(invoke(in, getterName, ref.dataType(), ref.nullable()), enc); + } + + private static Expression deserializeField( + Expression in, Encoder enc, int idx, String name) { + return deserialize(new GetStructField(in, idx, new Some<>(name)), enc); + } + + // Note: Currently this doesn't support nullable primitive values + private static Expression mapSerializer(Expression map, Encoder key, Encoder value) { + DataType keyType = deserializedType(key); + DataType valueType = deserializedType(value); + return SerializerBuildHelper.createSerializerForMap( + map, + new MapElementInformation(keyType, false, e -> serialize(e, key)), + new MapElementInformation(valueType, false, e -> serialize(e, value))); + } + + private static , K, V> Expression mapDeserializer( + Expression in, Encoder key, Encoder value, Class cls) { + Preconditions.checkArgument(cls.isAssignableFrom(HashMap.class) || cls.equals(TreeMap.class)); + Expression keys = deserializeSeq(new MapKeys(in), key, false, false); + Expression values = deserializeSeq(new MapValues(in), value, false, false); + String fn = cls.equals(TreeMap.class) ? "toTreeMap" : "toMap"; + return invoke( + Utils.class, fn, new ObjectType(cls), keys, values, mapItemType(key), mapItemType(value)); + } + + // serialized type for primitive types (avoid boxing!), otherwise the deserialized type + private static Literal mapItemType(Encoder enc) { + return lit(isPrimitiveEnc(enc) ? serializedType(enc) : deserializedType(enc), DataType.class); + } + + private static Expression serializeSeq(Expression in, Encoder enc, boolean nullable) { + if (isPrimitiveEnc(enc)) { + Expression array = invoke(in, "toArray", new ObjectType(Object[].class), false); + return SerializerBuildHelper.createSerializerForGenericArray( + array, serializedType(enc), nullable); + } + Expression seq = invoke(Utils.class, "toSeq", new ObjectType(Seq.class), in); + return MapObjects$.MODULE$.apply( + exp -> serialize(exp, enc), seq, deserializedType(enc), nullable, Option.empty()); + } + + private static Expression deserializeSeq( + Expression in, Encoder enc, boolean nullable, boolean exposeAsJava) { + DataType type = serializedType(enc); // input type is the serializer result type + if (isPrimitiveEnc(enc)) { + // Spark may reuse unsafe array data, if directly exposed it must be copied before + return exposeAsJava + ? invoke(Utils.class, "copyToList", LIST_TYPE, in, lit(type, DataType.class)) + : in; + } + Option> optCls = exposeAsJava ? Option.apply(List.class) : Option.empty(); + // MapObjects will always copy + return MapObjects$.MODULE$.apply(exp -> deserialize(exp, enc), in, type, nullable, optCls); + } + + private static boolean isPrimitiveEnc(Encoder enc) { + return PRIMITIVE_TYPES.contains(enc.clsTag().runtimeClass()); + } + + private static Expression serialize(Expression input, Encoder enc) { + return serializer(enc).transformUp(replace(BoundReference.class, input)); + } + + private static Expression deserialize(Expression input, Encoder enc) { + return deserializer(enc).transformUp(replace(GetColumnByOrdinal.class, input)); + } + + /** + * Wraps an {@link Encoder} as an {@link ExpressionEncoder}. In Spark 4.x, built-in encoders (e.g. + * {@code Encoders.INT()}) are {@link AgnosticEncoder} subclasses rather than {@link + * ExpressionEncoder}s, so we convert them on demand. + */ + @SuppressWarnings("unchecked") + private static ExpressionEncoder toExpressionEncoder(Encoder enc) { + if (enc instanceof ExpressionEncoder) { + return (ExpressionEncoder) enc; + } else if (enc instanceof AgnosticEncoder) { + return ExpressionEncoder.apply((AgnosticEncoder) enc); + } + throw new IllegalArgumentException("Unsupported encoder type: " + enc.getClass()); + } + + private static Expression serializer(Encoder enc) { + return toExpressionEncoder(enc).objSerializer(); + } + + private static Expression deserializer(Encoder enc) { + return toExpressionEncoder(enc).objDeserializer(); + } + + private static DataType serializedType(Encoder enc) { + return toExpressionEncoder(enc).objSerializer().dataType(); + } + + private static DataType deserializedType(Encoder enc) { + return toExpressionEncoder(enc).objDeserializer().dataType(); + } + + private static Expression rootRef(DataType dt, boolean nullable) { + return new BoundReference(0, dt, nullable); + } + + private static Expression rootCol(DataType dt) { + return new GetColumnByOrdinal(0, dt); + } + + private static Expression nullSafe(Expression in, Expression out) { + return new If(new IsNull(in), lit(null, out.dataType()), out); + } + + private static Expression ifNotNull(Expression expr, Expression otherwise) { + return new If(new IsNotNull(expr), expr, otherwise); + } + + private static Expression lit(T t) { + return Literal$.MODULE$.apply(t); + } + + @SuppressWarnings("nullness") // literal NULL is allowed + private static Expression lit(@Nullable T t, DataType dataType) { + return new Literal(t, dataType); + } + + private static Literal lit(T obj, Class cls) { + return Literal.fromObject(obj, new ObjectType(cls)); + } + + /** Encoder / expression utils that are called from generated code. */ + public static class Utils { + + public static PaneInfo paneInfoFromBytes(byte[] bytes) { + return CoderHelpers.fromByteArray(bytes, PaneInfoCoder.of()); + } + + public static byte[] paneInfoToBytes(PaneInfo paneInfo) { + return CoderHelpers.toByteArray(paneInfo, PaneInfoCoder.of()); + } + + /** The maximum {@code maxTimestamp} across all associated windows. */ + public static Instant maxTimestamp(Iterable windows) { + Instant maxTimestamp = null; + for (BoundedWindow window : windows) { + Instant timestamp = window.maxTimestamp(); + if (maxTimestamp == null || timestamp.isAfter(maxTimestamp)) { + maxTimestamp = timestamp; + } + } + return Preconditions.checkNotNull( + maxTimestamp, "WindowedValue must have at least one window"); + } + + public static List copyToList(ArrayData arrayData, DataType type) { + // Note, this could be optimized for primitive arrays (if elements are not nullable) using + // Ints.asList(arrayData.toIntArray()) and similar + return Arrays.asList(arrayData.toObjectArray(type)); + } + + public static Seq toSeq(ArrayData arrayData) { + return arrayData.toSeq(OBJECT_TYPE); + } + + public static Seq toSeq(Collection col) { + if (col instanceof List) { + return JavaConverters.asScalaBuffer((List) col); + } + return JavaConverters.collectionAsScalaIterable(col).toSeq(); + } + + public static TreeMap toTreeMap( + ArrayData keys, ArrayData values, DataType keyType, DataType valueType) { + return toMap(new TreeMap<>(), keys, values, keyType, valueType); + } + + public static HashMap toMap( + ArrayData keys, ArrayData values, DataType keyType, DataType valueType) { + HashMap map = Maps.newHashMapWithExpectedSize(keys.numElements()); + return toMap(map, keys, values, keyType, valueType); + } + + private static > MapT toMap( + MapT map, ArrayData keys, ArrayData values, DataType keyType, DataType valueType) { + IndexedSeq keysSeq = keys.toSeq(keyType); + IndexedSeq valuesSeq = values.toSeq(valueType); + for (int i = 0; i < keysSeq.size(); i++) { + map.put(keysSeq.apply(i), valuesSeq.apply(i)); + } + return map; + } + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java new file mode 100644 index 000000000000..175e144d6506 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.utils; + +import java.io.Serializable; +import org.checkerframework.checker.nullness.qual.NonNull; +import scala.Function1; +import scala.Function2; +import scala.PartialFunction; +import scala.Tuple2; +import scala.collection.Iterator; +import scala.collection.JavaConverters; +import scala.collection.Seq; +import scala.collection.immutable.List; +import scala.collection.immutable.Nil$; + +/** Utilities for easier interoperability with the Spark Scala API. */ +public class ScalaInterop { + private ScalaInterop() {} + + public static scala.collection.immutable.Seq seqOf(T... t) { + return (scala.collection.immutable.Seq) + JavaConverters.asScalaBuffer(java.util.Arrays.asList(t)).toList(); + } + + public static List concat(List a, List b) { + return b.$colon$colon$colon(a); + } + + public static Seq listOf(T t) { + return emptyList().$colon$colon(t); + } + + public static List emptyList() { + return (List) Nil$.MODULE$; + } + + /** Scala {@link Iterator} of Java {@link Iterable}. */ + public static Iterator scalaIterator(Iterable iterable) { + return scalaIterator(iterable.iterator()); + } + + /** Scala {@link Iterator} of Java {@link java.util.Iterator}. */ + public static Iterator scalaIterator(java.util.Iterator it) { + return JavaConverters.asScalaIterator(it); + } + + /** Java {@link java.util.Iterator} of Scala {@link Iterator}. */ + public static java.util.Iterator javaIterator(Iterator it) { + return JavaConverters.asJavaIterator(it); + } + + public static Tuple2 tuple(T1 t1, T2 t2) { + return new Tuple2<>(t1, t2); + } + + public static PartialFunction replace( + Class clazz, T replace) { + return new PartialFunction() { + + @Override + public boolean isDefinedAt(T x) { + return clazz.isAssignableFrom(x.getClass()); + } + + @Override + public T apply(T x) { + return replace; + } + }; + } + + public static PartialFunction match(Class clazz) { + return new PartialFunction() { + + @Override + public boolean isDefinedAt(T x) { + return clazz.isAssignableFrom(x.getClass()); + } + + @Override + public V apply(T x) { + return (V) x; + } + }; + } + + public static Fun1 fun1(Fun1 fun) { + return fun; + } + + public static Fun2 fun2(Fun2 fun) { + return fun; + } + + public interface Fun1 extends Function1, Serializable {} + + public interface Fun2 extends Function2, Serializable {} +} diff --git a/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java new file mode 100644 index 000000000000..48c1e645f6ec --- /dev/null +++ b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.helpers; + +import static java.util.Arrays.asList; +import static java.util.function.Function.identity; +import static java.util.stream.Collectors.toList; +import static java.util.stream.Collectors.toMap; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.collectionEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderFor; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.kvEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.mapEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.oneOfEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.windowedValueEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Predicates.notNull; +import static org.apache.spark.sql.types.DataTypes.IntegerType; +import static org.apache.spark.sql.types.DataTypes.StringType; +import static org.apache.spark.sql.types.DataTypes.createStructField; +import static org.apache.spark.sql.types.DataTypes.createStructType; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; + +import java.math.BigDecimal; +import java.math.MathContext; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.TreeMap; +import java.util.function.Function; +import org.apache.beam.runners.spark.structuredstreaming.SparkSessionRule; +import org.apache.beam.sdk.coders.BigDecimalCoder; +import org.apache.beam.sdk.coders.BigEndianIntegerCoder; +import org.apache.beam.sdk.coders.BigEndianLongCoder; +import org.apache.beam.sdk.coders.BigEndianShortCoder; +import org.apache.beam.sdk.coders.BooleanCoder; +import org.apache.beam.sdk.coders.ByteCoder; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.DelegateCoder; +import org.apache.beam.sdk.coders.DoubleCoder; +import org.apache.beam.sdk.coders.FloatCoder; +import org.apache.beam.sdk.coders.InstantCoder; +import org.apache.beam.sdk.coders.ListCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.joda.time.Instant; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import scala.Tuple2; + +/** Test of the wrapping of Beam Coders as Spark ExpressionEncoders. */ +@RunWith(JUnit4.class) +public class EncoderHelpersTest { + @ClassRule public static SparkSessionRule sessionRule = new SparkSessionRule("local[1]"); + + private static final Encoder windowEnc = + EncoderHelpers.encoderOf(GlobalWindow.class); + + private static final Map, List> BASIC_CASES = + ImmutableMap., List>builder() + .put(BooleanCoder.of(), asList(true, false, null)) + .put(ByteCoder.of(), asList((byte) 1, null)) + .put(BigEndianShortCoder.of(), asList((short) 1, null)) + .put(BigEndianIntegerCoder.of(), asList(1, 2, 3, null)) + .put(VarIntCoder.of(), asList(1, 2, 3, null)) + .put(BigEndianLongCoder.of(), asList(1L, 2L, 3L, null)) + .put(VarLongCoder.of(), asList(1L, 2L, 3L, null)) + .put(FloatCoder.of(), asList((float) 1.0, (float) 2.0, null)) + .put(DoubleCoder.of(), asList(1.0, 2.0, null)) + .put(StringUtf8Coder.of(), asList("1", "2", null)) + .put(BigDecimalCoder.of(), asList(bigDecimalOf(1L), bigDecimalOf(2L), null)) + .put(InstantCoder.of(), asList(Instant.ofEpochMilli(1), null)) + .build(); + + private Dataset createDataset(List data, Encoder encoder) { + Dataset ds = sessionRule.getSession().createDataset(data, encoder); + ds.printSchema(); + return ds; + } + + @Test + public void testBeamEncoderMappings() { + BASIC_CASES.forEach( + (coder, data) -> { + Encoder encoder = encoderFor(coder); + serializeAndDeserialize(data.get(0), (Encoder) encoder); + Dataset dataset = createDataset(data, (Encoder) encoder); + assertThat(dataset.collect(), equalTo(data.toArray())); + }); + } + + @Test + public void testBeamEncoderOfPrivateType() { + // Verify concrete types are not used in coder generation. + // In case of private types this would cause an IllegalAccessError. + List data = asList(new PrivateString("1"), new PrivateString("2")); + Dataset dataset = createDataset(data, encoderFor(PrivateString.CODER)); + assertThat(dataset.collect(), equalTo(data.toArray())); + } + + @Test + public void testBeamWindowedValueEncoderMappings() { + BASIC_CASES.forEach( + (coder, data) -> { + List> windowed = + Lists.transform(data, WindowedValues::valueInGlobalWindow); + + Encoder encoder = windowedValueEncoder(encoderFor(coder), windowEnc); + serializeAndDeserialize(windowed.get(0), (Encoder) encoder); + + Dataset dataset = createDataset(windowed, (Encoder) encoder); + assertThat(dataset.collect(), equalTo(windowed.toArray())); + }); + } + + @Test + public void testCollectionEncoder() { + BASIC_CASES.forEach( + (coder, data) -> { + Encoder> encoder = collectionEncoder(encoderFor(coder), true); + Collection collection = Collections.unmodifiableCollection(data); + + Dataset> dataset = createDataset(asList(collection), (Encoder) encoder); + assertThat(dataset.head(), equalTo(data)); + }); + } + + private void testMapEncoder(Class cls, Function, Map> decorator) { + BASIC_CASES.forEach( + (coder, data) -> { + Encoder enc = encoderFor(coder); + Encoder> mapEncoder = mapEncoder(enc, enc, (Class) cls); + Map map = + decorator.apply( + data.stream().filter(notNull()).collect(toMap(identity(), identity()))); + + Dataset> dataset = createDataset(asList(map), mapEncoder); + Map head = dataset.head(); + assertThat(head, equalTo(map)); + assertThat(head, instanceOf(cls)); + }); + } + + @Test + public void testMapEncoder() { + testMapEncoder(Map.class, identity()); + } + + @Test + public void testHashMapEncoder() { + testMapEncoder(HashMap.class, identity()); + } + + @Test + public void testTreeMapEncoder() { + testMapEncoder(TreeMap.class, TreeMap::new); + } + + @Test + public void testBeamBinaryEncoder() { + List> data = asList(asList("a1", "a2", "a3"), asList("b1", "b2"), asList("c1")); + + Encoder> encoder = encoderFor(ListCoder.of(StringUtf8Coder.of())); + serializeAndDeserialize(data.get(0), encoder); + + Dataset> dataset = createDataset(data, encoder); + assertThat(dataset.collect(), equalTo(data.toArray())); + } + + @Test + public void testEncoderForKVCoder() { + List> data = + asList(KV.of(1, "value1"), KV.of(null, "value2"), KV.of(3, null)); + + Encoder> encoder = + kvEncoder(encoderFor(VarIntCoder.of()), encoderFor(StringUtf8Coder.of())); + serializeAndDeserialize(data.get(0), encoder); + + Dataset> dataset = createDataset(data, encoder); + + StructType kvSchema = + createStructType( + new StructField[] { + createStructField("key", IntegerType, true), + createStructField("value", StringType, true) + }); + + assertThat(dataset.schema(), equalTo(kvSchema)); + assertThat(dataset.collectAsList(), equalTo(data)); + } + + @Test + public void testOneOffEncoder() { + List> coders = ImmutableList.copyOf(BASIC_CASES.keySet()); + List> encoders = coders.stream().map(EncoderHelpers::encoderFor).collect(toList()); + + // build oneOf tuples of type index and corresponding value + List> data = + BASIC_CASES.entrySet().stream() + .map(e -> tuple(coders.indexOf(e.getKey()), (Object) e.getValue().get(0))) + .collect(toList()); + + // dataset is a sparse dataset with only one column set per row + Dataset> dataset = createDataset(data, oneOfEncoder((List) encoders)); + assertThat(dataset.collectAsList(), equalTo(data)); + } + + // fix scale/precision to system default to compare using equals + private static BigDecimal bigDecimalOf(long l) { + DecimalType type = DecimalType.SYSTEM_DEFAULT(); + return new BigDecimal(l, new MathContext(type.precision())).setScale(type.scale()); + } + + // test and explicit serialization roundtrip + @SuppressWarnings("unchecked") + private static void serializeAndDeserialize(T data, Encoder enc) { + ExpressionEncoder bound; + if (enc instanceof ExpressionEncoder) { + bound = (ExpressionEncoder) enc; + } else { + bound = ExpressionEncoder.apply((AgnosticEncoder) enc); + } + bound = + bound.resolveAndBind(bound.resolveAndBind$default$1(), bound.resolveAndBind$default$2()); + + InternalRow row = bound.createSerializer().apply(data); + T deserialized = bound.createDeserializer().apply(row); + + assertThat(deserialized, equalTo(data)); + } + + private static class PrivateString { + private static final Coder CODER = + DelegateCoder.of( + StringUtf8Coder.of(), + str -> str.string, + PrivateString::new, + new TypeDescriptor() {}); + + private final String string; + + public PrivateString(String string) { + this.string = string; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof PrivateString)) { + return false; + } + PrivateString that = (PrivateString) o; + return Objects.equals(string, that.string); + } + + @Override + public int hashCode() { + return Objects.hash(string); + } + } +} diff --git a/runners/spark/spark_runner.gradle b/runners/spark/spark_runner.gradle index 0e77821e533e..273802b82391 100644 --- a/runners/spark/spark_runner.gradle +++ b/runners/spark/spark_runner.gradle @@ -256,12 +256,23 @@ dependencies { implementation "org.apache.spark:spark-common-utils_$spark_scala_version:$spark_version" implementation "org.apache.spark:spark-sql-api_$spark_scala_version:$spark_version" } + if (isSparkAtLeast("4.0.0")) { + // Spark 4 splits the Connect shims out of spark-sql; classes are referenced directly + // (e.g. via SparkSession builder paths) so strict dependency analysis requires it as + // a declared provided dep. The artifact does not exist for Spark 3. + provided "org.apache.spark:spark-connect-shims_$spark_scala_version:$spark_version" + } permitUnusedDeclared "org.apache.spark:spark-network-common_$spark_scala_version:$spark_version" implementation "io.dropwizard.metrics:metrics-core:4.1.1" // version used by Spark 3.1 - compileOnly "org.scala-lang:scala-library:2.12.15" - runtimeOnly library.java.jackson_module_scala_2_12 - // Force paranamer 2.8 to avoid issues when using Scala 2.12 - runtimeOnly "com.thoughtworks.paranamer:paranamer:2.8" + if (spark_scala_version == '2.13') { + compileOnly "org.scala-lang:scala-library:2.13.15" + runtimeOnly library.java.jackson_module_scala_2_13 + } else { + compileOnly "org.scala-lang:scala-library:2.12.15" + runtimeOnly library.java.jackson_module_scala_2_12 + // Force paranamer 2.8 to avoid issues when using Scala 2.12 + runtimeOnly "com.thoughtworks.paranamer:paranamer:2.8" + } provided "org.apache.hadoop:hadoop-client-api:3.3.1" provided library.java.commons_io provided library.java.hamcrest @@ -276,7 +287,9 @@ dependencies { testImplementation project(path: ":sdks:java:extensions:avro", configuration: "testRuntimeMigration") testImplementation project(":sdks:java:harness") testImplementation library.java.avro - testImplementation "org.apache.kafka:kafka_$spark_scala_version:2.4.1" + // kafka_2.13 artifacts were first published in 2.5.0; use a later version for Scala 2.13 + def kafka_version = (spark_scala_version == '2.13') ? '2.8.0' : '2.4.1' + testImplementation "org.apache.kafka:kafka_$spark_scala_version:$kafka_version" testImplementation library.java.kafka_clients testImplementation library.java.junit testImplementation library.java.mockito_core diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistrator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistrator.java index 68c602ff7f59..75c33b7dc5b5 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistrator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistrator.java @@ -28,9 +28,10 @@ import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.HashBasedTable; import org.apache.spark.serializer.KryoRegistrator; -import scala.collection.mutable.WrappedArray; +import org.checkerframework.checker.nullness.qual.Nullable; /** * Custom {@link KryoRegistrator}s for Beam's Spark runner needs and registering used class in spark @@ -61,7 +62,18 @@ public void registerClasses(Kryo kryo) { kryo.register(PaneInfo.class); kryo.register(StateAndTimers.class); kryo.register(TupleTag.class); - kryo.register(WrappedArray.ofRef.class); + // Scala 2.12 uses WrappedArray$ofRef, Scala 2.13 renamed it to ArraySeq$ofRef + Class scalaArrayClass = + findFirstAvailableClass( + "scala.collection.mutable.ArraySeq$ofRef", + "scala.collection.mutable.WrappedArray$ofRef"); + if (scalaArrayClass == null) { + throw new IllegalStateException( + "Neither scala.collection.mutable.ArraySeq$ofRef (Scala 2.13) nor " + + "scala.collection.mutable.WrappedArray$ofRef (Scala 2.12) was found on the " + + "classpath. Cannot register Scala wrapped arrays with Kryo."); + } + kryo.register(scalaArrayClass); try { kryo.register( @@ -74,4 +86,16 @@ public void registerClasses(Kryo kryo) { throw new IllegalStateException("Unable to register classes with kryo.", e); } } + + @VisibleForTesting + static @Nullable Class findFirstAvailableClass(String... classNames) { + for (String name : classNames) { + try { + return Class.forName(name); + } catch (ClassNotFoundException ignored) { + // try the next candidate + } + } + return null; + } } diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java index 1c7a4c2a2416..be69ee78e51c 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java @@ -522,7 +522,9 @@ JavaDStream>>> groupByKeyAndWindow( Tuple2>>*/ List>>> firedStream = pairDStream.updateStateByKey( - updateFunc, + // Raw cast to AbstractFunction1 suppresses Scala 2.12 (collection.Seq) vs + // Scala 2.13 (immutable.Seq) type difference — safe at runtime due to erasure. + (scala.runtime.AbstractFunction1) updateFunc, pairDStream.defaultPartitioner(pairDStream.defaultPartitioner$default$1()), true, JavaSparkContext$.MODULE$.fakeClassTag()); diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineResult.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineResult.java index 806d838d9bff..9d3419e19473 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineResult.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineResult.java @@ -113,6 +113,7 @@ public MetricResults metrics() { @Override public PipelineResult.State cancel() throws IOException { + pipelineExecution.cancel(true); offerNewState(PipelineResult.State.CANCELLED); return state; } diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java index c0d46e77c1d6..d7cdefc929b7 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java @@ -241,7 +241,7 @@ static List> partitionsOf(BoundedSource source, Params try { PipelineOptions options = params.options.get(); long desiredSize = source.getEstimatedSizeBytes(options) / params.numPartitions; - List> split = (List>) source.split(desiredSize, options); + List> split = source.split(desiredSize, options); IntSupplier idxSupplier = new AtomicInteger(0)::getAndIncrement; return split.stream().map(s -> new SourcePartition<>(s, idxSupplier)).collect(toList()); } catch (Exception e) { diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/EvaluationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/EvaluationContext.java index 55c4bbaedd3c..b8448567eafc 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/EvaluationContext.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/EvaluationContext.java @@ -86,7 +86,10 @@ public static void evaluate(String name, Dataset ds) { ds.write().mode("overwrite").format("noop").save(); LOG.info("Evaluated dataset {} in {}", name, durationSince(startMs)); } catch (RuntimeException e) { - LOG.error("Failed to evaluate dataset {}: {}", name, Throwables.getRootCause(e).getMessage()); + LOG.error( + "Failed to evaluate dataset {}: {}", + name, + String.valueOf(Throwables.getRootCause(e).getMessage())); throw new RuntimeException(e); } } @@ -102,7 +105,10 @@ public static void evaluate(String name, Dataset ds) { LOG.info("Collected dataset {} in {} [size: {}]", name, durationSince(startMs), res.length); return res; } catch (Exception e) { - LOG.error("Failed to collect dataset {}: {}", name, Throwables.getRootCause(e).getMessage()); + LOG.error( + "Failed to collect dataset {}: {}", + name, + String.valueOf(Throwables.getRootCause(e).getMessage())); throw new RuntimeException(e); } } diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java index 02c56a8081cf..513ef28a5897 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java @@ -64,7 +64,7 @@ * TODOs: *
  • combine with context (CombineFnWithContext)? *
  • combine with sideInputs? - *
  • other there other missing features? + *
  • are there other missing features? */ class CombinePerKeyTranslatorBatch extends TransformTranslator< diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyHelpers.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyHelpers.java index b7c139068d1b..2105fd05d493 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyHelpers.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyHelpers.java @@ -21,7 +21,6 @@ import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING; import static org.apache.beam.sdk.transforms.windowing.TimestampCombiner.END_OF_WINDOW; -import java.util.Collection; import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop; import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; @@ -80,8 +79,7 @@ static boolean eligibleForGlobalGroupBy( return v -> { T value = valueFn.apply(v); K key = v.getValue().getKey(); - Collection windows = (Collection) v.getWindows(); - return ScalaInterop.scalaIterator(windows).map(w -> tuple(tuple(w, key), value)); + return ScalaInterop.scalaIterator(v.getWindows()).map(w -> tuple(tuple(w, key), value)); }; } diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java index 46cda3334822..c55781ff84a8 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java @@ -81,7 +81,7 @@ import scala.collection.immutable.List; /** - * Translator for {@link GroupByKey} using {@link Dataset#groupByKey} with the build-in aggregation + * Translator for {@link GroupByKey} using {@link Dataset#groupByKey} with the built-in aggregation * function {@code collect_list} when applicable. * *

    Note: Using {@code collect_list} isn't any worse than using {@link ReduceFnRunner}. In the diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java index daf8451faac5..29f01c84c02e 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java @@ -48,7 +48,6 @@ import org.apache.beam.sdk.values.WindowedValues; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; @@ -98,7 +97,7 @@ public class EncoderHelpers { private static final DataType LIST_TYPE = new ObjectType(List.class); // Collections / maps of these types can be (de)serialized without (de)serializing each member - private static final Set> PRIMITIV_TYPES = + private static final Set> PRIMITIVE_TYPES = ImmutableSet.of( Boolean.class, Byte.class, @@ -154,7 +153,7 @@ public class EncoderHelpers { public static Encoder encoderOf(Class cls) { Encoder enc = getOrCreateDefaultEncoder(cls); if (enc == null) { - throw new IllegalArgumentException("No default coder available for class " + cls); + throw new IllegalArgumentException("No default encoder available for class " + cls); } return enc; } @@ -481,7 +480,7 @@ private static Expression deserializeSeq( } private static boolean isPrimitiveEnc(Encoder enc) { - return PRIMITIV_TYPES.contains(enc.clsTag().runtimeClass()); + return PRIMITIVE_TYPES.contains(enc.clsTag().runtimeClass()); } private static Expression serialize(Expression input, Encoder enc) { @@ -548,9 +547,17 @@ public static byte[] paneInfoToBytes(PaneInfo paneInfo) { return CoderHelpers.toByteArray(paneInfo, PaneInfoCoder.of()); } - /** The end of the only window (max timestamp). */ - public static Instant maxTimestamp(Iterable windows) { - return Iterables.getOnlyElement(windows).maxTimestamp(); + /** The maximum {@code maxTimestamp} across all associated windows. */ + public static Instant maxTimestamp(Iterable windows) { + Instant maxTimestamp = null; + for (BoundedWindow window : windows) { + Instant timestamp = window.maxTimestamp(); + if (maxTimestamp == null || timestamp.isAfter(maxTimestamp)) { + maxTimestamp = timestamp; + } + } + return Preconditions.checkNotNull( + maxTimestamp, "WindowedValue must have at least one window"); } public static List copyToList(ArrayData arrayData, DataType type) { diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistratorTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistratorTest.java index ddd0e74d1c9e..56eb0dfea5f2 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistratorTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistratorTest.java @@ -17,7 +17,10 @@ */ package org.apache.beam.runners.spark.coders; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import com.esotericsoftware.kryo.Kryo; @@ -73,6 +76,52 @@ public void testDefaultSerializerNotCallingKryo() { } } + /** Unit tests for the {@link SparkRunnerKryoRegistrator#findFirstAvailableClass} helper. */ + public static class FindFirstAvailableClassTest { + + @Test + public void returnsFirstWhenAvailable() { + Class result = + SparkRunnerKryoRegistrator.findFirstAvailableClass( + "java.lang.String", "java.lang.Integer"); + assertSame(String.class, result); + } + + @Test + public void fallsBackWhenFirstMissing() { + Class result = + SparkRunnerKryoRegistrator.findFirstAvailableClass("does.not.Exist", "java.lang.Integer"); + assertSame(Integer.class, result); + } + + @Test + public void returnsNullWhenNoneAvailable() { + Class result = + SparkRunnerKryoRegistrator.findFirstAvailableClass("does.not.Exist1", "does.not.Exist2"); + assertNull(result); + } + + @Test + public void returnsNullForEmptyInput() { + assertNull(SparkRunnerKryoRegistrator.findFirstAvailableClass()); + } + + @Test + public void resolvesScalaWrappedArrayClassOnRealClasspath() { + // On any supported Scala version (2.12 ArraySeq$ofRef does not exist; 2.13 it does), at + // least one of the two wrapped-array class names must resolve. This is the production call + // the registrator makes. + Class result = + SparkRunnerKryoRegistrator.findFirstAvailableClass( + "scala.collection.mutable.ArraySeq$ofRef", + "scala.collection.mutable.WrappedArray$ofRef"); + assertEquals( + "expected one of the Scala wrapped-array classes to be on the classpath", + true, + result != null); + } + } + // Hide TestKryoRegistrator from the Enclosed JUnit runner interface Others { class TestKryoRegistrator extends SparkRunnerKryoRegistrator { diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/EvaluationContextTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/EvaluationContextTest.java new file mode 100644 index 000000000000..d1669e45c182 --- /dev/null +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/EvaluationContextTest.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation; + +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + +import org.apache.spark.sql.Dataset; +import org.junit.Test; + +/** + * Unit tests for the static error-path branches of {@link EvaluationContext}. The happy-path + * branches are covered end-to-end by the structured-streaming translation tests; these tests + * specifically exercise the {@code catch} blocks that wrap and rethrow underlying Spark failures. + */ +public class EvaluationContextTest { + + @Test + public void evaluateWrapsAndRethrowsRuntimeException() { + @SuppressWarnings("unchecked") + Dataset ds = mock(Dataset.class); + RuntimeException underlying = new RuntimeException("boom"); + doThrow(underlying).when(ds).write(); + + RuntimeException thrown = + assertThrows(RuntimeException.class, () -> EvaluationContext.evaluate("test-ds", ds)); + assertSame(underlying, thrown.getCause()); + } + + @Test + public void evaluateHandlesNullExceptionMessage() { + // Reproduces the original NPE motivation for the String.valueOf wrap: a RuntimeException + // whose root cause carries a null message must not crash the error logger. + @SuppressWarnings("unchecked") + Dataset ds = mock(Dataset.class); + RuntimeException underlying = new RuntimeException((String) null); + doThrow(underlying).when(ds).write(); + + RuntimeException thrown = + assertThrows(RuntimeException.class, () -> EvaluationContext.evaluate("test-ds", ds)); + assertSame(underlying, thrown.getCause()); + } + + @Test + public void collectWrapsAndRethrowsException() { + @SuppressWarnings("unchecked") + Dataset ds = mock(Dataset.class); + RuntimeException underlying = new RuntimeException("boom"); + doThrow(underlying).when(ds).collect(); + + RuntimeException thrown = + assertThrows(RuntimeException.class, () -> EvaluationContext.collect("test-ds", ds)); + assertSame(underlying, thrown.getCause()); + } + + @Test + public void collectHandlesNullExceptionMessage() { + @SuppressWarnings("unchecked") + Dataset ds = mock(Dataset.class); + RuntimeException underlying = new RuntimeException((String) null); + doThrow(underlying).when(ds).collect(); + + RuntimeException thrown = + assertThrows(RuntimeException.class, () -> EvaluationContext.collect("test-ds", ds)); + assertSame(underlying, thrown.getCause()); + } +} diff --git a/settings.gradle.kts b/settings.gradle.kts index 4080206bb542..131e601a16c1 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -150,6 +150,8 @@ include(":runners:prism:java") include(":runners:spark:3") include(":runners:spark:3:job-server") include(":runners:spark:3:job-server:container") +include(":runners:spark:4") +include(":runners:spark:4:job-server") include(":sdks:go") include(":sdks:go:container") include(":sdks:go:examples")