-
Notifications
You must be signed in to change notification settings - Fork 417
Expand file tree
/
Copy pathrl_dapo_math_async.py
More file actions
212 lines (201 loc) · 6.88 KB
/
rl_dapo_math_async.py
File metadata and controls
212 lines (201 loc) · 6.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import os
from pathlib import Path
from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig
from xtuner.v1.data_proto import SampleParams
from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig
from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig
from xtuner.v1.model import get_model_config_from_hf
from xtuner.v1.rl.utils import AcceleratorResourcesConfig
from xtuner.v1.rl.rollout.worker import RolloutConfig
from xtuner.v1.rl.judger import DapoMathJudgerConfig
from xtuner.v1.rl.replay_buffer import AsyncReplayBufferConfig
from xtuner.v1.rl.trainer import WorkerConfig
from xtuner.v1.rl.agent_loop import ColocatedAgentLoopManagerConfig, TaskSpecConfig, SingleTurnAgentLoopConfig, AsyncProduceStrategyConfig, SamplerConfig
from xtuner.v1.rl.evaluator import EvaluatorConfig
from xtuner.v1.rl.loss import GRPOLossConfig
from xtuner.v1.train.rl_colocate_trainer import RLColocateTrainerConfig
work_dir = os.environ["WORK_DIR"]
model_path = os.environ["MODEL_PATH"]
data_path = os.environ["DATA_PATH"]
eval_data_path = os.environ["EVAL_DATA_PATH"]
enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0")
NNODE = int(os.environ.get("WORLD_SIZE", "1"))
# basic settings
experimental_name = "dapo_math"
total_epochs = 1
global_batch_size = 512
prompt_repeat_k = 16
rollout_tp_size = 1
rollout_ep_size = 1
max_prompt_length = 2048
max_response_length = 8192
pack_max_length = 32768
train_optimizer_steps = 16
hf_interval = 50
enable_initial_evaluate = True
evaluate_step = 5
# 1. resources
resources = AcceleratorResourcesConfig(
accelerator="GPU",
num_workers=8 * NNODE,
num_cpus_per_worker=12,
cpu_memory_per_worker=16 * 1024**3, # 16 GB
)
# 2. rollout
rollout_config = RolloutConfig(
env=experimental_name,
device=resources.accelerator,
model_path=model_path,
dtype="bfloat16",
tensor_parallel_size=rollout_tp_size,
expert_parallel_size=rollout_ep_size,
gpu_memory_utilization=0.8,
context_length=max_response_length + max_prompt_length,
enable_return_routed_experts=(enable_return_routed_experts == "1"),
rollout_max_batch_size_per_instance=2048
)
# 3. judger
from xtuner.v1.rl.utils import get_eos_token
from transformers import AutoTokenizer
eos_token_id = get_eos_token(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id)
judger_config = DapoMathJudgerConfig(
judger_name="dapo_math",
eos_token=eos_token_str,
enable_overlong_buffer = True,
max_response_len=max_response_length,
overlong_buffer_len=4096,
overlong_penalty_factor=1.0,
tokenizer=tokenizer)
# 4. train worker
lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6)
fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1)
model_cfg = get_model_config_from_hf(Path(model_path))
if hasattr(model_cfg, "balancing_loss_cfg"):
model_cfg.balancing_loss_cfg = None
if hasattr(model_cfg, "z_loss_cfg"):
model_cfg.z_loss_cfg = None
optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1)
loss_cfg = GRPOLossConfig(
policy_loss_cfg=dict(
cliprange_high=0.28,
cliprange_low=0.2,
loss_type=os.environ.get("LOSS_TYPE", "vanilla"),
clip_ratio_c=10.0,
log_prob_diff_min=-20.0,
log_prob_diff_max=20.0,
),
ignore_idx=-100,
use_kl_loss=False,
kl_loss_coef=0.0,
kl_loss_type="low_var_kl",
mode=os.environ.get("LOSS_MODE", "chunk"),
chunk_size=512,
)
train_worker_cfg = WorkerConfig(
model_cfg=model_cfg,
load_from=model_path,
optim_cfg=optim_cfg,
loss_cfg=loss_cfg,
lr_cfg=lr_cfg,
fsdp_cfg=fsdp_cfg,
sp_size=int(os.environ.get("SP_SIZE", "1")),
optimizer_steps=train_optimizer_steps,
pack_max_length=pack_max_length,
)
# 5. train agent loop manager
train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path)
tokenizer_config = RLTextTokenizeFnConfig(max_length=max_prompt_length)
train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}]
dataloader_cfg = DataloaderConfig(
dataset_config_list=train_dataset_cfg,
pack_max_length=pack_max_length,
collator="fake_collator",
pack_level="none",
)
sampler_config = SamplerConfig(
dataloader_cfg=dataloader_cfg,
prompt_repeat_k=prompt_repeat_k,
)
training_sample_params = SampleParams(
max_tokens=max_response_length,
top_k=0,
top_p=1.0,
temperature=1.0,
min_tokens=0,
)
agent_loop_config = SingleTurnAgentLoopConfig(
hf_checkpoint=model_path,
sample_params=training_sample_params,
)
produce_strategy_config = AsyncProduceStrategyConfig(
produce_batch_over_sample_threshold=0.2,
produce_batch_enable_partial_rollout=True,
tail_batch_stale_threshold=1,
tail_batch_trigger_size=256,
)
agent_loop_manager_cfg = ColocatedAgentLoopManagerConfig(
tasks=TaskSpecConfig(
task_name="train_task",
agent_loop_config=agent_loop_config,
produce_strategy_config=produce_strategy_config,
sampler_config=sampler_config,
),
)
# 6. eval agent loop manager
eval_dataset = DatasetConfig(
name=experimental_name, anno_path=eval_data_path, sample_ratio=1.0
)
eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}]
eval_dataloader_cfg = DataloaderConfig(
dataset_config_list=eval_dataset_cfg,
pack_max_length=pack_max_length,
collator="fake_collator",
pack_level="none",
)
eval_sampler_config = SamplerConfig(
dataloader_cfg=eval_dataloader_cfg,
prompt_repeat_k=1,
)
evaluation_sample_params = SampleParams(
max_tokens=max_response_length,
top_k=1,
top_p=0.7,
temperature=0.0,
min_tokens=0,
)
eval_agent_loop_config = SingleTurnAgentLoopConfig(
hf_checkpoint=model_path,
sample_params=evaluation_sample_params,
)
eval_agent_loop_manager_cfg = ColocatedAgentLoopManagerConfig(
tasks=TaskSpecConfig(
task_name="eval_task",
agent_loop_config=eval_agent_loop_config,
sampler_config=eval_sampler_config,
),
)
def dapo_compute_metric(samples):
return {"accuracy": sum(s.reward["acc"] > 0 for s in samples) / len(samples)}
evaluator_config = EvaluatorConfig(compute_metric_func=dapo_compute_metric)
trainer = RLColocateTrainerConfig(
resources=resources,
train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config
rollout_config=rollout_config,
judger_config=judger_config,
tokenizer_path=model_path,
replay_buffer_config=AsyncReplayBufferConfig(),
agent_loop_manager_cfg=agent_loop_manager_cfg,
eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg,
evaluator_config=evaluator_config,
load_from=model_path,
global_batch_size=global_batch_size,
enable_evaluate=True,
enable_initial_evaluate=False,
rollout_steps=500,
evaluate_step=evaluate_step,
work_dir=work_dir,
seed=123,
debug_rollout=False,
)