diff --git a/udf/worker/README.md b/udf/worker/README.md index fa27430b62b62..b843c430d0e04 100644 --- a/udf/worker/README.md +++ b/udf/worker/README.md @@ -5,44 +5,162 @@ Package structure for the UDF worker framework described in ## Overview -Spark processes a UDF by first obtaining a **WorkerDispatcher** from the worker -specification (plus context such as security scope). The dispatcher manages the -actual worker processes behind the scenes -- pooling, reuse, and termination are -all invisible to Spark. +Spark processes a UDF by obtaining a **WorkerDispatcher** from a worker +specification. The dispatcher manages workers behind the scenes. From +the dispatcher, Spark gets a **WorkerSession** -- one per UDF invocation -- +with an Iterator-to-Iterator `process` API that streams input batches +through the worker and returns result batches. -From the dispatcher, Spark gets a **WorkerSession**, which represents one single -UDF execution and can carry per-execution state. A WorkerSession is not 1-to-1 -mapped to an actual worker -- multiple sessions may share the same underlying -worker when it is reused. Worker reuse is managed by each dispatcher -implementation based on the worker specification. +``` +UDFWorkerSpecification -- how to create and configure workers + | + v +WorkerDispatcher -- manages workers, creates sessions + | + v +WorkerSession -- one UDF execution + | 1. session.init(InitMessage(payload, inputSchema, outputSchema)) + | 2. val results = session.process(inputBatches) + | 3. session.close() +``` + +How workers are created depends on the dispatcher implementation. The +framework currently provides **direct worker creation** (local OS +processes) and is designed for future **indirect creation** (via a +provisioning service or daemon). ## Sub-packages ``` udf/worker/ -├── proto/ Protobuf definition of the worker specification -│ (UDFWorkerSpecification). -│ WorkerSpecification -- typed Scala wrapper around the protobuf spec. -└── core/ Engine-side APIs (all @Experimental): - WorkerDispatcher -- manages workers for one spec; creates sessions. - WorkerSession -- represents one single UDF execution. - WorkerSecurityScope -- security boundary for connection pooling. +├── proto/ +│ worker_spec.proto -- UDFWorkerSpecification protobuf (+ generated Java classes) +│ common.proto -- shared enums (UDFWorkerDataFormat, etc.) +│ +└── core/ -- abstract interfaces + WorkerDispatcher.scala -- creates sessions, manages worker lifecycle + WorkerSession.scala -- per-UDF init/process/cancel/close + InitMessage + WorkerConnection.scala -- transport channel abstraction + WorkerSecurityScope.scala -- security boundary for worker pooling + │ + └── direct/ -- "direct" creation: local OS processes + DirectWorkerDispatcher.scala -- spawns processes, env lifecycle + DirectWorkerProcess.scala -- OS process + connection + UDS socket + DirectWorkerSession.scala -- session backed by a direct process +``` + +The `core/` package defines abstract interfaces that are independent of how +workers are created. The `core/direct/` sub-package implements "direct" +worker creation where Spark spawns local OS processes. Future packages +(e.g., `core/indirect/`) can implement alternative creation modes such as +obtaining workers from a provisioning service or daemon. + +### Direct worker creation + +`DirectWorkerDispatcher` spawns worker processes locally. On the first +session, it runs the optional environment lifecycle callables from the +`UDFWorkerSpecification`: + +- **`environmentVerification`** -- checks if the environment is ready + (exit 0 = ready). When it succeeds, installation is skipped. +- **`installation`** -- prepares the environment (installs runtime, + dependencies, worker binaries). Only runs when verification is absent + or fails. +- **`environmentCleanup`** -- runs after the dispatcher is closed or on + JVM shutdown to clean up temporary resources. + +Environment setup runs **once per dispatcher** (not per session). +Workers are terminated via SIGTERM/SIGKILL when the dispatcher is closed. + +## Basic usage (Scala) + +```scala +import org.apache.spark.udf.worker.{ + DirectWorker, ProcessCallable, UDFProtoCommunicationPattern, + UDFWorkerDataFormat, UDFWorkerProperties, UDFWorkerSpecification, + UnixDomainSocket, WorkerCapabilities, WorkerConnectionSpec, WorkerEnvironment} +import org.apache.spark.udf.worker.core._ + +// 1. Define a worker spec (direct creation mode). +val spec = UDFWorkerSpecification.newBuilder() + .setEnvironment(WorkerEnvironment.newBuilder() + .setEnvironmentVerification(ProcessCallable.newBuilder() + .addCommand("python").addCommand("-c").addCommand("import my_udf_worker").build()) + .setInstallation(ProcessCallable.newBuilder() + .addCommand("pip").addCommand("install").addCommand("my_udf_worker").build()) + .build()) + .setCapabilities(WorkerCapabilities.newBuilder() + .addSupportedDataFormats(UDFWorkerDataFormat.ARROW) + .addSupportedCommunicationPatterns( + UDFProtoCommunicationPattern.BIDIRECTIONAL_STREAMING) + .build()) + .setDirect(DirectWorker.newBuilder() + .setRunner(ProcessCallable.newBuilder() + .addCommand("python").addCommand("-m").addCommand("my_udf_worker").build()) + .setProperties(UDFWorkerProperties.newBuilder() + .setConnection(WorkerConnectionSpec.newBuilder() + .setUnixDomainSocket(UnixDomainSocket.getDefaultInstance).build()) + .build()) + .build()) + .build() + +// 2. Create a dispatcher. Use a protocol-specific subclass of +// DirectWorkerDispatcher (e.g., gRPC over UDS). +val dispatcher: WorkerDispatcher = ... + +// 3. Create a session for one UDF execution. +val session = dispatcher.createSession(securityScope = None) +try { + // 4. Initialize with the serialized function and schemas. + session.init(InitMessage( + functionPayload = serializedFunction, + inputSchema = arrowInputSchema, + outputSchema = arrowOutputSchema)) + + // 5. Process data -- Iterator in, Iterator out. + val results: Iterator[Array[Byte]] = + session.process(inputBatches) + + // Consume results lazily. + results.foreach(processResultBatch) +} finally { + session.close() +} + +// 6. Shut down all workers. +dispatcher.close() ``` ## Build SBT: ``` -build/sbt "udf-worker-core/compile" -build/sbt "udf-worker-core/test" +build/sbt "udf-worker-proto/compile" "udf-worker-core/compile" ``` Maven: ``` -./build/mvn -pl udf/worker/proto,udf/worker/core -am compile -./build/mvn -pl udf/worker/proto,udf/worker/core -am test +build/mvn compile -pl udf/worker/proto,udf/worker/core -am ``` +## Test + +SBT: +``` +build/sbt "udf-worker-core/test" +``` + +## Current status + +This is the **first MVP** providing the core abstraction layer and the +direct worker dispatcher. +The following are left as TODOs: + +- **Connection pooling** -- reuse workers across sessions +- **Security scope isolation** -- partition pools by `WorkerSecurityScope` +- **Indirect worker creation** -- obtain workers from a service or daemon +- **Protocol-specific implementations** -- e.g., gRPC over UDS + ## Design references * [SPIP Language-agnostic UDF Protocol for Spark](https://docs.google.com/document/d/19Whzq127QxVt2Luk0EClgaDtcpBsFUp67NcVdKKyPF8/edit?tab=t.0) diff --git a/udf/worker/proto/src/main/scala/org/apache/spark/udf/worker/WorkerSpecification.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/UnixSocketWorkerConnection.scala similarity index 57% rename from udf/worker/proto/src/main/scala/org/apache/spark/udf/worker/WorkerSpecification.scala rename to udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/UnixSocketWorkerConnection.scala index e25b99b69990c..b3b40d16e7443 100644 --- a/udf/worker/proto/src/main/scala/org/apache/spark/udf/worker/WorkerSpecification.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/UnixSocketWorkerConnection.scala @@ -14,14 +14,28 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.udf.worker +package org.apache.spark.udf.worker.core + +import java.io.File import org.apache.spark.annotation.Experimental /** * :: Experimental :: - * Typed Scala wrapper around the protobuf [[UDFWorkerSpecification]]. + * A [[WorkerConnection]] over a Unix domain socket. Owns the socket + * path and removes the socket file on [[close]]. Subclasses provide the + * protocol-specific channel (e.g. gRPC over UDS) and may override + * [[close]] to add transport-level shutdown -- they should call + * `super.close()` to ensure the socket file is removed. + * + * [[close]] is idempotent: deleting an already-removed file is a no-op. */ @Experimental -class WorkerSpecification(val proto: UDFWorkerSpecification) { +abstract class UnixSocketWorkerConnection(val socketPath: String) + extends WorkerConnection { + + override def close(): Unit = { + val f = new File(socketPath) + if (f.exists()) f.delete() + } } diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerConnection.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerConnection.scala new file mode 100644 index 0000000000000..82b2fff8df585 --- /dev/null +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerConnection.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * A transport-level connection to a running UDF worker process. + * + * A [[WorkerConnection]] represents the communication channel between the + * Spark engine and a single worker process (e.g., a gRPC channel over a + * Unix domain socket, or a raw TCP socket). It is owned by a worker + * process wrapper (e.g., [[direct.DirectWorkerProcess]]) and shared + * across all [[WorkerSession]]s that use that process. + * + * One connection, many sessions: the worker exposes a single server-side + * endpoint that all sessions share. For gRPC, per-session work lives on + * multiplexed streams over this channel. + * + * Implementations expose only lifecycle. Data transmission happens at + * the [[WorkerSession]] level -- this class is solely about whether the + * channel is open. + * + * '''Relationship to other classes (direct creation mode):''' + * {{{ + * DirectWorkerProcess 1 --- 1 WorkerConnection (transport over UDS) + * DirectWorkerProcess 1 --- * WorkerSession (UDF executions) + * }}} + */ +@Experimental +abstract class WorkerConnection extends AutoCloseable { + /** Returns true if the underlying transport channel is still usable. */ + def isActive: Boolean +} diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerDispatcher.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerDispatcher.scala index 58fabbaea00df..008cfc2993a09 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerDispatcher.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerDispatcher.scala @@ -17,11 +17,11 @@ package org.apache.spark.udf.worker.core import org.apache.spark.annotation.Experimental -import org.apache.spark.udf.worker.WorkerSpecification +import org.apache.spark.udf.worker.UDFWorkerSpecification /** * :: Experimental :: - * Manages workers for a single [[WorkerSpecification]] and hides worker details from Spark. + * Manages workers for a single [[UDFWorkerSpecification]] and hides worker details from Spark. * * A [[WorkerDispatcher]] is created from a worker specification (plus context such * as security scope). It owns the underlying worker processes and connections, @@ -31,7 +31,7 @@ import org.apache.spark.udf.worker.WorkerSpecification @Experimental trait WorkerDispatcher extends AutoCloseable { - def workerSpec: WorkerSpecification + def workerSpec: UDFWorkerSpecification /** * Creates a [[WorkerSession]] that maps to one single UDF execution. diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerLogger.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerLogger.scala new file mode 100644 index 0000000000000..a8f135f688908 --- /dev/null +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerLogger.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Minimal logging surface used by the udf/worker framework. + * + * The framework deliberately does not depend on SLF4J (or any other + * concrete logging backend) so callers can embed it without dragging a + * specific logger onto the classpath. Embedders should supply an + * adapter that forwards to their preferred backend (Spark's `Logging` + * trait, SLF4J, java.util.logging, etc.). + * + * Only the methods actually used by the framework are exposed. + * Messages are passed by-name so the formatting cost is avoided when + * the backend decides to drop the event. + */ +@Experimental +trait WorkerLogger { + def warn(msg: => String): Unit + def warn(msg: => String, t: Throwable): Unit + def debug(msg: => String): Unit + def debug(msg: => String, t: Throwable): Unit +} + +object WorkerLogger { + /** Discards all messages. Default for callers that don't wire up logging. */ + val NoOp: WorkerLogger = new WorkerLogger { + override def warn(msg: => String): Unit = () + override def warn(msg: => String, t: Throwable): Unit = () + override def debug(msg: => String): Unit = () + override def debug(msg: => String, t: Throwable): Unit = () + } +} diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala index 83c392a895b66..f4c4091688c94 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala @@ -16,22 +16,133 @@ */ package org.apache.spark.udf.worker.core +import java.util.concurrent.atomic.AtomicBoolean + import org.apache.spark.annotation.Experimental /** * :: Experimental :: - * Represents one single UDF execution. + * Carries all information needed to initialize a UDF execution on a worker. + * + * This message is passed to [[WorkerSession#init]] and contains the function + * definition, schemas, and any additional configuration. * - * A [[WorkerSession]] is obtained from [[WorkerDispatcher#createSession]] and - * can carry per-execution state for that UDF invocation. Implementations may - * add concrete data-processing methods and lifecycle hooks as needed. + * Placeholder: will be replaced by a generated proto message once the + * UDF wire protocol lands. Do not rely on case-class equality -- + * `Array[Byte]` fields compare by reference. * - * A WorkerSession is not 1-to-1 mapped to an actual worker process. Multiple - * WorkerSessions may be backed by the same worker when the worker is reused. - * Worker reuse and pooling are managed by each [[WorkerDispatcher]] - * implementation based on the [[WorkerSpecification]]. + * @param functionPayload serialized function (e.g., pickled Python, JVM bytes) + * @param inputSchema serialized input schema (e.g., Arrow schema bytes) + * @param outputSchema serialized output schema (e.g., Arrow schema bytes) + * @param properties additional key-value configuration. Can carry + * protocol-specific or engine-specific metadata that + * does not yet have a dedicated field. + */ +@Experimental +case class InitMessage( + functionPayload: Array[Byte], + inputSchema: Array[Byte], + outputSchema: Array[Byte], + properties: Map[String, String] = Map.empty) + +/** + * :: Experimental :: + * One UDF execution on a worker -- the main interface Spark uses to run UDFs. + * + * A [[WorkerSession]] is the '''per-UDF-invocation''' handle that Spark + * obtains from [[WorkerDispatcher#createSession]]. It carries the full + * init / data-stream / finish lifecycle for a single UDF evaluation. + * + * A [[WorkerSession]] does ''not'' own the underlying worker or its + * transport channel -- those are managed by the [[WorkerDispatcher]]. + * Multiple sessions may share the same worker when the worker supports + * concurrency. + * + * '''Usage:''' + * {{{ + * val session = dispatcher.createSession(securityScope = None) + * try { + * session.init(InitMessage(functionPayload, inputSchema, outputSchema)) + * val results = session.process(inputBatches) + * results.foreach(handleBatch) + * } finally { + * session.close() + * } + * }}} + * + * '''Lifecycle:''' + * - [[init]] must be called exactly once before [[process]]. + * - [[process]] must be called at most once per session. + * - [[close]] must always be called (use try-finally). + * - [[cancel]] may be called at any time to abort execution. + * + * The lifecycle is enforced here: [[init]] and [[process]] are `final` + * and delegate to [[doInit]] / [[doProcess]] after AtomicBoolean guards. + * Subclasses implement the protocol-specific work and do not re-check + * the contract. */ @Experimental abstract class WorkerSession extends AutoCloseable { - override def close(): Unit = {} + + private val initialized = new AtomicBoolean(false) + private val processed = new AtomicBoolean(false) + + /** + * Initializes the UDF execution. Must be called exactly once before + * [[process]]. + * + * Throws `IllegalStateException` if called more than once. + * + * @param message the initialization parameters including the serialized + * function, input/output schemas, and configuration. + */ + final def init(message: InitMessage): Unit = { + if (!initialized.compareAndSet(false, true)) { + throw new IllegalStateException("init has already been called on this session") + } + doInit(message) + } + + /** + * Processes input data through the worker and returns results. + * + * Follows Spark's Iterator-to-Iterator pattern: input batches are streamed + * to the worker, and result batches are lazily pulled from the returned + * iterator. The session sends a Finish signal to the worker when the input + * iterator is exhausted. + * + * Must be called after [[init]] and at most once per session. + * Throws `IllegalStateException` if called before [[init]] or more than once. + * + * @param input iterator of raw input data batches (e.g., Arrow IPC) + * @return iterator of raw result data batches + */ + final def process(input: Iterator[Array[Byte]]): Iterator[Array[Byte]] = { + if (!initialized.get()) { + throw new IllegalStateException("process called before init") + } + if (!processed.compareAndSet(false, true)) { + throw new IllegalStateException("process has already been called on this session") + } + doProcess(input) + } + + /** Subclass hook for [[init]]. Called once, after the guard. */ + protected def doInit(message: InitMessage): Unit + + /** Subclass hook for [[process]]. Called at most once, after the guard. */ + protected def doProcess(input: Iterator[Array[Byte]]): Iterator[Array[Byte]] + + /** + * Requests cancellation of the current UDF execution. + * + * '''Thread-safety:''' implementations must allow [[cancel]] to be called + * from a thread different from the one driving [[process]] (typically a + * task interruption thread). It may be invoked at any point after + * [[init]] and should be a no-op if execution has already finished. + */ + def cancel(): Unit + + /** Closes this session and releases resources. */ + override def close(): Unit } diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectUnixSocketWorkerDispatcher.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectUnixSocketWorkerDispatcher.scala new file mode 100644 index 0000000000000..8da0354187e4f --- /dev/null +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectUnixSocketWorkerDispatcher.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core.direct + +import java.io.File +import java.nio.file.{Files, Path} +import java.nio.file.attribute.PosixFilePermissions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.udf.worker.UDFWorkerSpecification +import org.apache.spark.udf.worker.core.{UnixSocketWorkerConnection, WorkerLogger} +import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.SOCKET_POLL_INTERVAL_MS + +/** + * :: Experimental :: + * A [[DirectWorkerDispatcher]] using Unix domain sockets as the worker + * transport. Allocates a private 0700 socket directory at construction; + * each worker is given a UDS path inside it. + * + * Concrete subclasses implement [[createConnection]] (with a UDS protocol + * of choice) and [[createSessionForWorker]]. + */ +@Experimental +abstract class DirectUnixSocketWorkerDispatcher( + workerSpec: UDFWorkerSpecification, + logger: WorkerLogger = WorkerLogger.NoOp) + extends DirectWorkerDispatcher(workerSpec, logger) { + + // Removed explicitly in closeTransport(). deleteOnExit is avoided because + // the JDK retains the path for the JVM lifetime, which leaks in + // long-lived drivers. + private val socketDir: Path = createPrivateTempDirectory() + + override protected def newEndpointAddress(workerId: String): String = + socketDir.resolve(s"worker-$workerId.sock").toString + + override protected def waitForReady( + address: String, + process: Process, + outputFile: File): Unit = { + val file = new File(address) + // At least one poll so very small initTimeouts don't trip a premature + // timeout before the worker has any chance to create the socket. + val maxAttempts = math.max(1, (initTimeoutMs / SOCKET_POLL_INTERVAL_MS).toInt) + var attempts = 0 + while (!file.exists() && attempts < maxAttempts) { + if (!process.isAlive) throwWorkerExitedBeforeSocket(process, address, outputFile) + Thread.sleep(SOCKET_POLL_INTERVAL_MS) + attempts += 1 + } + if (!file.exists()) { + if (process.isAlive) { + DirectWorkerDispatcher.destroyForciblyAndReap( + process, logger, s"init timeout $address") + val tail = readOutputTail(outputFile) + throw new DirectWorkerTimeoutException( + s"Worker did not create socket at $address within ${initTimeoutMs}ms\n$tail") + } else { + // Worker exited after the last poll without creating the socket; + // prefer the exit-code message over the ambiguous "did not create". + throwWorkerExitedBeforeSocket(process, address, outputFile) + } + } + } + + override protected def cleanupEndpointAddress(address: String): Unit = { + Files.deleteIfExists(new File(address).toPath) + } + + override protected def closeTransport(): Unit = { + val dir = socketDir.toFile + if (dir.exists()) { + val remaining = dir.listFiles() + if (remaining != null) remaining.foreach(_.delete()) + dir.delete() + } + } + + override protected def validateTransportSupport(): Unit = { + val props = workerSpec.getDirect.getProperties + require(props.hasConnection, + "DirectWorker.properties.connection must be set") + val conn = props.getConnection + require(conn.hasUnixDomainSocket, + "DirectUnixSocketWorkerDispatcher requires UNIX domain socket transport, " + + s"got ${conn.getTransportCase}") + } + + override protected def createConnection(address: String): UnixSocketWorkerConnection + + private def throwWorkerExitedBeforeSocket( + process: Process, + address: String, + outputFile: File): Nothing = { + val tail = readOutputTail(outputFile) + throw new DirectWorkerException( + s"Worker exited with code ${process.exitValue()} " + + s"before creating socket at $address\n$tail") + } + + /** + * Creates a temp directory with owner-only permissions (0700 on POSIX). + * On non-POSIX filesystems falls back to best-effort `File.setXxx`, + * which is TOCTOU-racy and weaker; a WARN surfaces if the platform + * refuses the setters. + */ + private def createPrivateTempDirectory(): Path = { + val attr = PosixFilePermissions.asFileAttribute( + PosixFilePermissions.fromString("rwx------")) + try { + Files.createTempDirectory("spark-udf-worker", attr) + } catch { + case _: UnsupportedOperationException => + val dir = Files.createTempDirectory("spark-udf-worker") + val f = dir.toFile + // `&` (non-short-circuiting) so every setter is attempted even if + // an earlier one refused. + val applied = + f.setReadable(false, false) & f.setWritable(false, false) & + f.setExecutable(false, false) & f.setReadable(true, true) & + f.setWritable(true, true) & f.setExecutable(true, true) + if (!applied) { + logger.warn( + s"Could not fully restrict permissions on $dir; socket " + + s"directory may be accessible to other local users on this " + + s"filesystem") + } + dir + } + } +} diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala new file mode 100644 index 0000000000000..afaf23791d80f --- /dev/null +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala @@ -0,0 +1,532 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core.direct + +import java.io.{BufferedReader, File, FileInputStream, InputStreamReader} +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, Path} +import java.util.UUID +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.mutable.{Queue => MQueue} +import scala.jdk.CollectionConverters._ +import scala.util.control.NonFatal + +import org.apache.spark.annotation.Experimental +import org.apache.spark.udf.worker.{ProcessCallable, UDFWorkerSpecification} +import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerDispatcher, + WorkerLogger, WorkerSecurityScope, WorkerSession} +import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.{CallableResult, + DEFAULT_CALLABLE_TIMEOUT_MS, DEFAULT_GRACEFUL_TIMEOUT_MS, DEFAULT_INIT_TIMEOUT_MS, + ENGINE_MAX_TIMEOUT_MS, EnvironmentState, MAX_OUTPUT_SCAN_BYTES, + PROCESS_OUTPUT_TAIL_LINES} + +/** + * :: Experimental :: + * A [[WorkerDispatcher]] that creates workers by spawning local OS processes + * ("direct" creation mode from the worker specification). + * + * On the first [[createSession]], the dispatcher ensures the environment is + * ready (verify / install) and registers the cleanup hook. Each session + * currently gets a fresh worker that is terminated when the session closes + * (the single-reference case of the future pooling policy). + * + * Subclasses implement [[createConnection]] and [[createSessionForWorker]] + * to provide protocol-specific behavior (e.g., gRPC, raw sockets). + * + * For workers obtained through a provisioning service or daemon (indirect + * creation), see the `indirect` package (TODO). + * + * @param workerSpec worker specification (proto) + * @param logger [[WorkerLogger]] used for dispatcher-internal messages. + * The framework does not depend on any concrete logging + * backend; callers should pass an adapter that forwards + * to their preferred logger (Spark's `Logging` trait, + * SLF4J, etc.). Defaults to [[WorkerLogger.NoOp]]. + */ +@Experimental +abstract class DirectWorkerDispatcher( + override val workerSpec: UDFWorkerSpecification, + protected val logger: WorkerLogger = WorkerLogger.NoOp) + extends WorkerDispatcher { + + // TODO: Connection pooling -- reuse idle workers across sessions. + // TODO: Security scope isolation -- partition pool by WorkerSecurityScope. + + validateTransportSupport() + validateEnvironmentCallables() + + /** + * Maximum time to wait for a setup/verify/cleanup callable to finish. + * Subclasses may override this to accommodate slow installation steps + * (e.g., a large dependency install). Defaults to 120 seconds. + */ + protected def callableTimeoutMs: Long = DEFAULT_CALLABLE_TIMEOUT_MS + + // Proto-provided timeouts are clamped to ENGINE_MAX_TIMEOUT_MS. The + // dispatcher-internal callableTimeoutMs above is subclass-controlled and + // not subject to the cap. + // Package-private for test access. + private[core] val initTimeoutMs: Long = { + val props = workerSpec.getDirect.getProperties + val raw = if (props.hasInitializationTimeoutMs && props.getInitializationTimeoutMs > 0) { + props.getInitializationTimeoutMs.toLong + } else { + DEFAULT_INIT_TIMEOUT_MS + } + clampTimeout("initialization_timeout_ms", raw) + } + + private val gracefulTimeoutMs: Long = { + val props = workerSpec.getDirect.getProperties + val raw = if (props.hasGracefulTerminationTimeoutMs && + props.getGracefulTerminationTimeoutMs > 0) { + props.getGracefulTerminationTimeoutMs.toLong + } else { + DEFAULT_GRACEFUL_TIMEOUT_MS + } + clampTimeout("graceful_termination_timeout_ms", raw) + } + + private def clampTimeout(field: String, raw: Long): Long = { + if (raw > ENGINE_MAX_TIMEOUT_MS) { + logger.warn( + s"Worker-provided $field=${raw}ms exceeds engine maximum " + + s"${ENGINE_MAX_TIMEOUT_MS}ms; using ${ENGINE_MAX_TIMEOUT_MS}ms instead") + ENGINE_MAX_TIMEOUT_MS + } else { + raw + } + } + + private[this] val workers = new ConcurrentHashMap[String, DirectWorkerProcess]() + private[this] val closed = new AtomicBoolean(false) + + @volatile private var environmentState: EnvironmentState = EnvironmentState.Pending + private val environmentLock = new Object + private[this] var cleanupHook: Option[Thread] = None + + /** + * Allocates a fresh endpoint address for a new worker. The string is + * passed to the worker binary as `--connection
`. + */ + protected def newEndpointAddress(workerId: String): String + + /** + * Waits for the worker process to be ready to accept connections at + * `address`. Throws [[DirectWorkerTimeoutException]] on timeout, or + * [[DirectWorkerException]] if the process exits early. + */ + protected def waitForReady( + address: String, + process: Process, + outputFile: File): Unit + + /** + * Best-effort per-endpoint cleanup, called from the spawn-failure path + * before any [[WorkerArtifacts]] / [[WorkerConnection]] exists. + */ + protected def cleanupEndpointAddress(address: String): Unit + + /** + * Cleans up dispatcher-level transport state (e.g., a UDS socket + * directory). Called from [[close]]. + */ + protected def closeTransport(): Unit + + /** + * Validates the worker spec's transport choice. Subclasses declare + * which transports they support. Called from the base constructor; + * implementations must only read base-class state (`workerSpec`). + */ + protected def validateTransportSupport(): Unit + + /** Creates a protocol-specific connection to a worker at the given address. */ + protected def createConnection(address: String): WorkerConnection + + /** Creates a protocol-specific session for the given worker. */ + protected def createSessionForWorker(worker: DirectWorkerProcess): WorkerSession + + override def createSession( + securityScope: Option[WorkerSecurityScope]): WorkerSession = { + require(securityScope.isEmpty, + "securityScope is not supported yet; pass None until pooling lands") + if (closed.get()) throwClosed() + ensureEnvironmentReady() + val worker = spawnWorker() + // Acquire before publish: a concurrent close() iterating `workers` must + // not tear down this worker before we hand it to the caller. + worker.acquireSession() + workers.put(worker.id, worker) + // Re-check for close() that ran concurrently. Releasing fires the + // ref-count callback, which removes and tears down the worker. + if (closed.get()) { + worker.releaseSession() + throwClosed() + } + try { + createSessionForWorker(worker) + } catch { + case e: InterruptedException => + Thread.currentThread().interrupt() + worker.releaseSession() + throw e + case NonFatal(e) => + worker.releaseSession() + throw e + } + } + + /** + * Invoked when a worker's last session closes. Terminates the worker + * today; future pooling can reuse it here instead. Safe to call after + * dispatcher close -- the worker's own CAS-idempotent close makes a + * second teardown a no-op. + */ + private def releaseWorker(worker: DirectWorkerProcess): Unit = { + workers.remove(worker.id) + try { + worker.close() + } catch { + case NonFatal(e) => + logger.warn(s"Error closing worker ${worker.id}", e) + } + } + + private def throwClosed(): Nothing = + throw new IllegalStateException("Dispatcher is closed") + + /** + * Terminates tracked workers, removes the socket directory, and runs + * environment cleanup. Idempotent via CAS. Does not drain in-flight + * createSession calls -- a worker spawned racing with close tears + * itself down through the ref-count callback, which may outlive this + * method. + */ + override def close(): Unit = { + if (!closed.compareAndSet(false, true)) { + return + } + // TODO: close workers in parallel -- today shutdown is serialised at + // N * gracefulTimeoutMs worst case. + workers.values().iterator().asScala.foreach { w => + try { + w.close() + } catch { + case NonFatal(e) => + logger.warn(s"Error closing worker ${w.id}", e) + } + } + workers.clear() + try closeTransport() catch { + case NonFatal(e) => + logger.warn("Error cleaning up transport state", e) + } + deregisterEnvironmentCleanupHook() + runEnvironmentCleanup() + } + + // -- Environment lifecycle ------------------------------------------------- + + // TODO: distinguish retriable vs permanent environment failures. + private def ensureEnvironmentReady(): Unit = { + environmentLock.synchronized { + environmentState match { + case EnvironmentState.Ready | EnvironmentState.CleanedUp => + case EnvironmentState.Failed(msg) => + throw new DirectWorkerException(s"Environment setup previously failed: $msg") + case EnvironmentState.Pending => + val env = workerSpec.getEnvironment + // Register up front so a partially-successful install still gets + // torn down at JVM shutdown if dispatcher.close is never called. + // No-op when environment_cleanup is not configured. + registerEnvironmentCleanupHook() + val verified = env.hasEnvironmentVerification && + runCallable(env.getEnvironmentVerification).exitCode == 0 + if (!verified && env.hasInstallation) { + // Treat any install failure (timeout or non-zero exit) as + // permanent. A partially-completed install can leave files on + // disk that a retry would race with; retry policy belongs in + // the future predicate (see TODO above). + val result = try { + runCallable(env.getInstallation) + } catch { + case e: DirectWorkerException => + environmentState = EnvironmentState.Failed( + s"installation failed: ${e.getMessage}") + throw e + } + if (result.exitCode != 0) { + val detail = s"exit code ${result.exitCode}\n${result.outputTail}" + environmentState = EnvironmentState.Failed(detail) + throw new DirectWorkerException( + s"Environment installation failed with $detail") + } + } + environmentState = EnvironmentState.Ready + } + } + } + + // TODO: share one JVM shutdown hook across all dispatchers in the + // process. Each live dispatcher is retained by the JVM until shutdown. + + /** Registers the JVM shutdown hook that runs the cleanup callable. */ + private def registerEnvironmentCleanupHook(): Unit = { + if (!Thread.holdsLock(environmentLock)) { + throw new IllegalStateException( + "registerEnvironmentCleanupHook must be called while holding environmentLock") + } + if (cleanupHook.isDefined) return + if (workerSpec.getEnvironment.hasEnvironmentCleanup) { + val hook = new Thread(() => runEnvironmentCleanup(), "udf-env-cleanup") + cleanupHook = Some(hook) + // scalastyle:off runtimeaddshutdownhook + Runtime.getRuntime.addShutdownHook(hook) + // scalastyle:on runtimeaddshutdownhook + } + } + + private def deregisterEnvironmentCleanupHook(): Unit = { + environmentLock.synchronized { + cleanupHook.foreach { hook => + try { + Runtime.getRuntime.removeShutdownHook(hook) + } catch { + case _: IllegalStateException => // JVM already shutting down + } + cleanupHook = None + } + } + } + + private def runEnvironmentCleanup(): Unit = { + environmentLock.synchronized { + environmentState match { + case EnvironmentState.CleanedUp => + case _ => + if (workerSpec.getEnvironment.hasEnvironmentCleanup) { + try { + val result = runCallable(workerSpec.getEnvironment.getEnvironmentCleanup) + if (result.exitCode != 0) { + logger.warn(s"Environment cleanup exited with code ${result.exitCode}" + + s"\n${result.outputTail}") + } + } catch { + case NonFatal(e) => logger.warn("Environment cleanup failed", e) + } + } + environmentState = EnvironmentState.CleanedUp + } + } + } + + // -- Process helpers ------------------------------------------------------- + + /** + * Runs a [[ProcessCallable]] synchronously and returns the result. + * Always throws on timeout; callers check `exitCode` for non-timeout failures. + */ + private[core] def runCallable(callable: ProcessCallable): CallableResult = { + val cmd = (callable.getCommandList.asScala ++ callable.getArgumentsList.asScala).toSeq + require(cmd.nonEmpty, + "ProcessCallable must have at least one entry in command or arguments") + val outputFile = Files.createTempFile("udf-callable-", ".log") + try { + val process = launchProcess( + cmd, callable.getEnvironmentVariablesMap.asScala.toMap, outputFile.toFile) + val timeoutMs = callableTimeoutMs + if (!process.waitFor(timeoutMs, TimeUnit.MILLISECONDS)) { + DirectWorkerDispatcher.destroyForciblyAndReap( + process, logger, s"callable timeout: ${cmd.head}") + val tail = readOutputTail(outputFile.toFile) + throw new DirectWorkerTimeoutException( + s"Callable timed out after ${timeoutMs}ms: " + + s"${cmd.mkString(" ")}\n$tail") + } + val tail = readOutputTail(outputFile.toFile) + CallableResult(process.exitValue(), tail) + } finally { + Files.deleteIfExists(outputFile) + } + } + + private def spawnWorker(): DirectWorkerProcess = { + val runner = workerSpec.getDirect.getRunner + val baseCmd = (runner.getCommandList.asScala ++ runner.getArgumentsList.asScala).toSeq + require(baseCmd.nonEmpty, + "DirectWorker.runner must have at least one entry in command or arguments") + val workerId = UUID.randomUUID().toString + val address = newEndpointAddress(workerId) + // Proto contract: the engine must pass --id and --connection. + val cmd = baseCmd ++ Seq("--id", workerId, "--connection", address) + val env = runner.getEnvironmentVariablesMap.asScala.toMap + val outputFile = Files.createTempFile("udf-worker-", ".log") + val process = launchProcess(cmd, env, outputFile.toFile) + + try { + waitForReady(address, process, outputFile.toFile) + val connection = createConnection(address) + val artifacts = new WorkerArtifacts(process, connection, outputFile, logger) + new DirectWorkerProcess( + workerId, artifacts, gracefulTimeoutMs, logger, + onLastSessionReleased = releaseWorker) + } catch { + case e: InterruptedException => + Thread.currentThread().interrupt() + cleanupRawSpawn(process, address, outputFile) + throw e + case NonFatal(e) => + cleanupRawSpawn(process, address, outputFile) + throw e + } + } + + // Pre-WorkerArtifacts cleanup: the connection has not been built yet, + // so we have no bundle to close(). Each step is independent. + private def cleanupRawSpawn(p: Process, address: String, outputFile: Path): Unit = { + DirectWorkerDispatcher.destroyForciblyAndReap(p, logger, "failed spawn") + try cleanupEndpointAddress(address) catch { + case NonFatal(e) => + logger.debug(s"Failed to clean up endpoint address $address", e) + } + try Files.deleteIfExists(outputFile) catch { + case NonFatal(e) => + logger.debug(s"Failed to clean up worker output file $outputFile", e) + } + } + + /** + * Starts an OS process. stdout and stderr are merged and redirected to the + * given file so that output can be read back for error reporting. + */ + private def launchProcess( + command: Seq[String], + env: Map[String, String], + outputFile: File): Process = { + val builder = new ProcessBuilder(command: _*) + env.foreach { case (k, v) => builder.environment().put(k, v) } + builder.redirectErrorStream(true) + builder.redirectOutput(outputFile) + builder.start() + } + + // Bounded scan so a runaway worker that writes gigabytes of output does + // not OOM the caller during error reporting. + protected def readOutputTail(file: File): String = { + if (!file.exists() || file.length() == 0) return "" + val fileLen = file.length() + val startPos = math.max(0L, fileLen - MAX_OUTPUT_SCAN_BYTES) + val fis = new FileInputStream(file) + try { + if (startPos > 0) fis.getChannel.position(startPos) + val reader = new BufferedReader( + new InputStreamReader(fis, StandardCharsets.UTF_8)) + // Discard the first (partial) line when we seeked into the middle. + if (startPos > 0) reader.readLine() + val buffer = new MQueue[String]() + var line = reader.readLine() + while (line != null) { + if (buffer.size >= PROCESS_OUTPUT_TAIL_LINES) buffer.dequeue() + buffer.enqueue(line) + line = reader.readLine() + } + if (buffer.isEmpty) "" + else "Process output (last lines):\n" + buffer.mkString("\n") + } catch { + case NonFatal(e) => + logger.debug(s"Failed to read process output from $file", e) + "" + } finally { + fis.close() + } + } + + // -- Spec validation ------------------------------------------------------- + + // Verification exists to short-circuit installation when the environment + // is already prepared, so requiring installation alongside verification + // catches user errors at spec-validation time. + private def validateEnvironmentCallables(): Unit = { + val env = workerSpec.getEnvironment + require(!env.hasEnvironmentVerification || env.hasInstallation, + "WorkerEnvironment.environment_verification requires installation to be set") + } +} + +private[direct] object DirectWorkerDispatcher { + private[direct] val SOCKET_POLL_INTERVAL_MS = 100L + private[direct] val DEFAULT_INIT_TIMEOUT_MS = 10000L + private[direct] val DEFAULT_CALLABLE_TIMEOUT_MS = 120000L + private[direct] val DEFAULT_GRACEFUL_TIMEOUT_MS = 5000L + // Engine-side cap on proto-provided worker timeouts. The defaults below + // must stay at or under this cap so the clamp only fires on + // user-provided values. + private[direct] val ENGINE_MAX_TIMEOUT_MS = 30000L + require(DEFAULT_INIT_TIMEOUT_MS <= ENGINE_MAX_TIMEOUT_MS && + DEFAULT_GRACEFUL_TIMEOUT_MS <= ENGINE_MAX_TIMEOUT_MS, + "default timeouts must not exceed ENGINE_MAX_TIMEOUT_MS") + private[direct] val PROCESS_OUTPUT_TAIL_LINES = 50 + private[direct] val MAX_OUTPUT_SCAN_BYTES = 1024L * 1024L // 1 MiB + // 5s bounds the wait for the kernel to reap a SIGKILL'd child. SIGKILL + // is unblockable, so exceeding this usually means the process is stuck + // in uninterruptible I/O (D-state) and further waiting will not help. + private[direct] val SIGKILL_REAP_TIMEOUT_MS = 5000L + + /** + * SIGKILL `process` and wait up to [[SIGKILL_REAP_TIMEOUT_MS]] for the + * kernel to reap it. `destroyForcibly()` alone returns before the child + * is reaped, which leaks a zombie until JVM exit. On reap-timeout logs + * a warning; on interrupt re-raises the interrupt and returns. + * + * @param context short tag included in the timeout warning so operators + * can correlate a stuck child with its source. + */ + private[direct] def destroyForciblyAndReap( + process: Process, + logger: WorkerLogger, + context: String = ""): Unit = { + if (!process.isAlive) return + process.destroyForcibly() + val reaped = try { + process.waitFor(SIGKILL_REAP_TIMEOUT_MS, TimeUnit.MILLISECONDS) + } catch { + case _: InterruptedException => + Thread.currentThread().interrupt() + return + } + if (!reaped && process.isAlive) { + val suffix = if (context.nonEmpty) s" [$context]" else "" + logger.warn( + s"Process ${process.pid()}$suffix still alive ${SIGKILL_REAP_TIMEOUT_MS}ms " + + s"after SIGKILL; leaving behind as zombie " + + s"(likely stuck in uninterruptible kernel state)") + } + } + + /** Result of running a [[ProcessCallable]]. */ + private[core] case class CallableResult(exitCode: Int, outputTail: String) + + private[direct] sealed trait EnvironmentState + private[direct] object EnvironmentState { + case object Pending extends EnvironmentState + case object Ready extends EnvironmentState + case class Failed(detail: String) extends EnvironmentState + case object CleanedUp extends EnvironmentState + } +} diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerException.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerException.scala new file mode 100644 index 0000000000000..b0ece15eae38f --- /dev/null +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerException.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core.direct + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Thrown by [[DirectWorkerDispatcher]] for runtime failures: worker + * spawn problems, environment setup or cleanup failures, callable + * timeouts, and socket-establishment timeouts. + * + * Distinguished from `IllegalArgumentException` (bad spec) and + * `IllegalStateException` (using a closed dispatcher), which indicate + * programming errors. Catching this type lets callers handle runtime + * failures specifically without catching every `RuntimeException`. + */ +@Experimental +class DirectWorkerException(message: String, cause: Throwable = null) + extends RuntimeException(message, cause) + +/** + * :: Experimental :: + * A [[DirectWorkerException]] caused specifically by a timeout: a worker + * that did not bind its socket within `initialization_timeout_ms`, or a + * setup callable (verify / install / cleanup) that exceeded + * `callableTimeoutMs`. Exposed as a distinct type so callers can choose + * different retry / escalation paths for timeouts vs other failures. + */ +@Experimental +class DirectWorkerTimeoutException(message: String, cause: Throwable = null) + extends DirectWorkerException(message, cause) diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala new file mode 100644 index 0000000000000..f4b5c1df63193 --- /dev/null +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core.direct + +import java.nio.file.{Files, Path} +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} + +import scala.util.control.NonFatal + +import org.apache.spark.annotation.Experimental +import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerLogger} + +/** + * :: Experimental :: + * A locally-spawned OS process running a UDF worker, together with its + * transport connection. Wraps a [[WorkerArtifacts]] bundle (process + + * connection + output log) plus a session ref-count scaffolding for + * future pooling -- today one process per session. + * + * Closing sends SIGTERM, waits up to [[gracefulTimeoutMs]], then + * delegates connection close + forced kill + file cleanup to + * [[WorkerArtifacts.close]]. + * + * @param id stable worker identifier (UUID passed to the binary as `--id`). + * @param artifacts process + connection + output-log, disposed together. + * @param gracefulTimeoutMs wait after SIGTERM before escalating to SIGKILL. + * @param logger [[WorkerLogger]] for process-level messages. + * @param onLastSessionReleased fires when the ref-count hits 0. Runs on + * the thread calling [[releaseSession]]. May fire more than once + * across a worker's lifetime; a concurrent `acquireSession` can + * re-increment the count before the callback returns, so pooling + * dispatchers must arbitrate reuse themselves. + */ +@Experimental +class DirectWorkerProcess( + val id: String, + private[direct] val artifacts: WorkerArtifacts, + val gracefulTimeoutMs: Long, + protected val logger: WorkerLogger = WorkerLogger.NoOp, + private[direct] val onLastSessionReleased: DirectWorkerProcess => Unit = _ => ()) + extends AutoCloseable { + + // TODO: idle-timeout tracking and concurrent session capacity. + + private val activeSessionCount = new AtomicInteger(0) + private val closed = new AtomicBoolean(false) + + /** The OS process handle for this worker. */ + def process: Process = artifacts.process + + /** The transport connection for this worker. */ + def connection: WorkerConnection = artifacts.connection + + /** Path to the merged stdout/stderr log for this worker. */ + def outputFile: Path = artifacts.outputFile + + /** Number of sessions currently using this worker. */ + def activeSessions: Int = activeSessionCount.get() + + /** Increments the active session count. */ + def acquireSession(): Unit = activeSessionCount.incrementAndGet() + + /** + * Decrements the active session count. Fires [[onLastSessionReleased]] + * on the 0-transition. A negative count indicates an unbalanced + * acquire/release; we log and reset to 0 rather than silently mask it. + */ + def releaseSession(): Unit = { + val c = activeSessionCount.decrementAndGet() + if (c < 0) { + logger.warn( + s"releaseSession called without a matching acquireSession (count=$c)") + activeSessionCount.set(0) + } else if (c == 0) { + // Swallow callback errors so session.close cannot throw. + try onLastSessionReleased(this) catch { + case NonFatal(e) => + logger.warn(s"onLastSessionReleased callback failed for worker $id", e) + } + } + } + + /** Returns true if the OS process is running and the connection is usable. */ + def isAlive: Boolean = process.isAlive && connection.isActive + + /** + * Sends SIGTERM, waits up to [[gracefulTimeoutMs]] for the worker to + * exit, then disposes artifacts (connection close + SIGKILL + file + * cleanup). Idempotent via CAS. + */ + override def close(): Unit = { + if (!closed.compareAndSet(false, true)) return + + if (process.isAlive) { + process.destroy() // SIGTERM + try { + // Ignore the return value: artifacts.close() SIGKILLs if still + // alive and no-ops if already dead. + process.waitFor(gracefulTimeoutMs, TimeUnit.MILLISECONDS) + } catch { + case _: InterruptedException => + Thread.currentThread().interrupt() + } + } + + artifacts.close() + } +} + +/** + * Closeable bundle of per-worker OS resources: the child [[Process]], its + * transport [[WorkerConnection]], and its merged stdout/stderr log. + * [[close]] runs connection close (which for UDS removes the socket + * file), then SIGKILL-reaps the process, then deletes the output log. + * Graceful SIGTERM is the higher layer's responsibility (see + * [[DirectWorkerProcess#close]]). + */ +private[direct] final class WorkerArtifacts( + val process: Process, + val connection: WorkerConnection, + val outputFile: Path, + private[this] val logger: WorkerLogger) extends AutoCloseable { + + private[this] val closed = new AtomicBoolean(false) + + /** + * Idempotently closes the connection (transport teardown + any + * transport-specific cleanup such as deleting a UDS socket file), + * SIGKILL-reaps the process, and deletes the output log. Each step + * is guarded so a failure in one does not skip the next. + */ + override def close(): Unit = { + if (!closed.compareAndSet(false, true)) return + + try connection.close() catch { + case NonFatal(e) => + logger.warn("Error closing worker connection", e) + } + + DirectWorkerDispatcher.destroyForciblyAndReap(process, logger, "worker artifacts") + + try Files.deleteIfExists(outputFile) catch { + case NonFatal(e) => + logger.warn(s"Error cleaning up worker output file $outputFile", e) + } + } +} diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala new file mode 100644 index 0000000000000..7cdc5329350e3 --- /dev/null +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core.direct + +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark.annotation.Experimental +import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerSession} + +/** + * :: Experimental :: + * A [[WorkerSession]] backed by a locally-spawned [[DirectWorkerProcess]]. + * + * This is the session type returned by [[DirectWorkerDispatcher]]. It ties + * the session lifecycle to the worker's ref-count: the dispatcher increments + * the count before construction, and [[close]] decrements it, so the + * dispatcher knows when a worker process is idle and can be terminated or + * reused. + * + * Subclasses implement the protocol-specific data transmission + * ([[init]], [[process]], [[cancel]]). + * + * @param workerProcess the direct worker process backing this session. + * Internal to the `core` package and test code -- the + * worker handle is a dispatcher implementation detail, + * not part of the public WorkerSession API. + */ +@Experimental +abstract class DirectWorkerSession( + private[core] val workerProcess: DirectWorkerProcess) extends WorkerSession { + + private val released = new AtomicBoolean(false) + + /** The connection to the worker for this session. */ + def connection: WorkerConnection = workerProcess.connection + + override def close(): Unit = { + if (released.compareAndSet(false, true)) { + workerProcess.releaseSession() + } + } +} diff --git a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala new file mode 100644 index 0000000000000..60f5e2211b702 --- /dev/null +++ b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala @@ -0,0 +1,981 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core + +import java.io.File +import java.nio.file.{Files, Path} +import java.nio.file.attribute.PosixFileAttributeView + +import scala.jdk.CollectionConverters._ + +// scalastyle:off funsuite +import org.scalatest.BeforeAndAfterEach +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.spark.udf.worker.{ + DirectWorker, LocalTcpConnection, ProcessCallable, UDFWorkerProperties, + UDFWorkerSpecification, UnixDomainSocket, WorkerConnectionSpec, + WorkerEnvironment} +import org.apache.spark.udf.worker.core.direct.{DirectUnixSocketWorkerDispatcher, + DirectWorkerException, DirectWorkerProcess, DirectWorkerSession, + DirectWorkerTimeoutException} + +/** + * A [[WorkerConnection]] test implementation that considers the connection + * active as long as the socket file exists on disk. Inherits socket-file + * deletion from [[UnixSocketWorkerConnection.close]]. + */ +class SocketFileConnection(socketPath: String) + extends UnixSocketWorkerConnection(socketPath) { + override def isActive: Boolean = new File(socketPath).exists() +} + +/** + * A stub [[DirectWorkerSession]] for process-lifecycle tests that don't + * need actual data transmission. + * + * TODO: [[cancel]] is a no-op here. Once a concrete [[DirectWorkerSession]] + * with real data-plane wiring lands, add tests exercising cancel() in + * particular: cancel from a different thread than process(), cancel + * after process() has returned, and cancel before init (should be a + * no-op). Tracking the thread-safety contract in the docstring on + * [[org.apache.spark.udf.worker.core.WorkerSession.cancel]]. + */ +class StubWorkerSession( + workerProcess: DirectWorkerProcess) extends DirectWorkerSession(workerProcess) { + + override protected def doInit(message: InitMessage): Unit = {} + + override protected def doProcess( + input: Iterator[Array[Byte]]): Iterator[Array[Byte]] = + Iterator.empty + + override def cancel(): Unit = {} +} + +/** + * A [[DirectUnixSocketWorkerDispatcher]] subclass for testing that uses + * a socket-file connection and stub sessions instead of a real protocol + * implementation. + */ +class TestDirectWorkerDispatcher(spec: UDFWorkerSpecification) + extends DirectUnixSocketWorkerDispatcher(spec) { + + override protected def createConnection( + socketPath: String): UnixSocketWorkerConnection = + new SocketFileConnection(socketPath) + + override protected def createSessionForWorker( + worker: DirectWorkerProcess): WorkerSession = + new StubWorkerSession(worker) +} + +/** + * Tests for [[DirectWorkerDispatcher]] process lifecycle: spawning workers + * and terminating them on close. + */ +class DirectWorkerDispatcherSuite + extends AnyFunSuite with BeforeAndAfterEach { +// scalastyle:on funsuite + + private val echoWorkerScript = + """ + |#!/bin/bash + |SOCKET_PATH="" + |while [[ $# -gt 0 ]]; do + | case "$1" in + | --connection) SOCKET_PATH="$2"; shift 2 ;; + | *) shift ;; + | esac + |done + |cleanup() { rm -f "$SOCKET_PATH"; exit 0; } + |trap cleanup SIGTERM + |touch "$SOCKET_PATH" + |while true; do sleep 1; done + """.stripMargin.trim + + private def defaultRunner: ProcessCallable = ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c").addCommand(echoWorkerScript).addCommand("--") + .build() + + private def udsProperties: UDFWorkerProperties = UDFWorkerProperties.newBuilder() + .setConnection(WorkerConnectionSpec.newBuilder() + .setUnixDomainSocket(UnixDomainSocket.getDefaultInstance) + .build()) + .build() + + private def directWorker(runner: ProcessCallable): DirectWorker = + DirectWorker.newBuilder().setRunner(runner).setProperties(udsProperties).build() + + private def specWithRunner(runner: ProcessCallable): UDFWorkerSpecification = + UDFWorkerSpecification.newBuilder() + .setDirect(directWorker(runner)) + .build() + + private def specWithEnv( + runner: ProcessCallable = defaultRunner, + env: WorkerEnvironment): UDFWorkerSpecification = + UDFWorkerSpecification.newBuilder() + .setEnvironment(env) + .setDirect(directWorker(runner)) + .build() + + private var dispatcher: TestDirectWorkerDispatcher = _ + + override def afterEach(): Unit = { + if (dispatcher != null) { + dispatcher.close() + dispatcher = null + } + super.afterEach() + } + + // Narrow the publicly-typed WorkerSession returned by `createSession` back + // down to StubWorkerSession in one place, with a descriptive failure if + // the cast is ever wrong, so individual tests don't scatter `asInstanceOf` + // (which would throw ClassCastException rather than a useful message). + private def createStubSession(): StubWorkerSession = + dispatcher.createSession(None) match { + case stub: StubWorkerSession => stub + case other => fail( + s"Expected StubWorkerSession, got ${other.getClass.getSimpleName}") + } + + // The whole suite uses UDS as the only transport, so reaching past the + // generic WorkerConnection abstraction to read the socket path is fine. + private def udsPath(w: DirectWorkerProcess): String = w.connection match { + case uds: UnixSocketWorkerConnection => uds.socketPath + case other => fail( + s"Expected UnixSocketWorkerConnection, got ${other.getClass.getSimpleName}") + } + + test("creates a worker and session") { + dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) + + val session = createStubSession() + val worker = session.workerProcess + + assert(worker.isAlive, "worker should be alive after creation") + assert(worker.activeSessions == 1, "should have 1 active session") + assert(new File(udsPath(worker)).exists(), "socket file should exist") + + session.close() + assert(worker.activeSessions == 0, "should have 0 sessions after close") + } + + test("concurrent createSession calls produce distinct workers") { + dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) + + val threads = 8 + val sessions = new java.util.concurrent.ConcurrentLinkedQueue[StubWorkerSession]() + val startGate = new java.util.concurrent.CountDownLatch(1) + val doneGate = new java.util.concurrent.CountDownLatch(threads) + val errors = new java.util.concurrent.ConcurrentLinkedQueue[Throwable]() + + (1 to threads).foreach { _ => + new Thread(() => { + try { + startGate.await() + sessions.add(createStubSession()) + } catch { + case t: Throwable => errors.add(t) + } finally { + doneGate.countDown() + } + }).start() + } + startGate.countDown() + assert(doneGate.await(30, java.util.concurrent.TimeUnit.SECONDS), + "createSession threads did not finish in time") + + assert(errors.isEmpty, + s"unexpected errors during concurrent createSession: ${errors.toArray.mkString(", ")}") + assert(sessions.size == threads, "expected one session per thread") + + val sessionList = sessions.asScala.toList + val workerObjects = sessionList.map(_.workerProcess) + assert(workerObjects.distinct.length == threads, + "each session should have its own DirectWorkerProcess") + // Object-identity is not sufficient on its own: a future regression + // that accidentally shared underlying transport resources could still + // hand out distinct DirectWorkerProcess wrappers pointing at the same + // socket. Verify socket paths are unique too. + val socketPaths = workerObjects.map(udsPath) + assert(socketPaths.distinct.length == threads, + s"each worker should have its own socket path, got $socketPaths") + + sessionList.foreach(_.close()) + } + + test("close shuts down all workers via SIGTERM") { + dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) + + val session1 = createStubSession() + val session2 = createStubSession() + + val worker1 = session1.workerProcess + val worker2 = session2.workerProcess + + session1.close() + session2.close() + dispatcher.close() + dispatcher = null + + assert(!worker1.process.isAlive, "worker1 should be terminated") + assert(!worker2.process.isAlive, "worker2 should be terminated") + } + + test("close escalates to SIGKILL when worker ignores SIGTERM") { + // The worker traps SIGTERM so the graceful stop is ineffective; the + // dispatcher must escalate to SIGKILL via destroyForciblyAndReap. + // Using a short gracefulTimeoutMs (500ms) keeps the test bounded: + // max close time is gracefulTimeoutMs + SIGKILL_REAP_TIMEOUT_MS. + val sigtermIgnoringScript = + """ + |#!/bin/bash + |SOCKET_PATH="" + |while [[ $# -gt 0 ]]; do + | case "$1" in + | --connection) SOCKET_PATH="$2"; shift 2 ;; + | *) shift ;; + | esac + |done + |touch "$SOCKET_PATH" + |trap '' SIGTERM + |while true; do sleep 1; done + """.stripMargin.trim + val runner = ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c").addCommand(sigtermIgnoringScript).addCommand("--") + .build() + val shortGracefulProps = UDFWorkerProperties.newBuilder() + .setConnection(WorkerConnectionSpec.newBuilder() + .setUnixDomainSocket(UnixDomainSocket.getDefaultInstance).build()) + .setGracefulTerminationTimeoutMs(500) + .build() + val spec = UDFWorkerSpecification.newBuilder() + .setDirect(DirectWorker.newBuilder() + .setRunner(runner).setProperties(shortGracefulProps).build()) + .build() + dispatcher = new TestDirectWorkerDispatcher(spec) + + val session = createStubSession() + val worker = session.workerProcess + assert(worker.process.isAlive, "worker should be alive before close") + + val closeStart = System.nanoTime() + session.close() + val closeElapsedMs = (System.nanoTime() - closeStart) / 1000000L + + assert(!worker.process.isAlive, + s"worker should have been SIGKILLed after ignoring SIGTERM (took ${closeElapsedMs}ms)") + assert(closeElapsedMs >= 500L, + s"close should have waited for gracefulTimeoutMs before escalating, " + + s"took ${closeElapsedMs}ms") + } + + test("closing a session terminates its worker") { + dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) + + val session = createStubSession() + val worker = session.workerProcess + val socketFile = new File(udsPath(worker)) + + assert(worker.process.isAlive, "worker should be alive before session close") + assert(socketFile.exists(), "socket file should exist before session close") + + session.close() + + // The session-close path is synchronous: SIGTERM is sent and the process + // is reaped before `close` returns. + assert(!worker.process.isAlive, + "worker process should be terminated when the session closes") + assert(!socketFile.exists(), + "socket file should be cleaned up when the session closes") + } + + test("concurrent session.close and dispatcher.close do not double-close the worker") { + dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) + + val sessions = (1 to 4).map(_ => createStubSession()) + val workers = sessions.map(_.workerProcess) + + val barrier = new java.util.concurrent.CyclicBarrier(sessions.size + 1) + val errors = new java.util.concurrent.ConcurrentLinkedQueue[Throwable]() + + val sessionThreads = sessions.map { s => + val t = new Thread(() => { + try { + barrier.await() + s.close() + } catch { + case t: Throwable => errors.add(t) + } + }) + t.start() + t + } + + val dispatcherThread = new Thread(() => { + try { + barrier.await() + dispatcher.close() + } catch { + case t: Throwable => errors.add(t) + } + }) + dispatcherThread.start() + + sessionThreads.foreach(_.join(30000)) + dispatcherThread.join(30000) + dispatcher = null + + assert(errors.isEmpty, + s"unexpected errors during concurrent close: ${errors.toArray.mkString(", ")}") + workers.foreach { w => + assert(!w.process.isAlive, + s"worker at ${udsPath(w)} should be terminated after concurrent close") + } + } + + test("close racing with in-flight createSession does not leak the worker") { + // The acquire-before-publish + post-publish closed re-check pattern in + // createSession is designed for this race: thread A is mid-spawn when + // thread B calls close(). Thread A must either throw IllegalStateException + // (post-publish check caught the close) or receive a session whose worker + // is reaped by close()'s iteration. No orphan process or socket file + // should remain in either case. + val readyLatch = new java.util.concurrent.CountDownLatch(1) + val releaseLatch = new java.util.concurrent.CountDownLatch(1) + val capturedWorkers = + new java.util.concurrent.ConcurrentLinkedQueue[DirectWorkerProcess]() + val racing = new DirectUnixSocketWorkerDispatcher(specWithRunner(defaultRunner)) { + override protected def createConnection( + socketPath: String): UnixSocketWorkerConnection = + new SocketFileConnection(socketPath) + override protected def createSessionForWorker( + worker: DirectWorkerProcess): WorkerSession = { + capturedWorkers.add(worker) + readyLatch.countDown() + // Block here so dispatcher.close() runs while createSession is in + // flight. Use a generous wait so a slow CI doesn't time out. + if (!releaseLatch.await(30, java.util.concurrent.TimeUnit.SECONDS)) { + fail("releaseLatch never fired -- test orchestration broken") + } + new StubWorkerSession(worker) + } + } + try { + val outcome = + new java.util.concurrent.atomic.AtomicReference[Either[Throwable, WorkerSession]]() + val createThread = new Thread(() => { + try { + val s = racing.createSession(None) + outcome.set(Right(s)) + } catch { + case t: Throwable => outcome.set(Left(t)) + } + }, "createSession-racer") + createThread.start() + + // Wait for thread A to have published the worker and entered the + // blocking override. + assert(readyLatch.await(10, java.util.concurrent.TimeUnit.SECONDS), + "createSession thread never reached createSessionForWorker") + + val closeThread = new Thread(() => racing.close(), "close-racer") + closeThread.start() + // Give close() time to flip `closed` and iterate workers. + Thread.sleep(200) + + // Now release the in-flight createSession. + releaseLatch.countDown() + + createThread.join(10000) + closeThread.join(10000) + assert(!createThread.isAlive, "createSession thread did not finish") + assert(!closeThread.isAlive, "close thread did not finish") + + val captured = capturedWorkers.toArray(Array.empty[DirectWorkerProcess]) + assert(captured.length == 1, + s"expected exactly one worker spawned, got ${captured.length}") + val worker = captured(0) + + outcome.get() match { + case Left(e: IllegalStateException) => + // Contractually allowed, but unreachable with this orchestration: + // readyLatch only fires after createSession has cleared both + // `closed` checks, so B's close cannot flip `closed` in time for + // A to observe it. Kept defensive so a future internal change + // that introduces a new window is still covered. + assert(e.getMessage.contains("closed"), + s"expected dispatcher-closed error, got: ${e.getMessage}") + case Left(other) => + fail(s"unexpected exception from racing createSession: $other") + case Right(_) => + // close() iterated the published worker and tore it down; the + // returned session points at a worker that should now be dead. + } + + // Whichever path won, the worker must not still be running and the + // socket file must be gone. + val deadline = System.currentTimeMillis() + 5000 + while (worker.process.isAlive && System.currentTimeMillis() < deadline) { + Thread.sleep(50) + } + val sockPath = udsPath(worker) + assert(!worker.process.isAlive, + s"worker process should be terminated after close, still alive at $sockPath") + assert(!new java.io.File(sockPath).exists(), + s"socket file $sockPath should have been removed") + } finally { + releaseLatch.countDown() + racing.close() + } + } + + test("worker-provided graceful timeout is capped at the engine-side maximum") { + // The proto documents an engine-configurable maximum (fixed at 30s today). + // A 60s spec value should be clamped down. + val oversizedProps = UDFWorkerProperties.newBuilder() + .setConnection(WorkerConnectionSpec.newBuilder() + .setUnixDomainSocket(UnixDomainSocket.getDefaultInstance).build()) + .setGracefulTerminationTimeoutMs(60000) + .build() + val spec = UDFWorkerSpecification.newBuilder() + .setDirect(DirectWorker.newBuilder() + .setRunner(defaultRunner).setProperties(oversizedProps).build()) + .build() + dispatcher = new TestDirectWorkerDispatcher(spec) + + val session = createStubSession() + assert(session.workerProcess.gracefulTimeoutMs == 30000L, + s"graceful timeout should be capped at 30000ms, " + + s"got ${session.workerProcess.gracefulTimeoutMs}") + session.close() + } + + test("worker-provided init timeout is capped at the engine-side maximum") { + val oversizedProps = UDFWorkerProperties.newBuilder() + .setConnection(WorkerConnectionSpec.newBuilder() + .setUnixDomainSocket(UnixDomainSocket.getDefaultInstance).build()) + .setInitializationTimeoutMs(60000) + .build() + val spec = UDFWorkerSpecification.newBuilder() + .setDirect(DirectWorker.newBuilder() + .setRunner(defaultRunner).setProperties(oversizedProps).build()) + .build() + dispatcher = new TestDirectWorkerDispatcher(spec) + + assert(dispatcher.initTimeoutMs == 30000L, + s"init timeout should be capped at 30000ms, got ${dispatcher.initTimeoutMs}") + } + + test("createSession after close is rejected") { + dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) + dispatcher.close() + + val ex = intercept[IllegalStateException] { + dispatcher.createSession(None) + } + assert(ex.getMessage.contains("closed"), + s"expected dispatcher-closed error, got: ${ex.getMessage}") + dispatcher = null + } + + test("socket directory is owner-only (0700) on POSIX") { + dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) + // Drive one createSession so a worker (and therefore the socket dir) is + // observable via the UDS connection's path. + val session = createStubSession() + val socketDir: Path = new File(udsPath(session.workerProcess)).toPath.getParent + session.close() + + val view = Files.getFileAttributeView(socketDir, classOf[PosixFileAttributeView]) + // Skip explicitly on non-POSIX filesystems rather than silently pass, + // so a CI environment without POSIX attributes is visible in the + // test report instead of giving false confidence. + assume(view != null, s"POSIX file attributes required to check $socketDir") + val perms = view.readAttributes().permissions().asScala.toSet + val expected = java.nio.file.attribute.PosixFilePermissions + .fromString("rwx------").asScala.toSet + assert(perms == expected, + s"socket directory $socketDir should be 0700, got ${perms.mkString(",")}") + } + + test("socket directory is removed after dispatcher.close") { + dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) + val session = createStubSession() + val socketDir = new File(udsPath(session.workerProcess)).toPath.getParent.toFile + assert(socketDir.exists(), + s"socket directory $socketDir should exist while a session is open") + session.close() + + dispatcher.close() + dispatcher = null + + assert(!socketDir.exists(), + s"socket directory $socketDir should be removed after dispatcher.close") + } + + // -- Error-path tests ------------------------------------------------------- + + test("worker is cleaned up when createSessionForWorker throws") { + // A dispatcher whose createSessionForWorker always throws. The spawned + // worker must be terminated rather than leaked until dispatcher.close(). + var capturedWorker: DirectWorkerProcess = null + val failingDispatcher = + new DirectUnixSocketWorkerDispatcher(specWithRunner(defaultRunner)) { + override protected def createConnection( + socketPath: String): UnixSocketWorkerConnection = + new SocketFileConnection(socketPath) + override protected def createSessionForWorker( + worker: DirectWorkerProcess): WorkerSession = { + capturedWorker = worker + throw new RuntimeException("session creation failed") + } + } + + try { + val ex = intercept[RuntimeException] { + failingDispatcher.createSession(None) + } + assert(ex.getMessage.contains("session creation failed")) + assert(capturedWorker != null, "worker should have been spawned before the failure") + assert(!capturedWorker.process.isAlive, + "worker process should have been terminated after session creation failed") + assert(capturedWorker.activeSessions == 0, + "worker session count should be released after failure") + } finally { + failingDispatcher.close() + } + } + + test("DirectWorker without a connection is rejected") { + val badSpec = UDFWorkerSpecification.newBuilder() + .setDirect(DirectWorker.newBuilder().setRunner(defaultRunner).build()) + .build() + val ex = intercept[IllegalArgumentException] { + new TestDirectWorkerDispatcher(badSpec) + } + assert(ex.getMessage.contains("connection must be set"), + s"expected missing-connection error, got: ${ex.getMessage}") + } + + test("DirectWorker with non-UDS transport is rejected") { + val tcpProperties = UDFWorkerProperties.newBuilder() + .setConnection(WorkerConnectionSpec.newBuilder() + .setTcp(LocalTcpConnection.getDefaultInstance).build()) + .build() + val badSpec = UDFWorkerSpecification.newBuilder() + .setDirect(DirectWorker.newBuilder() + .setRunner(defaultRunner).setProperties(tcpProperties).build()) + .build() + val ex = intercept[IllegalArgumentException] { + new TestDirectWorkerDispatcher(badSpec) + } + assert(ex.getMessage.contains("UNIX domain socket"), + s"expected UDS-only error, got: ${ex.getMessage}") + } + + test("socket file is cleaned up when createConnection throws") { + val capturedSocketPaths = new java.util.concurrent.ConcurrentLinkedQueue[String]() + val failingDispatcher = + new DirectUnixSocketWorkerDispatcher(specWithRunner(defaultRunner)) { + override protected def createConnection( + socketPath: String): UnixSocketWorkerConnection = { + capturedSocketPaths.add(socketPath) + throw new RuntimeException("connection creation failed") + } + override protected def createSessionForWorker( + worker: DirectWorkerProcess): WorkerSession = + new StubWorkerSession(worker) + } + try { + val ex = intercept[RuntimeException] { + failingDispatcher.createSession(None) + } + assert(ex.getMessage.contains("connection creation failed")) + assert(capturedSocketPaths.size == 1, "createConnection should have been called once") + val socketPath = capturedSocketPaths.peek() + assert(!new File(socketPath).exists(), + s"socket file $socketPath should have been cleaned up") + } finally { + failingDispatcher.close() + } + } + + test("empty ProcessCallable command is rejected with a clear error") { + val emptyRunner = ProcessCallable.newBuilder().build() + dispatcher = new TestDirectWorkerDispatcher(specWithRunner(emptyRunner)) + val ex = intercept[IllegalArgumentException] { + dispatcher.createSession(None) + } + assert(ex.getMessage.contains("at least one entry"), + s"expected explicit empty-command error, got: ${ex.getMessage}") + } + + test("spawnWorker fails when worker process exits immediately") { + val runner = ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c") + .addCommand("echo 'fatal: bad config' >&2; exit 42").addCommand("--") + .build() + dispatcher = new TestDirectWorkerDispatcher(specWithRunner(runner)) + + val ex = intercept[RuntimeException] { + dispatcher.createSession(None) + } + assert(ex.getMessage.contains("exited with code 42"), + s"expected early-exit error, got: ${ex.getMessage}") + assert(ex.getMessage.contains("fatal: bad config"), + s"expected process output in error, got: ${ex.getMessage}") + } + + test("spawnWorker times out when worker stays alive but never creates socket") { + // Distinct from the "process exits immediately" case: here the worker + // process is healthy but simply doesn't bind the socket, so the + // dispatcher must time out and SIGKILL-reap it rather than wait forever. + val hangingRunner = ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c") + .addCommand("while true; do sleep 1; done").addCommand("--") + .build() + val shortInitProps = UDFWorkerProperties.newBuilder() + .setConnection(WorkerConnectionSpec.newBuilder() + .setUnixDomainSocket(UnixDomainSocket.getDefaultInstance).build()) + .setInitializationTimeoutMs(500) + .build() + val spec = UDFWorkerSpecification.newBuilder() + .setDirect(DirectWorker.newBuilder() + .setRunner(hangingRunner).setProperties(shortInitProps).build()) + .build() + dispatcher = new TestDirectWorkerDispatcher(spec) + + val ex = intercept[DirectWorkerTimeoutException] { + dispatcher.createSession(None) + } + assert(ex.getMessage.contains("did not create socket"), + s"expected init-timeout error, got: ${ex.getMessage}") + assert(ex.getMessage.contains("500ms"), + s"expected timeout value in error, got: ${ex.getMessage}") + } + + // -- Environment lifecycle tests ------------------------------------------- + + test("skips installation when verification succeeds") { + val markerFile = Files.createTempFile("env-install-marker", ".txt").toFile + markerFile.delete() + + val env = WorkerEnvironment.newBuilder() + .setEnvironmentVerification(ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c").addCommand("exit 0").build()) + .setInstallation(ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c") + .addCommand(s"touch ${markerFile.getAbsolutePath}").build()) + .build() + dispatcher = new TestDirectWorkerDispatcher(specWithEnv(env = env)) + + val session = dispatcher.createSession(None) + session.close() + + assert(!markerFile.exists(), + "installation should not run when verification succeeds") + } + + test("runs installation when verification fails") { + val markerFile = Files.createTempFile("env-install-marker", ".txt").toFile + markerFile.delete() + + val env = WorkerEnvironment.newBuilder() + .setEnvironmentVerification(ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c").addCommand("exit 1").build()) + .setInstallation(ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c") + .addCommand(s"touch ${markerFile.getAbsolutePath}").build()) + .build() + dispatcher = new TestDirectWorkerDispatcher(specWithEnv(env = env)) + + val session = dispatcher.createSession(None) + session.close() + + assert(markerFile.exists(), + "installation should run when verification fails") + markerFile.delete() + } + + test("runs installation when no verification callable is provided") { + val markerFile = Files.createTempFile("env-install-marker", ".txt").toFile + markerFile.delete() + + val env = WorkerEnvironment.newBuilder() + .setInstallation(ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c") + .addCommand(s"touch ${markerFile.getAbsolutePath}").build()) + .build() + dispatcher = new TestDirectWorkerDispatcher(specWithEnv(env = env)) + + val session = dispatcher.createSession(None) + session.close() + + assert(markerFile.exists(), + "installation should run when no verification is defined") + markerFile.delete() + } + + test("installation failure throws with process output and prevents worker creation") { + val env = WorkerEnvironment.newBuilder() + .setInstallation(ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c") + .addCommand("echo 'missing dependency: libfoo' >&2; exit 7").build()) + .build() + dispatcher = new TestDirectWorkerDispatcher(specWithEnv(env = env)) + + val ex = intercept[RuntimeException] { + dispatcher.createSession(None) + } + assert(ex.getMessage.contains("exit code 7"), + s"expected installation failure, got: ${ex.getMessage}") + assert(ex.getMessage.contains("missing dependency: libfoo"), + s"expected process output in error, got: ${ex.getMessage}") + } + + test("installation that exceeds callableTimeoutMs is killed and reported") { + // Installation sleeps longer than callableTimeoutMs; the dispatcher + // must SIGKILL-reap it and surface a "Callable timed out" error + // rather than hang the caller. + val slowInstall = ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c") + .addCommand("sleep 30").build() + val env = WorkerEnvironment.newBuilder().setInstallation(slowInstall).build() + val shortTimeoutDispatcher = + new DirectUnixSocketWorkerDispatcher(specWithEnv(env = env)) { + override protected def callableTimeoutMs: Long = 500L + override protected def createConnection( + socketPath: String): UnixSocketWorkerConnection = + new SocketFileConnection(socketPath) + override protected def createSessionForWorker( + worker: DirectWorkerProcess): WorkerSession = + new StubWorkerSession(worker) + } + try { + val ex = intercept[DirectWorkerTimeoutException] { + shortTimeoutDispatcher.createSession(None) + } + assert(ex.getMessage.contains("Callable timed out"), + s"expected callable-timeout error, got: ${ex.getMessage}") + assert(ex.getMessage.contains("500ms"), + s"expected timeout value in error, got: ${ex.getMessage}") + } finally { + shortTimeoutDispatcher.close() + } + } + + test("environment setup runs only once across multiple sessions") { + val counterFile = Files.createTempFile("env-counter", ".txt").toFile + counterFile.delete() + + val env = WorkerEnvironment.newBuilder() + .setInstallation(ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c") + .addCommand(s"echo invoked >> ${counterFile.getAbsolutePath}").build()) + .build() + dispatcher = new TestDirectWorkerDispatcher(specWithEnv(env = env)) + + val s1 = dispatcher.createSession(None); s1.close() + val s2 = dispatcher.createSession(None); s2.close() + + val src = scala.io.Source.fromFile(counterFile) + val lines = try src.getLines().toList finally src.close() + assert(lines.size == 1, + s"installation should run exactly once, but ran ${lines.size} time(s)") + counterFile.delete() + } + + test("concurrent createSession still installs exactly once") { + // The sequential single-install test above cannot catch a missing + // lock around ensureEnvironmentReady. Race many createSession calls + // with an install script that takes long enough for the threads to + // queue on environmentLock, then verify it still ran exactly once. + val counterFile = Files.createTempFile("env-concurrent-install", ".txt").toFile + counterFile.delete() + + val env = WorkerEnvironment.newBuilder() + .setInstallation(ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c") + .addCommand( + s"sleep 0.2; echo invoked >> ${counterFile.getAbsolutePath}").build()) + .build() + dispatcher = new TestDirectWorkerDispatcher(specWithEnv(env = env)) + + val threads = 4 + val startGate = new java.util.concurrent.CountDownLatch(1) + val doneGate = new java.util.concurrent.CountDownLatch(threads) + val sessions = new java.util.concurrent.ConcurrentLinkedQueue[WorkerSession]() + val errors = new java.util.concurrent.ConcurrentLinkedQueue[Throwable]() + + (1 to threads).foreach { _ => + new Thread(() => { + try { + startGate.await() + sessions.add(dispatcher.createSession(None)) + } catch { + case t: Throwable => errors.add(t) + } finally { + doneGate.countDown() + } + }).start() + } + startGate.countDown() + assert(doneGate.await(30, java.util.concurrent.TimeUnit.SECONDS), + "createSession threads did not finish in time") + assert(errors.isEmpty, + s"unexpected errors during concurrent createSession: ${errors.toArray.mkString(", ")}") + + val src = scala.io.Source.fromFile(counterFile) + val lines = try src.getLines().toList finally src.close() + assert(lines.size == 1, + s"installation should run exactly once under concurrent createSession, " + + s"but ran ${lines.size} time(s)") + + sessions.asScala.foreach(_.close()) + counterFile.delete() + } + + test("failed environment setup is not retried on subsequent createSession") { + val counterFile = Files.createTempFile("env-failed-counter", ".txt").toFile + counterFile.delete() + + // Installation script appends a line every time it runs, then always + // fails. The first createSession should run it; the second should be + // rejected immediately without re-running. + val env = WorkerEnvironment.newBuilder() + .setInstallation(ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c") + .addCommand( + s"echo invoked >> ${counterFile.getAbsolutePath}; exit 1").build()) + .build() + dispatcher = new TestDirectWorkerDispatcher(specWithEnv(env = env)) + + val first = intercept[RuntimeException] { dispatcher.createSession(None) } + assert(first.getMessage.contains("installation failed"), + s"expected first-attempt installation failure, got: ${first.getMessage}") + + val second = intercept[RuntimeException] { dispatcher.createSession(None) } + assert(second.getMessage.contains("previously failed"), + s"expected cached failure on retry, got: ${second.getMessage}") + + val src = scala.io.Source.fromFile(counterFile) + val lines = try src.getLines().toList finally src.close() + assert(lines.size == 1, + s"installation should run only once across failed retries, got ${lines.size}") + counterFile.delete() + } + + test("installation timeout transitions to Failed and is not retried") { + val counterFile = Files.createTempFile("env-timeout-counter", ".txt").toFile + counterFile.delete() + + // Install appends to a counter file, then sleeps past callableTimeoutMs + // so runCallable times out. The dispatcher must mark the env Failed + // and reject the next createSession without re-running install. + val env = WorkerEnvironment.newBuilder() + .setInstallation(ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c") + .addCommand( + s"echo invoked >> ${counterFile.getAbsolutePath}; sleep 30").build()) + .build() + val timeoutDispatcher = + new DirectUnixSocketWorkerDispatcher(specWithEnv(env = env)) { + override protected def callableTimeoutMs: Long = 500L + override protected def createConnection( + socketPath: String): UnixSocketWorkerConnection = + new SocketFileConnection(socketPath) + override protected def createSessionForWorker( + worker: DirectWorkerProcess): WorkerSession = + new StubWorkerSession(worker) + } + try { + val first = intercept[DirectWorkerTimeoutException] { + timeoutDispatcher.createSession(None) + } + assert(first.getMessage.contains("Callable timed out"), + s"expected callable-timeout error, got: ${first.getMessage}") + + val second = intercept[DirectWorkerException] { + timeoutDispatcher.createSession(None) + } + assert(second.getMessage.contains("previously failed"), + s"expected cached failure on retry, got: ${second.getMessage}") + + val src = scala.io.Source.fromFile(counterFile) + val lines = try src.getLines().toList finally src.close() + assert(lines.size == 1, + s"installation should run only once across timed-out retries, got ${lines.size}") + } finally { + timeoutDispatcher.close() + counterFile.delete() + } + } + + test("non-None securityScope is rejected until pooling lands") { + dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) + val scope = new WorkerSecurityScope { + override def equals(obj: Any): Boolean = obj.isInstanceOf[this.type] + override def hashCode(): Int = 0 + } + val ex = intercept[IllegalArgumentException] { + dispatcher.createSession(Some(scope)) + } + assert(ex.getMessage.contains("not supported yet"), + s"expected unsupported-scope error, got: ${ex.getMessage}") + } + + test("verification without installation is rejected") { + val env = WorkerEnvironment.newBuilder() + .setEnvironmentVerification(ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c").addCommand("exit 0").build()) + .build() + val ex = intercept[IllegalArgumentException] { + new TestDirectWorkerDispatcher(specWithEnv(env = env)) + } + assert(ex.getMessage.contains("installation"), + s"expected installation-required error, got: ${ex.getMessage}") + } + + test("cleanup runs on dispatcher close") { + val cleanupMarker = Files.createTempFile("env-cleanup-marker", ".txt").toFile + cleanupMarker.delete() + + val env = WorkerEnvironment.newBuilder() + .setEnvironmentCleanup(ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c") + .addCommand(s"touch ${cleanupMarker.getAbsolutePath}").build()) + .build() + dispatcher = new TestDirectWorkerDispatcher(specWithEnv(env = env)) + + val session = dispatcher.createSession(None) + session.close() + + assert(!cleanupMarker.exists(), + "cleanup should not run until dispatcher is closed") + + dispatcher.close() + dispatcher = null + + assert(cleanupMarker.exists(), + "cleanup should run when dispatcher is closed") + cleanupMarker.delete() + } +} diff --git a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/WorkerAbstractionSuite.scala b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/WorkerAbstractionSuite.scala deleted file mode 100644 index 42f53af07424a..0000000000000 --- a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/WorkerAbstractionSuite.scala +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.udf.worker.core - -import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite - -class WorkerAbstractionSuite - extends AnyFunSuite { // scalastyle:ignore funsuite - - test("dummy") {} -} diff --git a/udf/worker/proto/src/main/protobuf/common.proto b/udf/worker/proto/src/main/protobuf/common.proto index 9c50cdd7a7e4b..ee032def73efe 100644 --- a/udf/worker/proto/src/main/protobuf/common.proto +++ b/udf/worker/proto/src/main/protobuf/common.proto @@ -32,6 +32,13 @@ enum UDFWorkerDataFormat { } // The UDF execution type/shape. +// +// BIDIRECTIONAL_STREAMING is the only pattern supported by the engine for +// now. It may be possible to express all UDF types (scalar, mapPartitions, +// and eventually UDAF/UDTF/streaming) on top of this single pattern by +// framing their phases as messages on the stream, but that is a design +// question worth revisiting as additional UDF types are added -- for +// example, aggregation may prefer a multi-round or specialized pattern. enum UDFProtoCommunicationPattern { UDF_PROTO_COMMUNICATION_PATTERN_UNSPECIFIED = 0; diff --git a/udf/worker/proto/src/main/protobuf/worker_spec.proto b/udf/worker/proto/src/main/protobuf/worker_spec.proto index f2eacf2b3ce35..83dac4f962e5f 100644 --- a/udf/worker/proto/src/main/protobuf/worker_spec.proto +++ b/udf/worker/proto/src/main/protobuf/worker_spec.proto @@ -140,13 +140,17 @@ message WorkerCapabilities { // Whether multiple, concurrent UDF // connections are supported by this worker // (for example via multi-threading). - // + // // In the first implementation of the engine-side // worker specification, this property will not be used. - // + // // Usage of this property can be enabled in the future if the // engine implements more advanced resource management (TBD). // + // TODO: wire this into planning/scheduling -- SPIP worker-spec §2.4 + // "Parallelism" describes the intended use (e.g., multiplex tasks onto + // a single worker vs. spawn multiple workers per executor). + // // (Optional) optional bool supports_concurrent_udfs = 3; @@ -190,25 +194,31 @@ message UDFWorkerProperties { // (Optional) optional int32 graceful_termination_timeout_ms = 2; - // The connection this [[DirectWorker]] supports. Note that a single - // connection is sufficient to run multiple UDFs and (gRPC) services. + // A [[DirectWorker]] exposes one server-side connection endpoint (a + // UDS path or a TCP port) that all sessions on the worker share. + // Multi-connection workers (e.g., separate data and control channels) + // are not supported in this release. + // + // On [[DirectWorker]] creation, connection information + // is passed to the callable as a string parameter. + // The string format depends on the [[WorkerConnectionSpec]]: // - // On [[DirectWorker]] creation, connection information - // is passed to the callable as a string parameter. - // The string format depends on the [[WorkerConnection]]: - // // For example, when using TCP, the callable argument will be: // --connection PORT // Here is a concrete example // --connection 8080 - // + // // For the format of each specific transport type, see the comments below. // // (Required) - WorkerConnection connection = 3; + WorkerConnectionSpec connection = 3; } -message WorkerConnection { +// Describes one connection (transport endpoint) that a [[DirectWorker]] +// exposes. This is a configuration message -- the live transport object +// used by the engine at runtime is the Scala abstraction +// `org.apache.spark.udf.worker.core.WorkerConnection`. +message WorkerConnectionSpec { // (Required) oneof transport { UnixDomainSocket unix_domain_socket = 1; @@ -275,7 +285,7 @@ message ProcessCallable { // // --connection // The value of the connection argument is a string with - // engine-assinged connection parameters. See [[UDFWorkerProperties]] + // engine-assigned connection parameters. See [[UDFWorkerProperties]] // for details. // // (Optional)