Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/v1/config/rl_dapo_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@
tasks=TaskSpecConfig(
task_name="train_task",
agent_loop_config=agent_loop_config,
judger_config=judger_config,
produce_strategy_config=produce_strategy_config,
sampler_config=sampler_config,
),
Expand Down Expand Up @@ -178,6 +179,7 @@
tasks=TaskSpecConfig(
task_name="eval_task",
agent_loop_config=eval_agent_loop_config,
judger_config=judger_config,
sampler_config=eval_sampler_config,
),
)
Expand All @@ -191,7 +193,6 @@ def dapo_compute_metric(samples):
resources=resources,
train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config
rollout_config=rollout_config,
judger_config=judger_config,
tokenizer_path=model_path,
replay_buffer_config=SyncReplayBufferConfig(),
agent_loop_manager_cfg=agent_loop_manager_cfg,
Expand Down
3 changes: 2 additions & 1 deletion examples/v1/config/rl_dapo_math_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
tasks=TaskSpecConfig(
task_name="train_task",
agent_loop_config=agent_loop_config,
judger_config=judger_config,
produce_strategy_config=produce_strategy_config,
sampler_config=sampler_config,
),
Expand Down Expand Up @@ -181,6 +182,7 @@
tasks=TaskSpecConfig(
task_name="eval_task",
agent_loop_config=eval_agent_loop_config,
judger_config=judger_config,
sampler_config=eval_sampler_config,
),
)
Expand All @@ -194,7 +196,6 @@ def dapo_compute_metric(samples):
resources=resources,
train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config
rollout_config=rollout_config,
judger_config=judger_config,
tokenizer_path=model_path,
replay_buffer_config=AsyncReplayBufferConfig(),
agent_loop_manager_cfg=agent_loop_manager_cfg,
Expand Down
3 changes: 2 additions & 1 deletion examples/v1/config/rl_dapo_math_async_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def group_samples_filter_func(rollout_states):
tasks=TaskSpecConfig(
task_name="train_task",
agent_loop_config=agent_loop_config,
judger_config=judger_config,
produce_strategy_config=produce_strategy_config,
sampler_config=sampler_config,
),
Expand Down Expand Up @@ -196,6 +197,7 @@ def group_samples_filter_func(rollout_states):
tasks=TaskSpecConfig(
task_name="eval_task",
agent_loop_config=eval_agent_loop_config,
judger_config=judger_config,
sampler_config=eval_sampler_config,
),
)
Expand All @@ -209,7 +211,6 @@ def dapo_compute_metric(samples):
resources=resources,
train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config
rollout_config=rollout_config,
judger_config=judger_config,
tokenizer_path=model_path,
replay_buffer_config=AsyncReplayBufferConfig(),
agent_loop_manager_cfg=agent_loop_manager_cfg,
Expand Down
3 changes: 2 additions & 1 deletion examples/v1/config/rl_grpo_geo3k_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@
tasks=TaskSpecConfig(
task_name="train_task",
agent_loop_config=agent_loop_config,
judger_config=judger_config,
produce_strategy_config=produce_strategy_config,
sampler_config=sampler_config,
),
Expand Down Expand Up @@ -192,6 +193,7 @@
tasks=TaskSpecConfig(
task_name="eval_task",
agent_loop_config=eval_agent_loop_config,
judger_config=judger_config,
sampler_config=eval_sampler_config,
),
)
Expand All @@ -204,7 +206,6 @@
resources=resources,
train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config
rollout_config=rollout_config,
judger_config=judger_config,
tokenizer_path=model_path,
replay_buffer_config=SyncReplayBufferConfig(),
agent_loop_manager_cfg=agent_loop_manager_cfg,
Expand Down
3 changes: 2 additions & 1 deletion examples/v1/config/rl_grpo_gsm8k_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
tasks=TaskSpecConfig(
task_name="train_task",
agent_loop_config=agent_loop_config,
judger_config=judger_config,
produce_strategy_config=produce_strategy_config,
sampler_config=sampler_config,
),
Expand Down Expand Up @@ -174,6 +175,7 @@
tasks=TaskSpecConfig(
task_name="eval_task",
agent_loop_config=eval_agent_loop_config,
judger_config=judger_config,
sampler_config=eval_sampler_config,
),
)
Expand All @@ -186,7 +188,6 @@
resources=resources,
train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config
rollout_config=rollout_config,
judger_config=judger_config,
tokenizer_path=model_path,
replay_buffer_config=AsyncReplayBufferConfig(),
agent_loop_manager_cfg=agent_loop_manager_cfg,
Expand Down
3 changes: 2 additions & 1 deletion examples/v1/config/rl_grpo_gsm8k_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
tasks=TaskSpecConfig(
task_name="train_task",
agent_loop_config=agent_loop_config,
judger_config=judger_config,
produce_strategy_config=produce_strategy_config,
sampler_config=sampler_config,
),
Expand Down Expand Up @@ -169,6 +170,7 @@
tasks=TaskSpecConfig(
task_name="eval_task",
agent_loop_config=eval_agent_loop_config,
judger_config=judger_config,
sampler_config=eval_sampler_config,
),
)
Expand All @@ -181,7 +183,6 @@
resources=resources,
train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config
rollout_config=rollout_config,
judger_config=judger_config,
tokenizer_path=model_path,
replay_buffer_config=SyncReplayBufferConfig(),
agent_loop_manager_cfg=agent_loop_manager_cfg,
Expand Down
3 changes: 2 additions & 1 deletion examples/v1/config/rl_grpo_gsm8k_with_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
tasks=TaskSpecConfig(
task_name="train_task",
agent_loop_config=agent_loop_config,
judger_config=judger_config,
produce_strategy_config=produce_strategy_config,
sampler_config=sampler_config,
),
Expand Down Expand Up @@ -191,6 +192,7 @@
tasks=TaskSpecConfig(
task_name="eval_task",
agent_loop_config=eval_agent_loop_config,
judger_config=judger_config,
sampler_config=eval_sampler_config,
),
)
Expand All @@ -203,7 +205,6 @@
resources=resources,
train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config
rollout_config=rollout_config,
judger_config=judger_config,
tokenizer_path=model_path,
replay_buffer_config=SyncReplayBufferConfig(),
agent_loop_manager_cfg=agent_loop_manager_cfg,
Expand Down
3 changes: 2 additions & 1 deletion examples/v1/config/rl_multi_task_gsm8k_dapo_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@
task_name="train_task:dapo_math",
weight=dapo_task_weight,
agent_loop_config=dapo_train_agent_loop_config,
judger_config=judger_config,
produce_strategy_config=SyncProduceStrategyConfig(),
sampler_config=dapo_train_sampler_config,
),
Expand Down Expand Up @@ -269,6 +270,7 @@
task_name="eval_task:dapo_math",
weight=dapo_task_weight,
agent_loop_config=dapo_eval_agent_loop_config,
judger_config=judger_config,
sampler_config=dapo_eval_sampler_config,
),
TaskSpecConfig(
Expand All @@ -291,7 +293,6 @@ def compute_metric(samples):
resources=resources,
train_worker_cfg=train_worker_cfg,
rollout_config=rollout_config,
judger_config=judger_config,
tokenizer_path=model_path,
replay_buffer_config=SyncReplayBufferConfig(),
agent_loop_manager_cfg=agent_loop_manager_cfg,
Expand Down
74 changes: 63 additions & 11 deletions tests/rl/test_agent_loop.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import os
import unittest
import asyncio
import copy
import ray
import tempfile
import torch
from transformers import AutoTokenizer
from xtuner.v1.rl.rollout.worker import RolloutConfig
from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers
from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig, AgentLoopManagerConfig, TaskSpecConfig, SyncProduceStrategyConfig, SamplerConfig
from xtuner.v1.data_proto import RolloutState, Status, SampleParams
from xtuner.v1.rl.agent_loop import (
SingleTurnAgentLoopConfig,
AgentLoopManagerConfig,
TaskSpecConfig,
SyncProduceStrategyConfig,
SamplerConfig,
)
from xtuner.v1.data_proto import RolloutState, Status, SampleParams
from xtuner.v1.rl.rollout import RolloutController
from xtuner.v1.rl.judger.gsm8k import GSM8KJudgerConfig
from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig
Expand Down Expand Up @@ -79,14 +85,16 @@ async def test_gsm8k_agent_loop(self):
judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", judger_type="router")
agent_loop_cfg = SingleTurnAgentLoopConfig(
hf_checkpoint=self.model_path,
sample_params=SampleParams(max_tokens=self.max_response_length, temperature=0.0)
sample_params=SampleParams(max_tokens=self.max_response_length, temperature=0.0),
)
# 2. 创建 rollout_controller, judger
# 2. 创建 rollout_controller
pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg)
rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg)
gsm8k_judger = judger_config.build()
# 3. 创建 AgentLoop
agent_loop = agent_loop_cfg.build(rollout_controller=rollout_controller, judger=gsm8k_judger)
agent_loop = agent_loop_cfg.build(
rollout_controller=rollout_controller,
judger=judger_config.build(),
)
# 4. 构造输入数据
prompt_repeat_k = 4
rollout_state = FAKE_INPUT_ITEM
Expand All @@ -104,6 +112,51 @@ async def test_gsm8k_agent_loop(self):
self.assertGreater(len(single_rollout_state.response_ids), 0)
self.assertEqual(single_rollout_state.reward["score"], 1)

async def test_gsm8k_agent_loop_with_ray_actor_judger(self):
self.init_config()
rollout_config = RolloutConfig(
env="test_agent_loop_ray_actor",
model_path=self.model_path,
model_name=os.path.basename(self.model_path).lower(),
tokenizer_path=self.model_path,
context_length=self.context_length,
worker_log_dir=self.worker_log_dir,
)
judger_config = GSM8KJudgerConfig(
judger_name="openai/gsm8k",
judger_type="ray.actor",
num_cpus_per_actor=1,
)
agent_loop_cfg = SingleTurnAgentLoopConfig(
hf_checkpoint=self.model_path,
sample_params=SampleParams(max_tokens=self.max_response_length, temperature=0.0),
num_ray_actors=1,
num_cpus=1,
)

pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg)
rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg)
agent_loop = agent_loop_cfg.build(
rollout_controller=rollout_controller,
judger=judger_config.build(),
)

prompt_repeat_k = 2
rollout_state = copy.deepcopy(FAKE_INPUT_ITEM)
group_in_rollout_state = [copy.deepcopy(FAKE_INPUT_ITEM) for _ in range(prompt_repeat_k)]

group_rollout_state = await agent_loop.generate_group.remote(group_in_rollout_state)
single_rollout_state = await agent_loop.generate_sample.remote(rollout_state)

self.assertEqual(len(group_rollout_state), prompt_repeat_k)
for state in group_rollout_state:
self.assertEqual(state.status, Status.COMPLETED)
self.assertGreater(len(state.response_ids), 0)
self.assertEqual(state.reward["score"], 1)
self.assertEqual(single_rollout_state.status, Status.COMPLETED)
self.assertGreater(len(single_rollout_state.response_ids), 0)
self.assertEqual(single_rollout_state.reward["score"], 1)

async def test_gsm8k_agent_loop_manager(self):
# 1. 初始化 config
self.init_config()
Expand All @@ -118,7 +171,7 @@ async def test_gsm8k_agent_loop_manager(self):
judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", judger_type="router")
agent_loop_cfg = SingleTurnAgentLoopConfig(
hf_checkpoint=self.model_path,
sample_params=SampleParams(max_tokens=self.max_response_length, temperature=0.0)
sample_params=SampleParams(max_tokens=self.max_response_length, temperature=0.0),
)
sampler_config = SamplerConfig(
dataloader_cfg=DataloaderConfig(
Expand All @@ -141,21 +194,20 @@ async def test_gsm8k_agent_loop_manager(self):
TaskSpecConfig(
task_name="test_gsm8k",
agent_loop_config=agent_loop_cfg,
judger_config=judger_config,
produce_strategy_config=SyncProduceStrategyConfig(),
sampler_config=sampler_config,
)
],
)
# 2. 创建 rollout_controller, judger
# 2. 创建 rollout_controller
pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg)
rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg)
gsm8k_judger = judger_config.build()
# 3. 创建 AgentLoopManager
replay_buffer_cfg = SyncReplayBufferConfig()
replay_buffer = replay_buffer_cfg.build()
agent_loop_manager = agent_loop_manager_cfg.build(
rollout_controller=rollout_controller,
judger=gsm8k_judger,
tokenizer=self.tokenizer,
replay_buffer=replay_buffer,
)
Expand Down
1 change: 0 additions & 1 deletion tests/rl/test_async_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def _build_agent_loop_manager(
)
manager = manager_cfg.build(
rollout_controller=rollout_ctl,
judger=None,
tokenizer=tokenizer,
replay_buffer=replay_buffer,
logger=None,
Expand Down
3 changes: 2 additions & 1 deletion tests/rl/test_rl_colocate_trainer_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def build_trainer_config(self, work_dir, checkpoint_interval=1, checkpoint_maxke
TaskSpecConfig(
task_name="train_task",
agent_loop_config=agent_loop_config,
judger_config=judger_config,
produce_strategy_config=produce_strategy_config,
sampler_config=sampler_config,
)
Expand All @@ -186,6 +187,7 @@ def build_trainer_config(self, work_dir, checkpoint_interval=1, checkpoint_maxke
TaskSpecConfig(
task_name="eval_task",
agent_loop_config=eval_agent_loop_config,
judger_config=judger_config,
sampler_config=eval_sampler_config,
)
],
Expand All @@ -198,7 +200,6 @@ def build_trainer_config(self, work_dir, checkpoint_interval=1, checkpoint_maxke
resources=resources,
train_worker_cfg=train_worker_cfg,
rollout_config=rollout_config,
judger_config=judger_config,
tokenizer_path=model_path,
replay_buffer_config=SyncReplayBufferConfig(),
agent_loop_manager_cfg=agent_loop_manager_cfg,
Expand Down
21 changes: 20 additions & 1 deletion xtuner/v1/rl/agent_loop/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from .agent_loop import AgentLoop, AgentLoopConfig
from xtuner.v1.rl.judger import JudgerConfigSpec, JudgerLike, JudgerSpec, JudgerSpecConfig

from .agent_loop import (
AgentLoop,
AgentLoopActor,
AgentLoopConfig,
AgentLoopSpec,
RayAgentLoop,
RayAgentLoopProxy,
RouterAgentLoop,
)
from .agent_loop_manager import (
AgentLoopManager,
AgentLoopManagerConfig,
Expand All @@ -21,7 +31,16 @@
"AgentLoopConfig",
"SingleTurnAgentLoopConfig",
"AgentLoop",
"AgentLoopSpec",
"AgentLoopActor",
"RouterAgentLoop",
"RayAgentLoop",
"RayAgentLoopProxy",
"SingleTurnAgentLoop",
"JudgerLike",
"JudgerSpec",
"JudgerConfigSpec",
"JudgerSpecConfig",
"AgentLoopManagerConfig",
"AgentLoopManager",
"TaskSpecConfig",
Expand Down
Loading
Loading