Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1050,7 +1050,9 @@ const struct TemporalCoreByteArray *temporal_core_worker_record_activity_heartbe
void temporal_core_worker_request_workflow_eviction(struct TemporalCoreWorker *worker,
struct TemporalCoreByteArrayRef run_id);

void temporal_core_worker_initiate_shutdown(struct TemporalCoreWorker *worker);
void temporal_core_worker_initiate_shutdown(struct TemporalCoreWorker *worker,
void *user_data,
TemporalCoreWorkerCallback callback);

void temporal_core_worker_finalize_shutdown(struct TemporalCoreWorker *worker,
void *user_data,
Expand Down
15 changes: 13 additions & 2 deletions crates/sdk-core-c-bridge/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -928,9 +928,20 @@ pub extern "C" fn temporal_core_worker_request_workflow_eviction(
}

#[unsafe(no_mangle)]
pub extern "C" fn temporal_core_worker_initiate_shutdown(worker: *mut Worker) {
pub extern "C" fn temporal_core_worker_initiate_shutdown(
worker: *mut Worker,
user_data: *mut libc::c_void,
callback: WorkerCallback,
) {
let worker = unsafe { &*worker };
worker.worker.as_ref().unwrap().initiate_shutdown();
let core_worker = worker.worker.as_ref().unwrap().clone();
let user_data = UserDataHandle(user_data);
worker.runtime.core.tokio_handle().spawn(async move {
core_worker.initiate_shutdown().await;
unsafe {
callback(user_data.into(), std::ptr::null());
}
});
}

#[unsafe(no_mangle)]
Expand Down
4 changes: 2 additions & 2 deletions crates/sdk-core/src/core_tests/activity_tasks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ async fn eviction_completion_during_shutdown_does_not_panic() {
// Cancel the shutdown token and enqueue BumpStream into the local channel.
// The stream will process BumpStream, see shutdown_done()=true (because
// ignore_evicts_on_shutdown skips the eviction activation), and exit.
core.initiate_shutdown();
core.initiate_shutdown().await;

// Complete the eviction. Its WFActCompleteMsg is queued AFTER BumpStream
// (same FIFO channel), so the stream exits before processing it — dropping
Expand Down Expand Up @@ -1082,7 +1082,7 @@ async fn graceful_shutdown(#[values(true, false)] at_max_outstanding: bool) {
let _2 = worker.poll_activity_task().await.unwrap();
let _3 = worker.poll_activity_task().await.unwrap();

worker.initiate_shutdown();
worker.initiate_shutdown().await;
let expected_tts = HashSet::from([vec![1], vec![2], vec![3]]);
let mut seen_tts = HashSet::new();
for _ in 1..=3 {
Expand Down
135 changes: 124 additions & 11 deletions crates/sdk-core/src/core_tests/workers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ async fn test_task_type_combinations_unified(
}
}

worker.initiate_shutdown();
worker.initiate_shutdown().await;
if enable_workflows {
assert_matches!(
worker.poll_workflow_activation().await.unwrap_err(),
Expand Down Expand Up @@ -647,7 +647,7 @@ async fn nexus_request_deadline_missing_header() {
.complete_nexus_task(create_test_nexus_completion(nexus_task.task_token()))
.await
.unwrap();
worker.initiate_shutdown();
worker.initiate_shutdown().await;
assert_matches!(
worker.poll_nexus_task().await.unwrap_err(),
PollError::ShutDown
Expand Down Expand Up @@ -707,7 +707,7 @@ async fn nexus_request_deadline_valid_header() {
.complete_nexus_task(create_test_nexus_completion(nexus_task.task_token()))
.await
.unwrap();
worker.initiate_shutdown();
worker.initiate_shutdown().await;
assert_matches!(
worker.poll_nexus_task().await.unwrap_err(),
PollError::ShutDown
Expand Down Expand Up @@ -753,7 +753,7 @@ async fn nexus_request_deadline_invalid_header() {
.complete_nexus_task(create_test_nexus_completion(nexus_task.task_token()))
.await
.unwrap();
worker.initiate_shutdown();
worker.initiate_shutdown().await;
assert_matches!(
worker.poll_nexus_task().await.unwrap_err(),
PollError::ShutDown
Expand Down Expand Up @@ -814,7 +814,7 @@ async fn nexus_task_completion_with_failure_status() {
.await
.unwrap();

worker.initiate_shutdown();
worker.initiate_shutdown().await;
assert_matches!(
worker.poll_nexus_task().await.unwrap_err(),
PollError::ShutDown
Expand Down Expand Up @@ -876,7 +876,7 @@ async fn nexus_task_completion_with_failure_converts_to_legacy_for_old_server()
.await
.unwrap();

worker.initiate_shutdown();
worker.initiate_shutdown().await;
assert_matches!(
worker.poll_nexus_task().await.unwrap_err(),
PollError::ShutDown
Expand Down Expand Up @@ -926,7 +926,7 @@ async fn nexus_task_completion_with_failure_status_missing_handler_info_fails(
Err(CompleteNexusError::MalformedNexusCompletion { reason }) if reason.contains("NexusHandlerFailureInfo")
);

worker.initiate_shutdown();
worker.initiate_shutdown().await;
assert_matches!(
worker.poll_nexus_task().await.unwrap_err(),
PollError::ShutDown
Expand Down Expand Up @@ -1006,7 +1006,7 @@ async fn nexus_start_operation_failure_with_application_failure_info() {
.await
.unwrap();

worker.initiate_shutdown();
worker.initiate_shutdown().await;
assert_matches!(
worker.poll_nexus_task().await.unwrap_err(),
PollError::ShutDown
Expand Down Expand Up @@ -1066,7 +1066,7 @@ async fn nexus_start_operation_failure_with_canceled_failure_info() {
.await
.unwrap();

worker.initiate_shutdown();
worker.initiate_shutdown().await;
assert_matches!(
worker.poll_nexus_task().await.unwrap_err(),
PollError::ShutDown
Expand Down Expand Up @@ -1117,7 +1117,7 @@ async fn nexus_start_operation_failure_with_invalid_failure_info(
if reason.contains("ApplicationFailureInfo") && reason.contains("CanceledFailureInfo")
);

worker.initiate_shutdown();
worker.initiate_shutdown().await;
assert_matches!(
worker.poll_nexus_task().await.unwrap_err(),
PollError::ShutDown
Expand Down Expand Up @@ -1201,11 +1201,124 @@ async fn nexus_start_operation_failure_converts_to_legacy_for_old_server(
.await
.unwrap();

worker.initiate_shutdown();
worker.initiate_shutdown().await;
assert_matches!(
worker.poll_nexus_task().await.unwrap_err(),
PollError::ShutDown
);
worker.shutdown().await;
worker.finalize_shutdown().await;
}

/// Verifies that `initiate_shutdown` sends the `ShutdownWorker` RPC so that the server can
/// complete in-flight polls. Without this, graceful poll shutdown deadlocks: the SDK waits for
/// polls to drain, but the server was never told to flush them.
#[tokio::test]
async fn graceful_shutdown_sends_shutdown_worker_rpc_during_initiate() {
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering},
};
use temporalio_common::protos::temporal::api::{
namespace::v1::{NamespaceInfo, namespace_info::Capabilities},
workflowservice::v1::DescribeNamespaceResponse,
};
use tokio::sync::Notify;

let shutdown_rpc_called = Arc::new(AtomicBool::new(false));
let shutdown_rpc_called_clone = shutdown_rpc_called.clone();
// When the shutdown_worker RPC fires, it signals polls to complete (simulating server
// behavior where ShutdownWorker causes the server to return empty poll responses).
let poll_releaser = Arc::new(Notify::new());
let poll_releaser_for_rpc = poll_releaser.clone();

let mut mock_client = MockWorkerClient::new();
mock_client
.expect_capabilities()
.returning(|| Some(*DEFAULT_TEST_CAPABILITIES));
mock_client
.expect_workers()
.returning(|| DEFAULT_WORKERS_REGISTRY.clone());
mock_client.expect_is_mock().returning(|| true);
mock_client
.expect_sdk_name_and_version()
.returning(|| ("test-core".to_string(), "0.0.0".to_string()));
mock_client
.expect_identity()
.returning(|| "test-identity".to_string());
mock_client
.expect_worker_grouping_key()
.returning(Uuid::new_v4);
mock_client
.expect_worker_instance_key()
.returning(Uuid::new_v4);
mock_client
.expect_set_heartbeat_client_fields()
.returning(|hb| {
hb.sdk_name = "test-core".to_string();
hb.sdk_version = "0.0.0".to_string();
hb.worker_identity = "test-identity".to_string();
hb.heartbeat_time = Some(std::time::SystemTime::now().into());
});
// Return the worker_poll_complete_on_shutdown capability so validate() enables graceful mode
mock_client.expect_describe_namespace().returning(move || {
Ok(DescribeNamespaceResponse {
namespace_info: Some(NamespaceInfo {
capabilities: Some(Capabilities {
worker_poll_complete_on_shutdown: true,
..Capabilities::default()
}),
..NamespaceInfo::default()
}),
..DescribeNamespaceResponse::default()
})
});
// When shutdown_worker RPC is called, mark it and release polls
mock_client
.expect_shutdown_worker()
.returning(move |_, _, _, _| {
shutdown_rpc_called_clone.store(true, Ordering::SeqCst);
poll_releaser_for_rpc.notify_waiters();
Ok(ShutdownWorkerResponse {})
});
mock_client
.expect_complete_workflow_task()
.returning(|_| Ok(RespondWorkflowTaskCompletedResponse::default()));

// Polls block until shutdown_worker RPC releases them (simulating server holding polls
// open until it receives the ShutdownWorker signal)
let poll_releaser_for_stream = poll_releaser.clone();
let stream = stream::unfold(poll_releaser_for_stream, |releaser| async move {
releaser.notified().await;
Some((
Ok(PollWorkflowTaskQueueResponse::default().try_into().unwrap()),
releaser,
))
});

let mw = MockWorkerInputs::new(stream.boxed());
let worker = mock_worker(MocksHolder::from_mock_worker(mock_client, mw));

// validate() reads describe_namespace and sets graceful_poll_shutdown = true
worker.validate().await.unwrap();

let poll_fut = worker.poll_workflow_activation();
let shutdown_fut = async {
// initiate_shutdown must send the ShutdownWorker RPC, which releases the polls
worker.initiate_shutdown().await;
};

let (poll_result, _) = tokio::time::timeout(Duration::from_secs(5), async {
tokio::join!(poll_fut, shutdown_fut)
})
.await
.expect("Shutdown should complete within 5s -- if it hangs, the ShutdownWorker RPC was not sent during initiate_shutdown");

assert_matches!(poll_result.unwrap_err(), PollError::ShutDown);
assert!(
shutdown_rpc_called.load(Ordering::SeqCst),
"ShutdownWorker RPC must be called during initiate_shutdown"
);

worker.finalize_shutdown().await;
}
8 changes: 4 additions & 4 deletions crates/sdk-core/src/core_tests/workflow_tasks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,7 @@ async fn lots_of_workflows() {
let completed_count = Arc::new(Semaphore::new(0));
let killer = async {
let _ = completed_count.acquire_many(total_wfs).await.unwrap();
worker.initiate_shutdown();
worker.initiate_shutdown().await;
};
let poller = fanout_tasks(5, |_| {
let completed_count = completed_count.clone();
Expand Down Expand Up @@ -2576,7 +2576,7 @@ async fn _do_post_terminal_commands_test(
.await
.unwrap();

core.initiate_shutdown();
core.initiate_shutdown().await;
let act = core.poll_workflow_activation().await;
assert_matches!(act.unwrap_err(), PollError::ShutDown);
core.shutdown().await;
Expand Down Expand Up @@ -2701,7 +2701,7 @@ async fn poller_wont_run_ahead_of_task_slots() {
let ender = async {
time::sleep(Duration::from_millis(300)).await;
// initiate shutdown, then complete open tasks
worker.initiate_shutdown();
worker.initiate_shutdown().await;
for t in tasks {
worker
.complete_workflow_activation(WorkflowActivationCompletion::empty(t.run_id))
Expand Down Expand Up @@ -2909,7 +2909,7 @@ async fn slot_provider_cant_hand_out_more_permits_than_cache_size() {
let ender = async {
time::sleep(Duration::from_millis(300)).await;
// initiate shutdown, then complete open tasks
worker.initiate_shutdown();
worker.initiate_shutdown().await;
for t in tasks {
worker
.complete_workflow_activation(WorkflowActivationCompletion::empty(t.run_id))
Expand Down
Loading
Loading