Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions crates/sdk-core/src/core_tests/workers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1209,3 +1209,116 @@ async fn nexus_start_operation_failure_converts_to_legacy_for_old_server(
worker.shutdown().await;
worker.finalize_shutdown().await;
}

/// Verifies that `initiate_shutdown` sends the `ShutdownWorker` RPC so that the server can
/// complete in-flight polls. Without this, graceful poll shutdown deadlocks: the SDK waits for
/// polls to drain, but the server was never told to flush them.
#[tokio::test]
async fn graceful_shutdown_sends_shutdown_worker_rpc_during_initiate() {
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering},
};
use temporalio_common::protos::temporal::api::{
namespace::v1::{NamespaceInfo, namespace_info::Capabilities},
workflowservice::v1::DescribeNamespaceResponse,
};
use tokio::sync::Notify;

let shutdown_rpc_called = Arc::new(AtomicBool::new(false));
let shutdown_rpc_called_clone = shutdown_rpc_called.clone();
// When the shutdown_worker RPC fires, it signals polls to complete (simulating server
// behavior where ShutdownWorker causes the server to return empty poll responses).
let poll_releaser = Arc::new(Notify::new());
let poll_releaser_for_rpc = poll_releaser.clone();

let mut mock_client = MockWorkerClient::new();
mock_client
.expect_capabilities()
.returning(|| Some(*DEFAULT_TEST_CAPABILITIES));
mock_client
.expect_workers()
.returning(|| DEFAULT_WORKERS_REGISTRY.clone());
mock_client.expect_is_mock().returning(|| true);
mock_client
.expect_sdk_name_and_version()
.returning(|| ("test-core".to_string(), "0.0.0".to_string()));
mock_client
.expect_identity()
.returning(|| "test-identity".to_string());
mock_client
.expect_worker_grouping_key()
.returning(Uuid::new_v4);
mock_client
.expect_worker_instance_key()
.returning(Uuid::new_v4);
mock_client
.expect_set_heartbeat_client_fields()
.returning(|hb| {
hb.sdk_name = "test-core".to_string();
hb.sdk_version = "0.0.0".to_string();
hb.worker_identity = "test-identity".to_string();
hb.heartbeat_time = Some(std::time::SystemTime::now().into());
});
// Return the worker_poll_complete_on_shutdown capability so validate() enables graceful mode
mock_client.expect_describe_namespace().returning(move || {
Ok(DescribeNamespaceResponse {
namespace_info: Some(NamespaceInfo {
capabilities: Some(Capabilities {
worker_poll_complete_on_shutdown: true,
..Capabilities::default()
}),
..NamespaceInfo::default()
}),
..DescribeNamespaceResponse::default()
})
});
// When shutdown_worker RPC is called, mark it and release polls
mock_client
.expect_shutdown_worker()
.returning(move |_, _, _, _| {
shutdown_rpc_called_clone.store(true, Ordering::SeqCst);
poll_releaser_for_rpc.notify_waiters();
Ok(ShutdownWorkerResponse {})
});
mock_client
.expect_complete_workflow_task()
.returning(|_| Ok(RespondWorkflowTaskCompletedResponse::default()));

// Polls block until shutdown_worker RPC releases them (simulating server holding polls
// open until it receives the ShutdownWorker signal)
let poll_releaser_for_stream = poll_releaser.clone();
let stream = stream::unfold(poll_releaser_for_stream, |releaser| async move {
releaser.notified().await;
Some((
Ok(PollWorkflowTaskQueueResponse::default().try_into().unwrap()),
releaser,
))
});

let mw = MockWorkerInputs::new(stream.boxed());
let worker = mock_worker(MocksHolder::from_mock_worker(mock_client, mw));

// validate() reads describe_namespace and sets graceful_poll_shutdown = true
worker.validate().await.unwrap();

let poll_fut = worker.poll_workflow_activation();
let shutdown_fut = async {
// initiate_shutdown must send the ShutdownWorker RPC, which releases the polls
worker.initiate_shutdown();
};

let (poll_result, _) = tokio::time::timeout(Duration::from_secs(5), async {
tokio::join!(poll_fut, shutdown_fut)
})
.await
.expect("Shutdown should complete within 5s -- if it hangs, the ShutdownWorker RPC was not sent during initiate_shutdown");

assert_matches!(poll_result.unwrap_err(), PollError::ShutDown);
assert!(
shutdown_rpc_called.load(Ordering::SeqCst),
"ShutdownWorker RPC must be called during initiate_shutdown"
);

worker.finalize_shutdown().await;
}
95 changes: 95 additions & 0 deletions crates/sdk-core/src/pollers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ where
return match state.poller.poll().await {
Some(Ok((task, permit))) => {
if task == Default::default() {
if state.poller_was_shutdown {
// Server sent an empty response after we initiated
// shutdown — this is the graceful shutdown signal.
return None;
}
// We get the default proto in the event that the long poll
// times out.
debug!("Poll {} task timeout", T::task_name());
Expand Down Expand Up @@ -276,3 +281,93 @@ pub(crate) fn new_nexus_task_poller(
)
.into_stream()
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{
abstractions::tests::fixed_size_permit_dealer, pollers::MockPermittedPollBuffer,
test_help::mock_poller, worker::ActivitySlotKind,
};
use futures_util::{StreamExt, pin_mut};
use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
};

/// Verify that empty responses after shutdown are not treated as poll timeout and retried
/// indefinitely
#[tokio::test]
async fn empty_response_after_shutdown_terminates_stream() {
let poll_count = Arc::new(AtomicUsize::new(0));
let poll_count_clone = poll_count.clone();

let mut mock_poller = mock_poller();
mock_poller.expect_poll().returning(move || {
poll_count_clone.fetch_add(1, Ordering::SeqCst);
Some(Ok(PollActivityTaskQueueResponse::default()))
});

let sem = Arc::new(fixed_size_permit_dealer::<ActivitySlotKind>(10));
let shutdown_token = CancellationToken::new();

let stream = new_activity_task_poller(
Box::new(MockPermittedPollBuffer::new(sem, mock_poller)),
MetricsContext::no_op(),
shutdown_token.clone(),
);
pin_mut!(stream);

shutdown_token.cancel();

let result = tokio::time::timeout(std::time::Duration::from_secs(2), stream.next()).await;
assert!(
result.is_ok(),
"Stream should terminate promptly after shutdown, not hang"
);
assert!(
result.unwrap().is_none(),
"Stream should return None (terminated) on empty response after shutdown"
);

let total = poll_count.load(Ordering::SeqCst);
assert!(
total < 5,
"Expected stream to terminate quickly, but poller was called {total} times"
);
}

#[tokio::test]
async fn empty_response_before_shutdown_retries() {
let mut mock_poller = mock_poller();
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = call_count.clone();
mock_poller.expect_poll().returning(move || {
let n = call_count_clone.fetch_add(1, Ordering::SeqCst);
if n < 2 {
Some(Ok(PollActivityTaskQueueResponse::default()))
} else {
None
}
});

let sem = Arc::new(fixed_size_permit_dealer::<ActivitySlotKind>(10));
let shutdown_token = CancellationToken::new();

let stream = new_activity_task_poller(
Box::new(MockPermittedPollBuffer::new(sem, mock_poller)),
MetricsContext::no_op(),
shutdown_token,
);
pin_mut!(stream);

// Without shutdown, empty responses should be skipped and the stream terminates
// only when the poller returns None.
let result = stream.next().await;
assert!(
result.is_none(),
"Stream should end when poller returns None"
);
assert_eq!(call_count.load(Ordering::SeqCst), 3);
}
}
13 changes: 12 additions & 1 deletion crates/sdk-core/src/pollers/poll_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,18 @@ where
let shutdown_clone = shutdown.clone();

let r = if graceful_shutdown.load(Ordering::Relaxed) {
pf(timeout_override).await
// TEMP FIX: Give the server a reasonable window to
// complete the poll after ShutdownWorker. Fall back
// to cancelling the poll if it takes too long, to
// avoid a 60s hang due to a server-side race
// (temporalio/temporal#9545).
let graceful_interruptor = shutdown_clone
.cancelled()
.then(|_| tokio::time::sleep(Duration::from_secs(5)));
tokio::select! {
r = pf(timeout_override) => r,
_ = graceful_interruptor => return,
}
} else {
let poll_interruptor = shutdown.cancelled().then(|_| async move {
if let Some(w) = poll_shutdown_interrupt_wait {
Expand Down
100 changes: 61 additions & 39 deletions crates/sdk-core/src/worker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,8 @@ pub struct Worker {
/// Set during validate() when the namespace has the poller_autoscaling capability,
/// enabling scale-down on poll timeout even without an explicit scaling decision.
poller_autoscaling: Arc<AtomicBool>,
/// Handle for the spawned ShutdownWorker RPC task, awaited during shutdown.
shutdown_rpc_handle: std::sync::Mutex<Option<tokio::task::JoinHandle<()>>>,
}

struct AllPermitsTracker {
Expand Down Expand Up @@ -927,6 +929,7 @@ impl Worker {
status: worker_status,
graceful_poll_shutdown,
poller_autoscaling,
shutdown_rpc_handle: std::sync::Mutex::new(None),
})
}

Expand All @@ -944,43 +947,16 @@ impl Worker {
/// [Worker::finalize_shutdown].
pub async fn shutdown(&self) {
self.initiate_shutdown();
{
*self.status.write() = WorkerStatus::ShuttingDown;
}
let heartbeat = self
.client_worker_registrator
.heartbeat_manager
.as_ref()
.map(|hm| hm.heartbeat_callback.clone()());
let sticky_name = self
.workflows
.as_ref()
.and_then(|wf| wf.get_sticky_queue_name())
.unwrap_or_default();
// This is a best effort call and we can still shutdown the worker if it fails
let task_queue_types = self.config.task_types.to_task_queue_types();
match self
.client
.shutdown_worker(
sticky_name,
self.config.task_queue.clone(),
task_queue_types,
heartbeat,
)
.await
{
Err(err)
if !matches!(
err.code(),
tonic::Code::Unimplemented | tonic::Code::Unavailable
) =>
{
warn!(
"shutdown_worker rpc errored during worker shutdown: {:?}",
err
);
}
_ => {}

// Ensure the ShutdownWorker RPC completes before waiting for polls to drain,
// otherwise graceful poll shutdown deadlocks.
let handle = self
.shutdown_rpc_handle
.lock()
.ok()
.and_then(|mut g| g.take());
if let Some(handle) = handle {
let _ = handle.await;
}

// We need to wait for all local activities to finish so no more workflow task heartbeats
Expand Down Expand Up @@ -1375,8 +1351,11 @@ impl Worker {
&self.config
}

/// Initiate shutdown. See [Worker::shutdown], this is just a sync version that starts the
/// process. You can then wait on `shutdown` or [Worker::finalize_shutdown].
/// Initiate shutdown, including sending the `ShutdownWorker` RPC so the server can complete
/// in-flight polls. This must be awaited before waiting for polls to drain, otherwise
/// graceful poll shutdown deadlocks.
///
/// You can then wait on `shutdown` or [Worker::finalize_shutdown].
pub fn initiate_shutdown(&self) {
if !self.shutdown_token.is_cancelled() {
info!(
Expand All @@ -1385,6 +1364,7 @@ impl Worker {
"Initiated shutdown",
);
}
let already_initiated_shutdown = self.shutdown_token.is_cancelled();
self.shutdown_token.cancel();
{
*self.status.write() = WorkerStatus::ShuttingDown;
Expand Down Expand Up @@ -1419,6 +1399,48 @@ impl Worker {
la_mgr.workflows_have_shutdown();
}
}

if already_initiated_shutdown {
return;
}

// Spawn the ShutdownWorker RPC so the server can complete in-flight polls.
// The handle is stored and awaited in shutdown() to ensure completion.
let client = self.client.clone();
let sticky_name = self
.workflows
.as_ref()
.and_then(|wf| wf.get_sticky_queue_name())
.unwrap_or_default();
let task_queue = self.config.task_queue.clone();
let task_queue_types = self.config.task_types.to_task_queue_types();
let heartbeat = self
.client_worker_registrator
.heartbeat_manager
.as_ref()
.map(|hm| hm.heartbeat_callback.clone()());
let handle = tokio::spawn(async move {
match client
.shutdown_worker(sticky_name, task_queue, task_queue_types, heartbeat)
.await
{
Err(err)
if !matches!(
err.code(),
tonic::Code::Unimplemented | tonic::Code::Unavailable
) =>
{
warn!(
"shutdown_worker rpc errored during worker shutdown: {:?}",
err
);
}
_ => {}
}
});
if let Ok(mut guard) = self.shutdown_rpc_handle.lock() {
*guard = Some(handle);
}
}

/// Unique identifier for this worker instance.
Expand Down
Loading
Loading