Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions torchtitan/experiments/rl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,19 @@
Unified approach for running TorchTitan models with vLLM inference.

To register TorchTitan models with vLLM:
from torchtitan.experiments.rl.models.vllm_registry import register_model_to_vllm_model_registry
register_model_to_vllm_model_registry(model_spec)
from torchtitan.experiments.rl.models.vllm_registry import registry_to_vllm
registry_to_vllm(
model_spec,
parallelism=parallelism_config,
compile_config=compile_config,
)
"""

from torchtitan.experiments.rl.models.vllm_registry import (
register_model_to_vllm_model_registry,
)
from torchtitan.experiments.rl.models.vllm_registry import registry_to_vllm
from torchtitan.experiments.rl.models.vllm_wrapper import TorchTitanVLLMModelWrapper


__all__ = [
"TorchTitanVLLMModelWrapper",
"register_model_to_vllm_model_registry", # Export register function for manual use
"registry_to_vllm", # Export register function for manual use
]
52 changes: 20 additions & 32 deletions torchtitan/experiments/rl/actors/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
import torch
import torchstore as ts
from monarch.actor import Actor, endpoint
from torchtitan.config import Configurable
from torchtitan.config.configs import CompileConfig, DebugConfig, ParallelismConfig
from torchtitan.config import (
CompileConfig,
Configurable,
DebugConfig,
ParallelismConfig,
)
from torchtitan.distributed.utils import set_batch_invariance
from torchtitan.experiments.rl.models.vllm_registry import (
register_model_to_vllm_model_registry,
VLLM_MODEL_NAME,
registry_to_vllm,
TORCHTITAN_CONFIG_FORMAT,
)
from torchtitan.experiments.rl.types import Completion
from torchtitan.protocols.model_spec import ModelSpec
Expand Down Expand Up @@ -136,31 +140,6 @@ class Config(Configurable.Config):
debug: DebugConfig = field(default_factory=DebugConfig)
"""Debug and determinism settings."""

def __post_init__(self):
# Generator only supports TP. vLLM handles its own parallelism
# and we only apply TP via the core parallelize function.
p = self.parallelism
if p.data_parallel_replicate_degree != 1:
raise ValueError(
f"Generator does not support data parallel replication, "
f"got dp_replicate={p.data_parallel_replicate_degree}"
)
if p.pipeline_parallel_degree > 1:
raise ValueError(
f"Generator does not support pipeline parallelism, "
f"got pp={p.pipeline_parallel_degree}"
)
if p.context_parallel_degree > 1:
raise ValueError(
f"Generator does not support context parallelism, "
f"got cp={p.context_parallel_degree}"
)
if p.expert_parallel_degree > 1:
raise ValueError(
f"Generator does not support expert parallelism, "
f"got ep={p.expert_parallel_degree}"
)

def __init__(
self,
config: Config,
Expand All @@ -180,9 +159,10 @@ def __init__(
# (RLTrainer) as num_prompts_per_step * sampling.n.
self._max_num_seqs = max_num_seqs

# Register TorchTitan model with vLLM before any engine creation
register_model_to_vllm_model_registry(
# Register TorchTitan model + parser with vLLM
registry_to_vllm(
model_spec,
parallelism=config.parallelism,
compile_config=compile_config,
)

Expand All @@ -197,16 +177,24 @@ def __init__(

# Build vLLM engine
engine_kwargs = dict(
# ``model`` is the path to the HF checkpoint directory. The
# config is sourced from torchtitan's ModelSpec via
# ``config_format=TORCHTITAN_CONFIG_FORMAT`` (no config.json
# read), but vLLM still uses this path to locate the
# tokenizer assets and the safetensors weight shards.
model=model_path,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this path for?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now it serves 2 purpose:

  1. Loading tokenizer. This can be removed by passing tokenizer=tokenizer_path to EngineArgs.
    2.Initial_weight_loading: Will sort out the weight loading part for both trainer and generator in next PR.

After lifting both, we can pass some fake path, say "torchtitan", to vllm

trust_remote_code=True,
# Use the torchtitan custom config parser (registered by
# registry_to_vllm above). It builds PretrainedConfig from
# ModelSpec instead of reading config.json from disk.
config_format=TORCHTITAN_CONFIG_FORMAT,
dtype=config.model_dtype,
tensor_parallel_size=config.parallelism.tensor_parallel_degree,
# Monarch already spawned TP workers via proc mesh. "external_launcher"
# tells vLLM to run one worker per process (no subprocess spawning)
distributed_executor_backend="external_launcher",
gpu_memory_utilization=config.gpu_memory_limit,
enforce_eager=not config.cudagraph.enable,
hf_overrides={"architectures": [VLLM_MODEL_NAME]},
attention_config=AttentionConfig(
backend=AttentionBackendEnum.CUSTOM,
),
Expand Down
6 changes: 4 additions & 2 deletions torchtitan/experiments/rl/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
)
from torchtitan.components.lr_scheduler import LRSchedulersContainer
from torchtitan.components.optimizer import OptimizersContainer
from torchtitan.config import CommConfig, Configurable, TORCH_DTYPE_MAP
from torchtitan.config.configs import (
from torchtitan.config import (
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is just consolidate the import path

ActivationCheckpointConfig,
CommConfig,
CompileConfig,
Configurable,
DebugConfig,
ParallelismConfig,
TORCH_DTYPE_MAP,
TrainingConfig,
)
from torchtitan.distributed import ParallelDims, utils as dist_utils
Expand Down
12 changes: 11 additions & 1 deletion torchtitan/experiments/rl/config_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from torchtitan.components.lr_scheduler import LRSchedulersContainer
from torchtitan.components.optimizer import OptimizersContainer
from torchtitan.config.configs import (
from torchtitan.config import (
CompileConfig,
DebugConfig,
ParallelismConfig,
Expand Down Expand Up @@ -59,6 +59,8 @@ def rl_grpo_qwen3_0_6b() -> RLTrainer.Config:
parallelism=ParallelismConfig(
tensor_parallel_degree=4,
data_parallel_replicate_degree=1,
enable_sequence_parallel=False,
disable_loss_parallel=True,
),
sampling=SamplingConfig(
n=group_size,
Expand Down Expand Up @@ -104,6 +106,8 @@ def rl_grpo_qwen3_1_7b() -> RLTrainer.Config:
data_parallel_shard_degree=1,
tensor_parallel_degree=4,
data_parallel_replicate_degree=1,
enable_sequence_parallel=False,
disable_loss_parallel=True,
),
sampling=SamplingConfig(
n=group_size,
Expand Down Expand Up @@ -148,6 +152,8 @@ def rl_grpo_qwen3_14b() -> RLTrainer.Config:
parallelism=ParallelismConfig(
tensor_parallel_degree=8,
data_parallel_replicate_degree=1,
enable_sequence_parallel=False,
disable_loss_parallel=True,
),
sampling=SamplingConfig(
n=group_size,
Expand Down Expand Up @@ -190,6 +196,8 @@ def rl_grpo_qwen3_debug() -> RLTrainer.Config:
parallelism=ParallelismConfig(
tensor_parallel_degree=1,
data_parallel_replicate_degree=1,
enable_sequence_parallel=False,
disable_loss_parallel=True,
),
sampling=SamplingConfig(
n=group_size,
Expand Down Expand Up @@ -242,6 +250,8 @@ def rl_grpo_qwen3_0_6b_batch_invariant() -> RLTrainer.Config:
parallelism=ParallelismConfig(
tensor_parallel_degree=2,
data_parallel_replicate_degree=1,
enable_sequence_parallel=False,
disable_loss_parallel=True,
),
sampling=SamplingConfig(
n=group_size,
Expand Down
5 changes: 3 additions & 2 deletions torchtitan/experiments/rl/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,13 @@ def generate():

# Register TorchTitan model with vLLM before engine creation
from torchtitan.experiments.rl.models.vllm_registry import (
register_model_to_vllm_model_registry,
registry_to_vllm,
VLLM_MODEL_NAME,
)

register_model_to_vllm_model_registry(
registry_to_vllm(
config.model_spec,
parallelism=gen_config.parallelism,
compile_config=config.compile,
)
logger.info("Registered TorchTitan model with vLLM")
Expand Down
44 changes: 41 additions & 3 deletions torchtitan/experiments/rl/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@
from monarch.actor import this_host
from monarch.spmd import setup_torch_elastic_env_async

from torchtitan.config import Configurable, ParallelismConfig
from torchtitan.config.configs import CompileConfig
from torchtitan.config.manager import ConfigManager
from torchtitan.config import (
CompileConfig,
ConfigManager,
Configurable,
ParallelismConfig,
)
from torchtitan.experiments.rl.actors.generator import SamplingConfig, VLLMGenerator
from torchtitan.experiments.rl.actors.trainer import PolicyTrainer
from torchtitan.experiments.rl.types import (
Expand Down Expand Up @@ -234,6 +237,41 @@ def __post_init__(self):
"and has not been validated for determinism."
)

# VLLMGenerator only supports TP. vLLM handles its own parallelism;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry my earlier comment introduced some noise
#3242 (comment)

I meant to say that we should put some check in post_init, but not ParallelismConfig's post_init because it'll be shared by core trainer, policy trainer, and generator. I think the proper place to put these checks is generator config's post_init.

# we only apply TP via the core parallelize function.
if self.generator.parallelism.data_parallel_replicate_degree != 1:
raise ValueError(
f"Generator does not support data parallel replication, "
f"got dp_replicate={self.generator.parallelism.data_parallel_replicate_degree}"
)
if self.generator.parallelism.pipeline_parallel_degree > 1:
raise ValueError(
f"Generator does not support pipeline parallelism, "
f"got pp={self.generator.parallelism.pipeline_parallel_degree}"
)
if self.generator.parallelism.context_parallel_degree > 1:
raise ValueError(
f"Generator does not support context parallelism, "
f"got cp={self.generator.parallelism.context_parallel_degree}"
)
if self.generator.parallelism.expert_parallel_degree > 1:
raise ValueError(
f"Generator does not support expert parallelism, "
f"got ep={self.generator.parallelism.expert_parallel_degree}"
)
if self.generator.parallelism.enable_sequence_parallel:
raise ValueError(
"Generator does not support sequence parallelism: "
"spmd_types erasure mode requires sequence length to be "
"evenly divisible by TP, which doesn't hold for inference "
"(uneven batches). Set enable_sequence_parallel=False."
)
if not self.generator.parallelism.disable_loss_parallel:
raise ValueError(
"Generator requires disable_loss_parallel=True, "
f"got disable_loss_parallel={self.generator.parallelism.disable_loss_parallel}"
)

def __init__(self, config: Config):
self.config = config
self._proc_meshes = []
Expand Down
3 changes: 1 addition & 2 deletions torchtitan/experiments/rl/models/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
SequenceParallel,
)

from torchtitan.config import ParallelismConfig
from torchtitan.config.configs import CompileConfig
from torchtitan.config import CompileConfig, ParallelismConfig
from torchtitan.distributed import ParallelDims
from torchtitan.distributed.compile import apply_compile
from torchtitan.distributed.tensor_parallel import NoParallel
Expand Down
Loading
Loading