Skip to content

Commit ad14668

Browse files
committed
support Ray actor AgentLoop and move judger resource config into AgentLoop
1 parent 7aa405e commit ad14668

File tree

11 files changed

+276
-52
lines changed

11 files changed

+276
-52
lines changed

tests/rl/test_agent_loop.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import unittest
3-
import asyncio
3+
import copy
44
import ray
55
import tempfile
66
import torch
@@ -79,14 +79,14 @@ async def test_gsm8k_agent_loop(self):
7979
judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", judger_type="router")
8080
agent_loop_cfg = SingleTurnAgentLoopConfig(
8181
hf_checkpoint=self.model_path,
82-
sample_params=SampleParams(max_tokens=self.max_response_length, temperature=0.0)
82+
sample_params=SampleParams(max_tokens=self.max_response_length, temperature=0.0),
83+
judger_config=judger_config,
8384
)
84-
# 2. 创建 rollout_controller, judger
85+
# 2. 创建 rollout_controller
8586
pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg)
8687
rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg)
87-
gsm8k_judger = judger_config.build()
8888
# 3. 创建 AgentLoop
89-
agent_loop = agent_loop_cfg.build(rollout_controller=rollout_controller, judger=gsm8k_judger)
89+
agent_loop = agent_loop_cfg.build(rollout_controller=rollout_controller)
9090
# 4. 构造输入数据
9191
prompt_repeat_k = 4
9292
rollout_state = FAKE_INPUT_ITEM
@@ -104,6 +104,52 @@ async def test_gsm8k_agent_loop(self):
104104
self.assertGreater(len(single_rollout_state.response_ids), 0)
105105
self.assertEqual(single_rollout_state.reward["score"], 1)
106106

107+
async def test_gsm8k_agent_loop_with_ray_actor_judger(self):
108+
self.init_config()
109+
rollout_config = RolloutConfig(
110+
env="test_agent_loop_ray_actor",
111+
model_path=self.model_path,
112+
model_name=os.path.basename(self.model_path).lower(),
113+
tokenizer_path=self.model_path,
114+
context_length=self.context_length,
115+
worker_log_dir=self.worker_log_dir,
116+
)
117+
judger_config = GSM8KJudgerConfig(
118+
judger_name="openai/gsm8k",
119+
judger_type="ray.actor",
120+
num_cpus_per_actor=1,
121+
)
122+
agent_loop_cfg = SingleTurnAgentLoopConfig(
123+
hf_checkpoint=self.model_path,
124+
sample_params=SampleParams(max_tokens=self.max_response_length, temperature=0.0),
125+
judger_config=judger_config,
126+
type="ray.actor",
127+
num_cpus=1,
128+
)
129+
130+
self.assertEqual(agent_loop_cfg._get_agent_loop_cpu_bundle()["CPU"], 1)
131+
self.assertEqual(agent_loop_cfg._get_judger_cpu_bundles(), [{"CPU": 1, "memory": 1024**3}])
132+
133+
pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg)
134+
rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg)
135+
agent_loop = agent_loop_cfg.build(rollout_controller=rollout_controller)
136+
137+
prompt_repeat_k = 2
138+
rollout_state = copy.deepcopy(FAKE_INPUT_ITEM)
139+
group_in_rollout_state = [copy.deepcopy(FAKE_INPUT_ITEM) for _ in range(prompt_repeat_k)]
140+
141+
group_rollout_state = await agent_loop.generate_group.remote(group_in_rollout_state)
142+
single_rollout_state = await agent_loop.generate_sample.remote(rollout_state)
143+
144+
self.assertEqual(len(group_rollout_state), prompt_repeat_k)
145+
for state in group_rollout_state:
146+
self.assertEqual(state.status, Status.COMPLETED)
147+
self.assertGreater(len(state.response_ids), 0)
148+
self.assertEqual(state.reward["score"], 1)
149+
self.assertEqual(single_rollout_state.status, Status.COMPLETED)
150+
self.assertGreater(len(single_rollout_state.response_ids), 0)
151+
self.assertEqual(single_rollout_state.reward["score"], 1)
152+
107153
async def test_gsm8k_agent_loop_manager(self):
108154
# 1. 初始化 config
109155
self.init_config()
@@ -118,7 +164,8 @@ async def test_gsm8k_agent_loop_manager(self):
118164
judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", judger_type="router")
119165
agent_loop_cfg = SingleTurnAgentLoopConfig(
120166
hf_checkpoint=self.model_path,
121-
sample_params=SampleParams(max_tokens=self.max_response_length, temperature=0.0)
167+
sample_params=SampleParams(max_tokens=self.max_response_length, temperature=0.0),
168+
judger_config=judger_config,
122169
)
123170
sampler_config = SamplerConfig(
124171
dataloader_cfg=DataloaderConfig(
@@ -146,16 +193,14 @@ async def test_gsm8k_agent_loop_manager(self):
146193
)
147194
],
148195
)
149-
# 2. 创建 rollout_controller, judger
196+
# 2. 创建 rollout_controller
150197
pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg)
151198
rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg)
152-
gsm8k_judger = judger_config.build()
153199
# 3. 创建 AgentLoopManager
154200
replay_buffer_cfg = SyncReplayBufferConfig()
155201
replay_buffer = replay_buffer_cfg.build()
156202
agent_loop_manager = agent_loop_manager_cfg.build(
157203
rollout_controller=rollout_controller,
158-
judger=gsm8k_judger,
159204
tokenizer=self.tokenizer,
160205
replay_buffer=replay_buffer,
161206
)

tests/rl/test_async_rollout.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ def _build_agent_loop_manager(
118118
)
119119
manager = manager_cfg.build(
120120
rollout_controller=rollout_ctl,
121-
judger=None,
122121
tokenizer=tokenizer,
123122
replay_buffer=replay_buffer,
124123
logger=None,

tests/rl/test_rl_colocate_trainer_integration.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def build_trainer_config(self, work_dir, checkpoint_interval=1, checkpoint_maxke
159159
agent_loop_config = SingleTurnAgentLoopConfig(
160160
hf_checkpoint=model_path,
161161
sample_params=training_sample_params,
162+
judger_config=judger_config,
162163
)
163164
produce_strategy_config = SyncProduceStrategyConfig()
164165
agent_loop_manager_cfg = AgentLoopManagerConfig(
@@ -180,6 +181,7 @@ def build_trainer_config(self, work_dir, checkpoint_interval=1, checkpoint_maxke
180181
eval_agent_loop_config = SingleTurnAgentLoopConfig(
181182
hf_checkpoint=model_path,
182183
sample_params=SampleParams(max_tokens=512, top_k=1, temperature=0.0),
184+
judger_config=judger_config,
183185
)
184186
eval_agent_loop_manager_cfg = AgentLoopManagerConfig(
185187
tasks=[
@@ -198,7 +200,6 @@ def build_trainer_config(self, work_dir, checkpoint_interval=1, checkpoint_maxke
198200
resources=resources,
199201
train_worker_cfg=train_worker_cfg,
200202
rollout_config=rollout_config,
201-
judger_config=judger_config,
202203
tokenizer_path=model_path,
203204
replay_buffer_config=SyncReplayBufferConfig(),
204205
agent_loop_manager_cfg=agent_loop_manager_cfg,

xtuner/v1/rl/agent_loop/__init__.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
1-
from .agent_loop import AgentLoop, AgentLoopConfig
1+
from .agent_loop import (
2+
AgentLoop,
3+
AgentLoopActor,
4+
AgentLoopConfig,
5+
JudgerConfigSpec,
6+
JudgerLike,
7+
JudgerSpec,
8+
RayAgentLoop,
9+
RayAgentLoopProxy,
10+
)
211
from .agent_loop_manager import (
312
AgentLoopManager,
413
AgentLoopManagerConfig,
@@ -21,7 +30,13 @@
2130
"AgentLoopConfig",
2231
"SingleTurnAgentLoopConfig",
2332
"AgentLoop",
33+
"AgentLoopActor",
34+
"RayAgentLoop",
35+
"RayAgentLoopProxy",
2436
"SingleTurnAgentLoop",
37+
"JudgerLike",
38+
"JudgerSpec",
39+
"JudgerConfigSpec",
2540
"AgentLoopManagerConfig",
2641
"AgentLoopManager",
2742
"TaskSpecConfig",

xtuner/v1/rl/agent_loop/agent_loop.py

Lines changed: 158 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,132 @@
11
import asyncio
2+
import inspect
3+
import os
24
from abc import ABC, abstractmethod
3-
from typing import Callable
5+
from typing import Awaitable, Callable, Literal, TypeAlias, cast
46

5-
from pydantic import BaseModel, ConfigDict
7+
import ray
8+
from pydantic import BaseModel, ConfigDict, Field
9+
from ray.actor import ActorClass, ActorProxy
10+
from ray.util.placement_group import PlacementGroup, placement_group
11+
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
612

713
from xtuner.v1.data_proto import RolloutState, SampleParams
8-
from xtuner.v1.rl.judger import NativeJudger, RouterJudger
14+
from xtuner.v1.rl.judger import Judger, JudgerConfig
915
from xtuner.v1.rl.rollout import RolloutController
1016
from xtuner.v1.rl.utils import create_task
11-
from xtuner.v1.utils import get_logger
17+
from xtuner.v1.utils import get_logger, ray_method
1218
from xtuner.v1.utils.processing_utils import load_processor, load_tokenizer
1319

1420

21+
PG_READY_TIMEOUT = os.getenv("XTUNER_PG_READY_TIMEOUT", 30) # default 30 seconds
22+
23+
JudgerCallable: TypeAlias = Callable[[RolloutState], RolloutState | Awaitable[RolloutState]]
24+
JudgerLike: TypeAlias = Judger | JudgerCallable
25+
JudgerSpec: TypeAlias = JudgerLike | dict[str, JudgerLike] | None
26+
JudgerConfigLike: TypeAlias = JudgerConfig | JudgerCallable
27+
JudgerConfigSpec: TypeAlias = JudgerConfigLike | dict[str, JudgerConfigLike] | None
28+
29+
1530
class AgentLoopConfig(ABC, BaseModel):
1631
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
1732
hf_checkpoint: str
1833
sample_params: SampleParams
34+
judger_config: JudgerConfigSpec = None
35+
type: Literal["local", "ray.actor"] = "local"
36+
num_cpus: float = Field(default=1, gt=0, description="CPU cores required by the AgentLoop actor itself.")
37+
cpu_memory: int = Field(default=1024**3, gt=0, description="CPU memory in bytes required by AgentLoop.")
38+
39+
def _get_agent_loop_cpu_bundle(self) -> dict[str, float | int]:
40+
return {"CPU": self.num_cpus, "memory": self.cpu_memory}
41+
42+
def _get_judger_cpu_bundles(self) -> list[dict[str, float | int]]:
43+
if self.judger_config is None:
44+
return []
45+
if isinstance(self.judger_config, dict):
46+
judger_configs = [config for config in self.judger_config.values() if isinstance(config, JudgerConfig)]
47+
elif isinstance(self.judger_config, JudgerConfig):
48+
judger_configs = [self.judger_config]
49+
else:
50+
judger_configs = []
51+
52+
bundles: list[dict[str, float | int]] = []
53+
for judger_config in judger_configs:
54+
bundles.extend(judger_config.get_cpu_bundles())
55+
return bundles
56+
57+
def _build_cpu_placement_group(self, strategy: str = "SPREAD") -> PlacementGroup:
58+
assert ray.is_initialized(), "Ray must be initialized before building AgentLoop placement groups."
59+
bundle_specs = [self._get_agent_loop_cpu_bundle(), *self._get_judger_cpu_bundles()]
60+
pg = placement_group(bundles=bundle_specs, strategy=strategy)
61+
ray.get(pg.ready(), timeout=PG_READY_TIMEOUT)
62+
return pg
63+
64+
def build_judger(self, pg: PlacementGroup | None = None, start_bundle_idx: int = 0) -> JudgerSpec:
65+
if self.judger_config is None:
66+
return None
67+
68+
if isinstance(self.judger_config, dict):
69+
judger_dict = {}
70+
bundle_idx = start_bundle_idx
71+
for key, config in self.judger_config.items():
72+
if isinstance(config, JudgerConfig):
73+
judger_dict[key] = config.build(pg=pg, start_bundle_idx=bundle_idx)
74+
bundle_idx += config.get_num_placement_group_bundles()
75+
elif callable(config):
76+
judger_dict[key] = config
77+
else:
78+
raise ValueError(f"Invalid judger config type: {type(config)} for key {key}")
79+
return judger_dict
80+
elif isinstance(self.judger_config, JudgerConfig):
81+
return self.judger_config.build(pg=pg, start_bundle_idx=start_bundle_idx)
82+
elif callable(self.judger_config):
83+
return self.judger_config
84+
else:
85+
raise ValueError(f"Invalid judger config type: {type(self.judger_config)}")
86+
87+
def build(self, rollout_controller, logger=None) -> "AgentLoop | RayAgentLoopProxy":
88+
if self.type == "local":
89+
return self.build_local(
90+
rollout_controller=rollout_controller,
91+
logger=logger,
92+
)
93+
if self.type == "ray.actor":
94+
pg = self._build_cpu_placement_group()
95+
return self._build_ray_actor(
96+
rollout_controller=rollout_controller,
97+
pg=pg,
98+
logger=logger,
99+
)
100+
raise ValueError(f"Invalid agent loop type: {self.type}")
19101

20102
@abstractmethod
21-
def build(self, rollout_controller, judger=None, logger=None) -> "AgentLoop": ...
103+
def build_local(
104+
self,
105+
rollout_controller,
106+
logger=None,
107+
pg: PlacementGroup | None = None,
108+
start_bundle_idx: int = 0,
109+
) -> "AgentLoop": ...
110+
111+
def _build_ray_actor(
112+
self,
113+
rollout_controller: RolloutController,
114+
pg: PlacementGroup,
115+
logger=None,
116+
) -> "RayAgentLoopProxy":
117+
scheduling_strategy = PlacementGroupSchedulingStrategy(
118+
placement_group=pg,
119+
placement_group_bundle_index=0,
120+
placement_group_capture_child_tasks=True,
121+
)
122+
return RayAgentLoop.options(
123+
num_cpus=self.num_cpus,
124+
scheduling_strategy=scheduling_strategy,
125+
).remote(
126+
self,
127+
rollout_controller,
128+
logger,
129+
)
22130

23131

24132
class AgentLoop(ABC):
@@ -27,7 +135,7 @@ def __init__(
27135
rollout_ctl: RolloutController,
28136
sample_params: SampleParams,
29137
hf_checkpoint: str,
30-
judger: Callable | NativeJudger | RouterJudger | None = None,
138+
judger: JudgerSpec = None,
31139
logger=None,
32140
) -> None:
33141
self.rollout_ctl = rollout_ctl
@@ -57,10 +165,49 @@ async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> l
57165
async def judge_sample(self, rollout_state: RolloutState) -> RolloutState:
58166
if self.judger is None:
59167
return rollout_state
60-
if callable(self.judger):
61-
rollout_state = await self.judger(rollout_state)
62-
elif isinstance(self.judger, RouterJudger) or isinstance(self.judger, NativeJudger):
63-
rollout_state = await self.judger.judge(rollout_state) # type: ignore[operator]
168+
169+
judger = self.judger
170+
if isinstance(judger, dict):
171+
if len(judger) > 1:
172+
raise NotImplementedError("Multiple judgers require a custom AgentLoop.judge_sample implementation.")
173+
judger = next(iter(judger.values()))
174+
175+
if isinstance(judger, Judger):
176+
rollout_state = await judger.judge(rollout_state)
177+
elif isinstance(judger, ray.actor.ActorHandle):
178+
rollout_state = await judger.judge.remote(rollout_state)
179+
elif callable(judger):
180+
judger_result = judger(rollout_state)
181+
if inspect.isawaitable(judger_result):
182+
rollout_state = await judger_result
183+
else:
184+
rollout_state = judger_result
64185
else:
65-
raise ValueError(f"Invalid judger type: {type(self.judger)}")
186+
raise ValueError(f"Invalid judger type: {type(judger)}")
187+
188+
if not isinstance(rollout_state, RolloutState):
189+
raise TypeError(f"Judger must return RolloutState, but got {type(rollout_state)}")
66190
return rollout_state
191+
192+
193+
class AgentLoopActor:
194+
def __init__(self, agent_loop_config: AgentLoopConfig, rollout_controller: RolloutController, logger=None):
195+
current_pg = ray.util.get_current_placement_group()
196+
self.agent_loop = agent_loop_config.build_local(
197+
rollout_controller=rollout_controller,
198+
logger=logger,
199+
pg=current_pg,
200+
start_bundle_idx=1,
201+
)
202+
203+
@ray_method
204+
async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState:
205+
return await self.agent_loop.generate_sample(rollout_state, **kwargs)
206+
207+
@ray_method
208+
async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]:
209+
return await self.agent_loop.generate_group(rollout_state, **kwargs)
210+
211+
212+
RayAgentLoop = cast(ActorClass[AgentLoopActor], ray.remote(AgentLoopActor))
213+
RayAgentLoopProxy: TypeAlias = ActorProxy[AgentLoopActor]

0 commit comments

Comments
 (0)