11import asyncio
2+ import inspect
3+ import os
24from abc import ABC , abstractmethod
3- from typing import Callable
5+ from typing import Awaitable , Callable , Literal , TypeAlias , cast
46
5- from pydantic import BaseModel , ConfigDict
7+ import ray
8+ from pydantic import BaseModel , ConfigDict , Field
9+ from ray .actor import ActorClass , ActorProxy
10+ from ray .util .placement_group import PlacementGroup , placement_group
11+ from ray .util .scheduling_strategies import PlacementGroupSchedulingStrategy
612
713from xtuner .v1 .data_proto import RolloutState , SampleParams
8- from xtuner .v1 .rl .judger import NativeJudger , RouterJudger
14+ from xtuner .v1 .rl .judger import Judger , JudgerConfig
915from xtuner .v1 .rl .rollout import RolloutController
1016from xtuner .v1 .rl .utils import create_task
11- from xtuner .v1 .utils import get_logger
17+ from xtuner .v1 .utils import get_logger , ray_method
1218from xtuner .v1 .utils .processing_utils import load_processor , load_tokenizer
1319
1420
21+ PG_READY_TIMEOUT = os .getenv ("XTUNER_PG_READY_TIMEOUT" , 30 ) # default 30 seconds
22+
23+ JudgerCallable : TypeAlias = Callable [[RolloutState ], RolloutState | Awaitable [RolloutState ]]
24+ JudgerLike : TypeAlias = Judger | JudgerCallable
25+ JudgerSpec : TypeAlias = JudgerLike | dict [str , JudgerLike ] | None
26+ JudgerConfigLike : TypeAlias = JudgerConfig | JudgerCallable
27+ JudgerConfigSpec : TypeAlias = JudgerConfigLike | dict [str , JudgerConfigLike ] | None
28+
29+
1530class AgentLoopConfig (ABC , BaseModel ):
1631 model_config = ConfigDict (extra = "forbid" , arbitrary_types_allowed = True )
1732 hf_checkpoint : str
1833 sample_params : SampleParams
34+ judger_config : JudgerConfigSpec = None
35+ type : Literal ["local" , "ray.actor" ] = "local"
36+ num_cpus : float = Field (default = 1 , gt = 0 , description = "CPU cores required by the AgentLoop actor itself." )
37+ cpu_memory : int = Field (default = 1024 ** 3 , gt = 0 , description = "CPU memory in bytes required by AgentLoop." )
38+
39+ def _get_agent_loop_cpu_bundle (self ) -> dict [str , float | int ]:
40+ return {"CPU" : self .num_cpus , "memory" : self .cpu_memory }
41+
42+ def _get_judger_cpu_bundles (self ) -> list [dict [str , float | int ]]:
43+ if self .judger_config is None :
44+ return []
45+ if isinstance (self .judger_config , dict ):
46+ judger_configs = [config for config in self .judger_config .values () if isinstance (config , JudgerConfig )]
47+ elif isinstance (self .judger_config , JudgerConfig ):
48+ judger_configs = [self .judger_config ]
49+ else :
50+ judger_configs = []
51+
52+ bundles : list [dict [str , float | int ]] = []
53+ for judger_config in judger_configs :
54+ bundles .extend (judger_config .get_cpu_bundles ())
55+ return bundles
56+
57+ def _build_cpu_placement_group (self , strategy : str = "SPREAD" ) -> PlacementGroup :
58+ assert ray .is_initialized (), "Ray must be initialized before building AgentLoop placement groups."
59+ bundle_specs = [self ._get_agent_loop_cpu_bundle (), * self ._get_judger_cpu_bundles ()]
60+ pg = placement_group (bundles = bundle_specs , strategy = strategy )
61+ ray .get (pg .ready (), timeout = PG_READY_TIMEOUT )
62+ return pg
63+
64+ def build_judger (self , pg : PlacementGroup | None = None , start_bundle_idx : int = 0 ) -> JudgerSpec :
65+ if self .judger_config is None :
66+ return None
67+
68+ if isinstance (self .judger_config , dict ):
69+ judger_dict = {}
70+ bundle_idx = start_bundle_idx
71+ for key , config in self .judger_config .items ():
72+ if isinstance (config , JudgerConfig ):
73+ judger_dict [key ] = config .build (pg = pg , start_bundle_idx = bundle_idx )
74+ bundle_idx += config .get_num_placement_group_bundles ()
75+ elif callable (config ):
76+ judger_dict [key ] = config
77+ else :
78+ raise ValueError (f"Invalid judger config type: { type (config )} for key { key } " )
79+ return judger_dict
80+ elif isinstance (self .judger_config , JudgerConfig ):
81+ return self .judger_config .build (pg = pg , start_bundle_idx = start_bundle_idx )
82+ elif callable (self .judger_config ):
83+ return self .judger_config
84+ else :
85+ raise ValueError (f"Invalid judger config type: { type (self .judger_config )} " )
86+
87+ def build (self , rollout_controller , logger = None ) -> "AgentLoop | RayAgentLoopProxy" :
88+ if self .type == "local" :
89+ return self .build_local (
90+ rollout_controller = rollout_controller ,
91+ logger = logger ,
92+ )
93+ if self .type == "ray.actor" :
94+ pg = self ._build_cpu_placement_group ()
95+ return self ._build_ray_actor (
96+ rollout_controller = rollout_controller ,
97+ pg = pg ,
98+ logger = logger ,
99+ )
100+ raise ValueError (f"Invalid agent loop type: { self .type } " )
19101
20102 @abstractmethod
21- def build (self , rollout_controller , judger = None , logger = None ) -> "AgentLoop" : ...
103+ def build_local (
104+ self ,
105+ rollout_controller ,
106+ logger = None ,
107+ pg : PlacementGroup | None = None ,
108+ start_bundle_idx : int = 0 ,
109+ ) -> "AgentLoop" : ...
110+
111+ def _build_ray_actor (
112+ self ,
113+ rollout_controller : RolloutController ,
114+ pg : PlacementGroup ,
115+ logger = None ,
116+ ) -> "RayAgentLoopProxy" :
117+ scheduling_strategy = PlacementGroupSchedulingStrategy (
118+ placement_group = pg ,
119+ placement_group_bundle_index = 0 ,
120+ placement_group_capture_child_tasks = True ,
121+ )
122+ return RayAgentLoop .options (
123+ num_cpus = self .num_cpus ,
124+ scheduling_strategy = scheduling_strategy ,
125+ ).remote (
126+ self ,
127+ rollout_controller ,
128+ logger ,
129+ )
22130
23131
24132class AgentLoop (ABC ):
@@ -27,7 +135,7 @@ def __init__(
27135 rollout_ctl : RolloutController ,
28136 sample_params : SampleParams ,
29137 hf_checkpoint : str ,
30- judger : Callable | NativeJudger | RouterJudger | None = None ,
138+ judger : JudgerSpec = None ,
31139 logger = None ,
32140 ) -> None :
33141 self .rollout_ctl = rollout_ctl
@@ -57,10 +165,49 @@ async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> l
57165 async def judge_sample (self , rollout_state : RolloutState ) -> RolloutState :
58166 if self .judger is None :
59167 return rollout_state
60- if callable (self .judger ):
61- rollout_state = await self .judger (rollout_state )
62- elif isinstance (self .judger , RouterJudger ) or isinstance (self .judger , NativeJudger ):
63- rollout_state = await self .judger .judge (rollout_state ) # type: ignore[operator]
168+
169+ judger = self .judger
170+ if isinstance (judger , dict ):
171+ if len (judger ) > 1 :
172+ raise NotImplementedError ("Multiple judgers require a custom AgentLoop.judge_sample implementation." )
173+ judger = next (iter (judger .values ()))
174+
175+ if isinstance (judger , Judger ):
176+ rollout_state = await judger .judge (rollout_state )
177+ elif isinstance (judger , ray .actor .ActorHandle ):
178+ rollout_state = await judger .judge .remote (rollout_state )
179+ elif callable (judger ):
180+ judger_result = judger (rollout_state )
181+ if inspect .isawaitable (judger_result ):
182+ rollout_state = await judger_result
183+ else :
184+ rollout_state = judger_result
64185 else :
65- raise ValueError (f"Invalid judger type: { type (self .judger )} " )
186+ raise ValueError (f"Invalid judger type: { type (judger )} " )
187+
188+ if not isinstance (rollout_state , RolloutState ):
189+ raise TypeError (f"Judger must return RolloutState, but got { type (rollout_state )} " )
66190 return rollout_state
191+
192+
193+ class AgentLoopActor :
194+ def __init__ (self , agent_loop_config : AgentLoopConfig , rollout_controller : RolloutController , logger = None ):
195+ current_pg = ray .util .get_current_placement_group ()
196+ self .agent_loop = agent_loop_config .build_local (
197+ rollout_controller = rollout_controller ,
198+ logger = logger ,
199+ pg = current_pg ,
200+ start_bundle_idx = 1 ,
201+ )
202+
203+ @ray_method
204+ async def generate_sample (self , rollout_state : RolloutState , ** kwargs ) -> RolloutState :
205+ return await self .agent_loop .generate_sample (rollout_state , ** kwargs )
206+
207+ @ray_method
208+ async def generate_group (self , rollout_state : list [RolloutState ], ** kwargs ) -> list [RolloutState ]:
209+ return await self .agent_loop .generate_group (rollout_state , ** kwargs )
210+
211+
212+ RayAgentLoop = cast (ActorClass [AgentLoopActor ], ray .remote (AgentLoopActor ))
213+ RayAgentLoopProxy : TypeAlias = ActorProxy [AgentLoopActor ]
0 commit comments