diff --git a/examples/multi-agent/run_train_chain.sh b/examples/multi-agent/run_train_chain.sh old mode 100644 new mode 100755 index 1e59da3..6c698e0 --- a/examples/multi-agent/run_train_chain.sh +++ b/examples/multi-agent/run_train_chain.sh @@ -6,17 +6,28 @@ # cd /mnt/shared-storage-user/marti/OpenRLHF # 基础配置 -MODEL_DIR="/your_model_path" +MODEL_DIR="Qwen" SHORT_NAME=${1:-"Qwen2.5-3B-Instruct"} PRETRAIN="${MODEL_DIR}/${SHORT_NAME}" PROMPT_MAX_LEN=4096 GENERATE_MAX_LEN=32768 EVAL_GENERATE_MAX_LEN=32768 OVERLONG_BUFFER_LEN=2048 -MAX_LEN=42024 # 30000 +MAX_LEN=30000 #42024 ADVANTAGE="group_norm" -ROOT_DIR="" +# Repo root: use ROOT_DIR if set, else derive from this script's path +if [ -z "${ROOT_DIR}" ]; then + SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]:-$0}")" && pwd)" + ROOT_DIR="$(cd "${SCRIPT_DIR}/../.." && pwd)" +fi +[ -n "${ROOT_DIR}" ] || ROOT_DIR="$(pwd)" + +# Hugging Face cache under repo (avoids PermissionError on ~/.cache/huggingface) +export HF_DATASETS_CACHE="${ROOT_DIR}/.cache/huggingface/datasets" +export HF_HOME="${ROOT_DIR}/.cache/huggingface" +mkdir -p "$HF_DATASETS_CACHE" + TASK="MATH" PROMPT_DATA="json@${ROOT_DIR}/data/${TASK}" @@ -104,12 +115,12 @@ mkdir -p "${ROOT_DIR}/outputs/ckpt/${ADVANTAGE}-${SHORT_NAME}-${TASK}-ch-${EXP}- # --reward_alloc "${REWARD_ALLOC_ARGS}" \ -python3 -m openrlhf.cli.multi_agent_train_ppo_ray \ +python3 -m marti.cli.multi_agent_train_ppo_ray \ --default_agent "$DEFAULT_AGENT" \ --agents "$AGENT0" "$AGENT1" "$AGENT2" \ --workflow_args "$WORKFLOW_ARGS" \ - --workflow_func_path openrlhf/agent_workflows/chain_workflow.py \ - --processor_func_path openrlhf/agent_workflows/chain_processor.py \ + --workflow_func_path marti/agent_workflows/chain_workflow.py \ + --processor_func_path marti/agent_workflows/chain_processor.py \ --parallel_loading \ --ref_num_nodes 1 \ --ref_num_gpus_per_node 1 \ diff --git a/marti/trainer/ray/vllm_engine.py b/marti/trainer/ray/vllm_engine.py index e4cdbb5..0362873 100755 --- a/marti/trainer/ray/vllm_engine.py +++ b/marti/trainer/ray/vllm_engine.py @@ -172,7 +172,7 @@ def create_vllm_engines( ).remote( model=pretrain, enforce_eager=enforce_eager, - worker_extension_cls="openrlhf.trainer.ray.vllm_worker_wrap.WorkerWrap", + worker_extension_cls="marti.trainer.ray.vllm_worker_wrap.WorkerWrap", tensor_parallel_size=tensor_parallel_size, seed=seed + i, distributed_executor_backend=distributed_executor_backend, diff --git a/requirements.txt b/requirements.txt index b41e72b..fb6acb5 100755 --- a/requirements.txt +++ b/requirements.txt @@ -6,20 +6,26 @@ einops flash-attn==2.8.3 grpcio>=1.74.0 isort +json5 jsonlines +latex2sympy2 loralib optimum optree>=0.13.0 packaging peft +pylatexenc pynvml>=12.0.0 ray[default]==2.48.0 +srsly tensorboard -torch +torch==2.9 torchdata torchmetrics tqdm transformers==4.57.0 transformers_stream_generator +vllm>0.8.5.post1 wandb wheel +word2number