diff --git a/src/guidellm/scheduler/strategies.py b/src/guidellm/scheduler/strategies.py index 4b1adf129..ff8e76a4c 100644 --- a/src/guidellm/scheduler/strategies.py +++ b/src/guidellm/scheduler/strategies.py @@ -19,7 +19,8 @@ import math import random from abc import abstractmethod -from multiprocessing import Event, Value, synchronize +from multiprocessing import synchronize +from multiprocessing.context import BaseContext from multiprocessing.sharedctypes import Synchronized from typing import Annotated, ClassVar, Literal, TypeVar @@ -103,7 +104,10 @@ def requests_limit(self) -> PositiveInt | None: return None def init_processes_timings( - self, worker_count: PositiveInt, max_concurrency: PositiveInt + self, + worker_count: PositiveInt, + max_concurrency: PositiveInt, + mp_context: BaseContext, ): """ Initialize shared timing state for multi-process coordination. @@ -117,9 +121,9 @@ def init_processes_timings( self.worker_count = worker_count self.max_concurrency = max_concurrency - self._processes_init_event = Event() - self._processes_request_index = Value("i", 0) - self._processes_start_time = Value("d", -1.0) + self._processes_init_event = mp_context.Event() + self._processes_request_index = mp_context.Value("i", 0) + self._processes_start_time = mp_context.Value("d", -1.0) def init_processes_start(self, start_time: float): """ @@ -593,7 +597,12 @@ def requests_limit(self) -> PositiveInt | None: """ return self.max_concurrency - def init_processes_timings(self, worker_count: int, max_concurrency: int): + def init_processes_timings( + self, + worker_count: PositiveInt, + max_concurrency: PositiveInt, + mp_context: BaseContext, + ): """ Initialize Poisson-specific timing state. @@ -603,10 +612,10 @@ def init_processes_timings(self, worker_count: int, max_concurrency: int): :param worker_count: Number of worker processes to coordinate :param max_concurrency: Maximum number of concurrent requests allowed """ - self._offset = Value("d", -1.0) + self._offset = mp_context.Value("d", -1.0) # Call base implementation last to avoid # setting Event before offset is ready - super().init_processes_timings(worker_count, max_concurrency) + super().init_processes_timings(worker_count, max_concurrency, mp_context) def init_processes_start(self, start_time: float): """ diff --git a/src/guidellm/scheduler/worker_group.py b/src/guidellm/scheduler/worker_group.py index 823617850..13a0074a4 100644 --- a/src/guidellm/scheduler/worker_group.py +++ b/src/guidellm/scheduler/worker_group.py @@ -221,7 +221,9 @@ async def create_processes(self): # Initialize worker processes self.processes = [] self.strategy.init_processes_timings( - worker_count=num_processes, max_concurrency=max_conc + worker_count=num_processes, + max_concurrency=max_conc, + mp_context=self.mp_context, ) for rank in range(num_processes): # Distribute any remainder across the first N ranks diff --git a/tests/unit/scheduler/test_strategies.py b/tests/unit/scheduler/test_strategies.py index bd8d65f42..667720e4a 100644 --- a/tests/unit/scheduler/test_strategies.py +++ b/tests/unit/scheduler/test_strategies.py @@ -2,6 +2,7 @@ import math import time +from multiprocessing import get_context from typing import Literal, TypeVar import pytest @@ -502,7 +503,9 @@ async def test_timing_without_rampup(self): ### WRITTEN BY AI ### """ strategy = AsyncConstantStrategy(rate=10.0, rampup_duration=0.0) - strategy.init_processes_timings(worker_count=1, max_concurrency=100) + strategy.init_processes_timings( + worker_count=1, max_concurrency=100, mp_context=get_context() + ) start_time = 1000.0 strategy.init_processes_start(start_time) @@ -525,7 +528,9 @@ async def test_timing_with_rampup(self): rate = 10.0 rampup_duration = 2.0 strategy = AsyncConstantStrategy(rate=rate, rampup_duration=rampup_duration) - strategy.init_processes_timings(worker_count=1, max_concurrency=100) + strategy.init_processes_timings( + worker_count=1, max_concurrency=100, mp_context=get_context() + ) start_time = 1000.0 strategy.init_processes_start(start_time) @@ -574,7 +579,9 @@ async def test_timing_with_rampup_edge_cases(self): # Test with very short rampup_duration strategy = AsyncConstantStrategy(rate=100.0, rampup_duration=0.01) - strategy.init_processes_timings(worker_count=1, max_concurrency=100) + strategy.init_processes_timings( + worker_count=1, max_concurrency=100, mp_context=get_context() + ) start_time = 2000.0 strategy.init_processes_start(start_time) @@ -584,7 +591,9 @@ async def test_timing_with_rampup_edge_cases(self): # Test with very long rampup_duration strategy2 = AsyncConstantStrategy(rate=1.0, rampup_duration=100.0) - strategy2.init_processes_timings(worker_count=1, max_concurrency=100) + strategy2.init_processes_timings( + worker_count=1, max_concurrency=100, mp_context=get_context() + ) start_time2 = 3000.0 strategy2.init_processes_start(start_time2) @@ -613,7 +622,9 @@ async def test_timing_rampup_transition(self): rate = 10.0 rampup_duration = 2.0 strategy = AsyncConstantStrategy(rate=rate, rampup_duration=rampup_duration) - strategy.init_processes_timings(worker_count=1, max_concurrency=100) + strategy.init_processes_timings( + worker_count=1, max_concurrency=100, mp_context=get_context() + ) start_time = 5000.0 strategy.init_processes_start(start_time)