Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions examples/multi-agent/run_train_chain.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,28 @@
# cd /mnt/shared-storage-user/marti/OpenRLHF

# 基础配置
MODEL_DIR="/your_model_path"
MODEL_DIR="Qwen"
SHORT_NAME=${1:-"Qwen2.5-3B-Instruct"}
PRETRAIN="${MODEL_DIR}/${SHORT_NAME}"
PROMPT_MAX_LEN=4096
GENERATE_MAX_LEN=32768
EVAL_GENERATE_MAX_LEN=32768
OVERLONG_BUFFER_LEN=2048
MAX_LEN=42024 # 30000
MAX_LEN=30000 #42024
ADVANTAGE="group_norm"

ROOT_DIR=""
# Repo root: use ROOT_DIR if set, else derive from this script's path
if [ -z "${ROOT_DIR}" ]; then
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]:-$0}")" && pwd)"
ROOT_DIR="$(cd "${SCRIPT_DIR}/../.." && pwd)"
fi
[ -n "${ROOT_DIR}" ] || ROOT_DIR="$(pwd)"

# Hugging Face cache under repo (avoids PermissionError on ~/.cache/huggingface)
export HF_DATASETS_CACHE="${ROOT_DIR}/.cache/huggingface/datasets"
export HF_HOME="${ROOT_DIR}/.cache/huggingface"
mkdir -p "$HF_DATASETS_CACHE"

TASK="MATH"
PROMPT_DATA="json@${ROOT_DIR}/data/${TASK}"

Expand Down Expand Up @@ -104,12 +115,12 @@ mkdir -p "${ROOT_DIR}/outputs/ckpt/${ADVANTAGE}-${SHORT_NAME}-${TASK}-ch-${EXP}-
# --reward_alloc "${REWARD_ALLOC_ARGS}" \


python3 -m openrlhf.cli.multi_agent_train_ppo_ray \
python3 -m marti.cli.multi_agent_train_ppo_ray \
--default_agent "$DEFAULT_AGENT" \
--agents "$AGENT0" "$AGENT1" "$AGENT2" \
--workflow_args "$WORKFLOW_ARGS" \
--workflow_func_path openrlhf/agent_workflows/chain_workflow.py \
--processor_func_path openrlhf/agent_workflows/chain_processor.py \
--workflow_func_path marti/agent_workflows/chain_workflow.py \
--processor_func_path marti/agent_workflows/chain_processor.py \
--parallel_loading \
--ref_num_nodes 1 \
--ref_num_gpus_per_node 1 \
Expand Down
2 changes: 1 addition & 1 deletion marti/trainer/ray/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def create_vllm_engines(
).remote(
model=pretrain,
enforce_eager=enforce_eager,
worker_extension_cls="openrlhf.trainer.ray.vllm_worker_wrap.WorkerWrap",
worker_extension_cls="marti.trainer.ray.vllm_worker_wrap.WorkerWrap",
tensor_parallel_size=tensor_parallel_size,
seed=seed + i,
distributed_executor_backend=distributed_executor_backend,
Expand Down
8 changes: 7 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,26 @@ einops
flash-attn==2.8.3
grpcio>=1.74.0
isort
json5
jsonlines
latex2sympy2
loralib
optimum
optree>=0.13.0
packaging
peft
pylatexenc
pynvml>=12.0.0
ray[default]==2.48.0
srsly
tensorboard
torch
torch==2.9
torchdata
torchmetrics
tqdm
transformers==4.57.0
transformers_stream_generator
vllm>0.8.5.post1
wandb
wheel
word2number