Skip to content

Commit 9408022

Browse files
authored
Allow external storage to run concurrently and separate from codecs (#1394)
1 parent 05971a8 commit 9408022

8 files changed

Lines changed: 124 additions & 54 deletions

File tree

temporalio/bridge/worker.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ async def decode_activation(
303303
activation: temporalio.bridge.proto.workflow_activation.WorkflowActivation,
304304
data_converter: temporalio.converter.DataConverter,
305305
decode_headers: bool,
306-
concurrency_limit: int,
306+
storage_concurrency_limit: int,
307307
) -> temporalio.converter._extstore.StorageOperationMetrics:
308308
"""Decode all payloads in the activation.
309309
@@ -315,27 +315,48 @@ async def decode_activation(
315315
await CommandAwarePayloadVisitor(
316316
skip_search_attributes=True,
317317
skip_headers=not decode_headers,
318-
concurrency_limit=concurrency_limit,
319-
).visit(_Visitor(data_converter._decode_payload_sequence), activation)
318+
concurrency_limit=storage_concurrency_limit,
319+
).visit(
320+
_Visitor(data_converter._external_retrieve_payload_sequence), activation
321+
)
322+
323+
await CommandAwarePayloadVisitor(
324+
skip_search_attributes=True,
325+
skip_headers=not decode_headers,
326+
).visit(_Visitor(data_converter._decode_payload_sequence), activation)
327+
320328
return metrics
321329

322330

323331
async def encode_completion(
324332
completion: temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion,
325333
data_converter: temporalio.converter.DataConverter,
326334
encode_headers: bool,
327-
concurrency_limit: int,
335+
storage_concurrency_limit: int,
328336
) -> temporalio.converter._extstore.StorageOperationMetrics:
329337
"""Encode all payloads in the completion.
330338
331339
Returns:
332340
Metrics from any external storage store operations that occurred.
333341
"""
342+
await CommandAwarePayloadVisitor(
343+
skip_search_attributes=True,
344+
skip_headers=not encode_headers,
345+
).visit(_Visitor(data_converter._encode_payload_sequence), completion)
346+
347+
async def _store_and_validate(
348+
payloads: Sequence[Payload],
349+
) -> list[Payload]:
350+
stored = await data_converter._external_store_payload_sequence(payloads)
351+
data_converter._validate_payload_limits(stored)
352+
return stored
353+
334354
metrics = temporalio.converter._extstore.StorageOperationMetrics()
335355
with metrics.track():
336356
await CommandAwarePayloadVisitor(
337357
skip_search_attributes=True,
338358
skip_headers=not encode_headers,
339-
concurrency_limit=concurrency_limit,
340-
).visit(_Visitor(data_converter._encode_payload_sequence), completion)
359+
concurrency_limit=storage_concurrency_limit,
360+
).visit(_Visitor(_store_and_validate), completion)
361+
341362
return metrics

temporalio/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9185,7 +9185,7 @@ async def _apply_headers(
91859185
return
91869186
if encode_headers:
91879187
for payload in source.values():
9188-
payload.CopyFrom(await data_converter._encode_payload(payload))
9188+
payload.CopyFrom(await data_converter._transform_outbound_payload(payload))
91899189
temporalio.common._apply_headers(source, dest)
91909190

91919191

temporalio/converter/_data_converter.py

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ async def encode(
111111
"""
112112
payloads = self.payload_converter.to_payloads(values)
113113
payloads = await self._encode_payload_sequence(payloads)
114+
payloads = await self._external_store_payload_sequence(payloads)
115+
self._validate_payload_limits(payloads)
114116
return payloads
115117

116118
async def decode(
@@ -128,6 +130,7 @@ async def decode(
128130
Returns:
129131
Decoded and converted values.
130132
"""
133+
payloads = await self._external_retrieve_payload_sequence(payloads)
131134
payloads = await self._decode_payload_sequence(payloads)
132135
return self.payload_converter.from_payloads(payloads, type_hints)
133136

@@ -156,13 +159,13 @@ async def encode_failure(
156159
) -> None:
157160
"""Convert and encode failure."""
158161
self.failure_converter.to_failure(exception, self.payload_converter, failure)
159-
await _apply_to_failure_payloads(failure, self._encode_payloads)
162+
await _apply_to_failure_payloads(failure, self._transform_outbound_payloads)
160163

161164
async def decode_failure(
162165
self, failure: temporalio.api.failure.v1.Failure
163166
) -> BaseException:
164167
"""Decode and convert failure."""
165-
await _apply_to_failure_payloads(failure, self._decode_payloads)
168+
await _apply_to_failure_payloads(failure, self._transform_inbound_payloads)
166169
return self.failure_converter.from_failure(failure, self.payload_converter)
167170

168171
def with_context(self, context: SerializationContext) -> Self:
@@ -250,7 +253,7 @@ async def _encode_memo_existing(
250253
"[TMPRL1103] Attempted to upload memo with size that exceeded the warning limit.",
251254
)
252255

253-
async def _encode_payload(
256+
async def _transform_outbound_payload(
254257
self, payload: temporalio.api.common.v1.Payload
255258
) -> temporalio.api.common.v1.Payload:
256259
if self.payload_codec:
@@ -260,27 +263,16 @@ async def _encode_payload(
260263
self._validate_payload_limits([payload])
261264
return payload
262265

263-
async def _encode_payloads(self, payloads: temporalio.api.common.v1.Payloads):
266+
async def _transform_outbound_payloads(
267+
self, payloads: temporalio.api.common.v1.Payloads
268+
):
264269
if self.payload_codec:
265270
await self.payload_codec.encode_wrapper(payloads)
266271
if self.external_storage:
267272
await self.external_storage._store_payloads(payloads)
268273
self._validate_payload_limits(payloads.payloads)
269274

270-
async def _encode_payload_sequence(
271-
self, payloads: Sequence[temporalio.api.common.v1.Payload]
272-
) -> list[temporalio.api.common.v1.Payload]:
273-
encoded_payloads = list(payloads)
274-
if self.payload_codec:
275-
encoded_payloads = await self.payload_codec.encode(encoded_payloads)
276-
if self.external_storage:
277-
encoded_payloads = await self.external_storage._store_payload_sequence(
278-
encoded_payloads
279-
)
280-
self._validate_payload_limits(encoded_payloads)
281-
return encoded_payloads
282-
283-
async def _decode_payload(
275+
async def _transform_inbound_payload(
284276
self, payload: temporalio.api.common.v1.Payload
285277
) -> temporalio.api.common.v1.Payload:
286278
if self.external_storage:
@@ -289,7 +281,9 @@ async def _decode_payload(
289281
payload = (await self.payload_codec.decode([payload]))[0]
290282
return payload
291283

292-
async def _decode_payloads(self, payloads: temporalio.api.common.v1.Payloads):
284+
async def _transform_inbound_payloads(
285+
self, payloads: temporalio.api.common.v1.Payloads
286+
):
293287
if self.external_storage:
294288
await self.external_storage._retrieve_payloads(payloads)
295289
else:
@@ -304,23 +298,51 @@ async def _decode_payloads(self, payloads: temporalio.api.common.v1.Payloads):
304298
if self.payload_codec:
305299
await self.payload_codec.decode_wrapper(payloads)
306300

307-
async def _decode_payload_sequence(
301+
async def _encode_payload_sequence(
308302
self, payloads: Sequence[temporalio.api.common.v1.Payload]
309303
) -> list[temporalio.api.common.v1.Payload]:
310-
decoded_payloads = list(payloads)
304+
"""Codec encode only."""
305+
encoded_payloads = list(payloads)
306+
if self.payload_codec:
307+
encoded_payloads = await self.payload_codec.encode(encoded_payloads)
308+
return encoded_payloads
309+
310+
async def _external_store_payload_sequence(
311+
self, payloads: Sequence[temporalio.api.common.v1.Payload]
312+
) -> list[temporalio.api.common.v1.Payload]:
313+
"""External storage store, then validate payload limits."""
314+
stored_payloads = list(payloads)
315+
if self.external_storage:
316+
stored_payloads = await self.external_storage._store_payload_sequence(
317+
stored_payloads
318+
)
319+
return stored_payloads
320+
321+
async def _external_retrieve_payload_sequence(
322+
self, payloads: Sequence[temporalio.api.common.v1.Payload]
323+
) -> list[temporalio.api.common.v1.Payload]:
324+
"""External storage retrieve only."""
325+
retrieved_payloads = list(payloads)
311326
if self.external_storage:
312-
decoded_payloads = await self.external_storage._retrieve_payload_sequence(
313-
decoded_payloads
327+
retrieved_payloads = await self.external_storage._retrieve_payload_sequence(
328+
retrieved_payloads
314329
)
315330
else:
316331
if any(
317332
p.metadata.get("encoding") == _REFERENCE_ENCODING
318-
for p in decoded_payloads
333+
for p in retrieved_payloads
319334
):
320335
warnings.warn(
321336
"[TMPRL1105] Detected externally stored payload(s) but external storage is not configured.",
322337
StorageWarning,
323338
)
339+
return retrieved_payloads
340+
341+
async def _decode_payload_sequence(
342+
self, payloads: Sequence[temporalio.api.common.v1.Payload]
343+
) -> list[temporalio.api.common.v1.Payload]:
344+
"""Codec decode only."""
345+
decoded_payloads = list(payloads)
324346
if self.payload_codec:
325347
decoded_payloads = await self.payload_codec.decode(decoded_payloads)
326348
return decoded_payloads

temporalio/worker/_activity.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,9 @@ async def _execute_activity(
631631

632632
if self._encode_headers:
633633
for payload in start.header_fields.values():
634-
payload.CopyFrom(await data_converter._decode_payload(payload))
634+
payload.CopyFrom(
635+
await data_converter._transform_inbound_payload(payload)
636+
)
635637

636638
running_activity.info = info
637639
input = ExecuteActivityInput(

temporalio/worker/_replayer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def on_eviction_hook(
268268
"header_codec_behavior", HeaderCodecBehavior.NO_CODEC
269269
)
270270
!= HeaderCodecBehavior.NO_CODEC,
271-
max_workflow_task_payload_concurrency=1,
271+
max_workflow_task_external_storage_concurrency=1,
272272
)
273273
external_storage = data_converter.external_storage
274274
storage_driver_types = (

temporalio/worker/_worker.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@
3636
from ._nexus import _NexusWorker
3737
from ._plugin import Plugin
3838
from ._tuning import WorkerTuner
39-
from ._workflow import _DEFAULT_WORKFLOW_TASK_PAYLOAD_CONCURRENCY, _WorkflowWorker
39+
from ._workflow import (
40+
_DEFAULT_WORKFLOW_TASK_EXTERNAL_STORAGE_CONCURRENCY,
41+
_WorkflowWorker,
42+
)
4043
from ._workflow_instance import UnsandboxedWorkflowRunner, WorkflowRunner
4144
from .workflow_sandbox import SandboxedWorkflowRunner
4245

@@ -142,7 +145,7 @@ def __init__(
142145
maximum=5
143146
),
144147
disable_payload_error_limit: bool = False,
145-
max_workflow_task_payload_concurrency: int = _DEFAULT_WORKFLOW_TASK_PAYLOAD_CONCURRENCY,
148+
max_workflow_task_external_storage_concurrency: int = _DEFAULT_WORKFLOW_TASK_EXTERNAL_STORAGE_CONCURRENCY,
146149
) -> None:
147150
"""Create a worker to process workflows and/or activities.
148151
@@ -317,10 +320,10 @@ def __init__(
317320
and cause a task failure if the size limit is exceeded. The default is False.
318321
See https://docs.temporal.io/troubleshooting/blob-size-limit-error for more
319322
details.
320-
max_workflow_task_payload_concurrency: Maximum number of payload
321-
operations (codec encode/decode, external storage I/O, etc.)
322-
that may run concurrently within a single workflow task
323-
activation. Defaults to 1. WARNING: This setting is experimental.
323+
max_workflow_task_external_storage_concurrency: Maximum number of
324+
external storage I/O operations (store/retrieve) that may run
325+
concurrently within a single workflow task activation.
326+
Defaults to 10. WARNING: This setting is experimental.
324327
325328
"""
326329
config = WorkerConfig(
@@ -366,7 +369,7 @@ def __init__(
366369
activity_task_poller_behavior=activity_task_poller_behavior,
367370
nexus_task_poller_behavior=nexus_task_poller_behavior,
368371
disable_payload_error_limit=disable_payload_error_limit,
369-
max_workflow_task_payload_concurrency=max_workflow_task_payload_concurrency,
372+
max_workflow_task_external_storage_concurrency=max_workflow_task_external_storage_concurrency,
370373
)
371374

372375
plugins_from_client = cast(
@@ -420,12 +423,14 @@ def _init_from_config(self, client: temporalio.client.Client, config: WorkerConf
420423
raise ValueError(
421424
"default_versioning_behavior must be UNSPECIFIED when use_worker_versioning is False"
422425
)
423-
max_workflow_task_payload_concurrency = config.get(
424-
"max_workflow_task_payload_concurrency",
425-
_DEFAULT_WORKFLOW_TASK_PAYLOAD_CONCURRENCY,
426+
max_workflow_task_external_storage_concurrency = config.get(
427+
"max_workflow_task_external_storage_concurrency",
428+
_DEFAULT_WORKFLOW_TASK_EXTERNAL_STORAGE_CONCURRENCY,
426429
)
427-
if max_workflow_task_payload_concurrency < 1:
428-
raise ValueError("max_workflow_task_payload_concurrency must be positive")
430+
if max_workflow_task_external_storage_concurrency < 1:
431+
raise ValueError(
432+
"max_workflow_task_external_storage_concurrency must be positive"
433+
)
429434

430435
# Prepend applicable client interceptors to the given ones
431436
client_config = config["client"].config(active_config=True) # type: ignore[reportTypedDictNotRequiredAccess]
@@ -530,7 +535,7 @@ def check_activity(activity: str):
530535
assert_local_activity_valid=check_activity,
531536
encode_headers=client_config["header_codec_behavior"]
532537
!= HeaderCodecBehavior.NO_CODEC,
533-
max_workflow_task_payload_concurrency=max_workflow_task_payload_concurrency,
538+
max_workflow_task_external_storage_concurrency=max_workflow_task_external_storage_concurrency,
534539
)
535540

536541
tuner = config.get("tuner")
@@ -977,7 +982,7 @@ class WorkerConfig(TypedDict, total=False):
977982
activity_task_poller_behavior: PollerBehavior
978983
nexus_task_poller_behavior: PollerBehavior
979984
disable_payload_error_limit: bool
980-
max_workflow_task_payload_concurrency: int
985+
max_workflow_task_external_storage_concurrency: int
981986

982987

983988
def _warn_if_activity_executor_max_workers_is_inconsistent(

temporalio/worker/_workflow.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
# Set to true to log all activations and completions
4848
LOG_PROTOS = False
4949

50-
_DEFAULT_WORKFLOW_TASK_PAYLOAD_CONCURRENCY: int = 1
50+
_DEFAULT_WORKFLOW_TASK_EXTERNAL_STORAGE_CONCURRENCY: int = 10
5151

5252

5353
class _WorkflowWorker: # type:ignore[reportUnusedClass]
@@ -76,7 +76,7 @@ def __init__(
7676
should_enforce_versioning_behavior: bool,
7777
assert_local_activity_valid: Callable[[str], None],
7878
encode_headers: bool,
79-
max_workflow_task_payload_concurrency: int,
79+
max_workflow_task_external_storage_concurrency: int,
8080
) -> None:
8181
self._bridge_worker = bridge_worker
8282
self._namespace = namespace
@@ -115,8 +115,8 @@ def __init__(
115115
self._on_eviction_hook = on_eviction_hook
116116
self._disable_safe_eviction = disable_safe_eviction
117117
self._encode_headers = encode_headers
118-
self._max_workflow_task_payload_concurrency = (
119-
max_workflow_task_payload_concurrency
118+
self._max_workflow_task_external_storage_concurrency = (
119+
max_workflow_task_external_storage_concurrency
120120
)
121121
self._throw_after_activation: Exception | None = None
122122

@@ -300,7 +300,7 @@ async def _handle_activation(
300300
act,
301301
data_converter,
302302
decode_headers=self._encode_headers,
303-
concurrency_limit=self._max_workflow_task_payload_concurrency,
303+
storage_concurrency_limit=self._max_workflow_task_external_storage_concurrency,
304304
)
305305
if not workflow:
306306
assert init_job
@@ -409,7 +409,7 @@ async def _handle_activation(
409409
completion,
410410
data_converter,
411411
encode_headers=self._encode_headers,
412-
concurrency_limit=self._max_workflow_task_payload_concurrency,
412+
storage_concurrency_limit=self._max_workflow_task_external_storage_concurrency,
413413
)
414414
except temporalio.converter._payload_limits._PayloadSizeError as err:
415415
logger.warning(err.message)
@@ -893,11 +893,29 @@ async def _encode_payload_sequence(
893893
) -> list[temporalio.api.common.v1.Payload]:
894894
return await self._get_current_dc()._encode_payload_sequence(payloads)
895895

896+
async def _external_store_payload_sequence(
897+
self, payloads: Sequence[temporalio.api.common.v1.Payload]
898+
) -> list[temporalio.api.common.v1.Payload]:
899+
return await self._get_current_dc()._external_store_payload_sequence(payloads)
900+
901+
async def _external_retrieve_payload_sequence(
902+
self, payloads: Sequence[temporalio.api.common.v1.Payload]
903+
) -> list[temporalio.api.common.v1.Payload]:
904+
return await self._get_current_dc()._external_retrieve_payload_sequence(
905+
payloads
906+
)
907+
896908
async def _decode_payload_sequence(
897909
self, payloads: Sequence[temporalio.api.common.v1.Payload]
898910
) -> list[temporalio.api.common.v1.Payload]:
899911
return await self._get_current_dc()._decode_payload_sequence(payloads)
900912

913+
def _validate_payload_limits(
914+
self,
915+
payloads: Sequence[temporalio.api.common.v1.Payload],
916+
) -> None:
917+
self._get_current_dc()._validate_payload_limits(payloads)
918+
901919

902920
class _InterruptDeadlockError(BaseException):
903921
pass

0 commit comments

Comments
 (0)