Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions src/guidellm/scheduler/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import math
import random
from abc import abstractmethod
from multiprocessing import Event, Value, synchronize
from multiprocessing import synchronize
from multiprocessing.context import BaseContext
from multiprocessing.sharedctypes import Synchronized
from typing import Annotated, ClassVar, Literal, TypeVar

Expand Down Expand Up @@ -103,7 +104,10 @@ def requests_limit(self) -> PositiveInt | None:
return None

def init_processes_timings(
self, worker_count: PositiveInt, max_concurrency: PositiveInt
self,
worker_count: PositiveInt,
max_concurrency: PositiveInt,
mp_context: BaseContext,
):
"""
Initialize shared timing state for multi-process coordination.
Expand All @@ -117,9 +121,9 @@ def init_processes_timings(
self.worker_count = worker_count
self.max_concurrency = max_concurrency

self._processes_init_event = Event()
self._processes_request_index = Value("i", 0)
self._processes_start_time = Value("d", -1.0)
self._processes_init_event = mp_context.Event()
self._processes_request_index = mp_context.Value("i", 0)
self._processes_start_time = mp_context.Value("d", -1.0)

def init_processes_start(self, start_time: float):
"""
Expand Down Expand Up @@ -593,7 +597,12 @@ def requests_limit(self) -> PositiveInt | None:
"""
return self.max_concurrency

def init_processes_timings(self, worker_count: int, max_concurrency: int):
def init_processes_timings(
self,
worker_count: PositiveInt,
max_concurrency: PositiveInt,
mp_context: BaseContext,
):
"""
Initialize Poisson-specific timing state.

Expand All @@ -603,10 +612,10 @@ def init_processes_timings(self, worker_count: int, max_concurrency: int):
:param worker_count: Number of worker processes to coordinate
:param max_concurrency: Maximum number of concurrent requests allowed
"""
self._offset = Value("d", -1.0)
self._offset = mp_context.Value("d", -1.0)
# Call base implementation last to avoid
# setting Event before offset is ready
super().init_processes_timings(worker_count, max_concurrency)
super().init_processes_timings(worker_count, max_concurrency, mp_context)

def init_processes_start(self, start_time: float):
"""
Expand Down
4 changes: 3 additions & 1 deletion src/guidellm/scheduler/worker_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,9 @@ async def create_processes(self):
# Initialize worker processes
self.processes = []
self.strategy.init_processes_timings(
worker_count=num_processes, max_concurrency=max_conc
worker_count=num_processes,
max_concurrency=max_conc,
mp_context=self.mp_context,
)
for rank in range(num_processes):
# Distribute any remainder across the first N ranks
Expand Down
21 changes: 16 additions & 5 deletions tests/unit/scheduler/test_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import math
import time
from multiprocessing import get_context
from typing import Literal, TypeVar

import pytest
Expand Down Expand Up @@ -502,7 +503,9 @@ async def test_timing_without_rampup(self):
### WRITTEN BY AI ###
"""
strategy = AsyncConstantStrategy(rate=10.0, rampup_duration=0.0)
strategy.init_processes_timings(worker_count=1, max_concurrency=100)
strategy.init_processes_timings(
worker_count=1, max_concurrency=100, mp_context=get_context()
)
start_time = 1000.0
strategy.init_processes_start(start_time)

Expand All @@ -525,7 +528,9 @@ async def test_timing_with_rampup(self):
rate = 10.0
rampup_duration = 2.0
strategy = AsyncConstantStrategy(rate=rate, rampup_duration=rampup_duration)
strategy.init_processes_timings(worker_count=1, max_concurrency=100)
strategy.init_processes_timings(
worker_count=1, max_concurrency=100, mp_context=get_context()
)
start_time = 1000.0
strategy.init_processes_start(start_time)

Expand Down Expand Up @@ -574,7 +579,9 @@ async def test_timing_with_rampup_edge_cases(self):

# Test with very short rampup_duration
strategy = AsyncConstantStrategy(rate=100.0, rampup_duration=0.01)
strategy.init_processes_timings(worker_count=1, max_concurrency=100)
strategy.init_processes_timings(
worker_count=1, max_concurrency=100, mp_context=get_context()
)
start_time = 2000.0
strategy.init_processes_start(start_time)

Expand All @@ -584,7 +591,9 @@ async def test_timing_with_rampup_edge_cases(self):

# Test with very long rampup_duration
strategy2 = AsyncConstantStrategy(rate=1.0, rampup_duration=100.0)
strategy2.init_processes_timings(worker_count=1, max_concurrency=100)
strategy2.init_processes_timings(
worker_count=1, max_concurrency=100, mp_context=get_context()
)
start_time2 = 3000.0
strategy2.init_processes_start(start_time2)

Expand Down Expand Up @@ -613,7 +622,9 @@ async def test_timing_rampup_transition(self):
rate = 10.0
rampup_duration = 2.0
strategy = AsyncConstantStrategy(rate=rate, rampup_duration=rampup_duration)
strategy.init_processes_timings(worker_count=1, max_concurrency=100)
strategy.init_processes_timings(
worker_count=1, max_concurrency=100, mp_context=get_context()
)
start_time = 5000.0
strategy.init_processes_start(start_time)

Expand Down
Loading