Skip to content

Refactor step inputs#4504

Open
grimoire wants to merge 3 commits intoInternLM:mainfrom
grimoire:refactor-step-inputs
Open

Refactor step inputs#4504
grimoire wants to merge 3 commits intoInternLM:mainfrom
grimoire:refactor-step-inputs

Conversation

@grimoire
Copy link
Copy Markdown
Collaborator

@grimoire grimoire commented Apr 7, 2026

  • Consolidate all decoding-loop state update logic into per-paradigm StepInputs subclasses, replacing the old StepInputs in agent.py that delegated to 8 scattered methods across 3 strategy classes
  • Remove update_prefill_for_next_step, update_decoding_for_next_step, update_extra_inputs from ModelAgentStrategy; update_inputs, merge from ModelInputsStrategy; merge_sampling_delta, step_sampling_delta, update_sampling_delta from SamplingStrategy — all were exclusively called from StepInputs
  • Rename methods for clarity: mergemerge_prefill, update_deltareindex, stepstep_decode

Motivation

Before this change, understanding the update lifecycle for a single paradigm (e.g., AR) required reading 3 files (ar/model_agent.py, ar/model_inputs.py, ar/sampling.py). Adding a new field to ModelInputs meant updating methods in all 3 files. Now each paradigm's update logic lives in a single step_inputs.py file.

What changed

New files (4):

  • strategies/base/step_inputs.py — abstract StepInputs base class with lifecycle docs
  • strategies/ar/step_inputs.pyARStepInputs with all AR update logic
  • strategies/dllm/step_inputs.pyDLLMStepInputs with all DLLM update logic
  • strategies/ar_spec/step_inputs.pyARSpecStepInputs with all AR Spec update logic

Key modifications:

  • engine/model_agent/agent.py — removed old StepInputs class (~85 lines), construction now uses strategy_factory.build_step_inputs()
  • strategies/base/ — removed 8 abstract methods from the 3 strategy ABCs
  • strategies/{ar,dllm,ar_spec}/ — removed inlined methods from strategy implementations, added build_step_inputs() factory methods
  • ar/model_inputs.pyindex_select promoted to standalone index_select_model_inputs(); get_model_inputs_next_decoding moved here from ar/model_agent.py

Test

qwen3.5-35b-a3b gpqa

dataset version metric mode stepinputs-gpqa
GPQA_diamond_repeat_4 772ea0 accuracy (4 runs average) gen 83.84

qwen3.5-35b-a3b MTP aime2025

dataset version metric mode stepinputs_aime
aime2025_repeat_32 5e9f4f accuracy (32 runs average) gen 88.96

qwen3.5-35b-a3b ruler

Select 64 data, same result as main branch

@grimoire grimoire requested a review from RunningLeon April 7, 2026 12:39
@grimoire grimoire marked this pull request as ready for review April 8, 2026 03:09
Copilot AI review requested due to automatic review settings April 8, 2026 03:09
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the decoding-loop “step state” management by moving prefill/decode update logic into paradigm-specific StepInputs subclasses, removing the previous cross-strategy delegation methods.

Changes:

  • Introduces a new StepInputs ABC and adds paradigm implementations for AR, DLLM, and AR-Spec.
  • Simplifies strategy ABCs by removing now-unneeded “delta/merge/step” update hooks, and updates the model agent loop to call StepInputs directly.
  • Moves AR helper logic into ar/model_inputs.py and updates DLLM/AR-Spec code to align with the new lifecycle (merge_prefill, reindex, step_decode).

Reviewed changes

Copilot reviewed 20 out of 20 changed files in this pull request and generated no comments.

Show a summary per file
File Description
lmdeploy/pytorch/strategies/base/step_inputs.py Adds the new StepInputs base class and defines the decoding-loop lifecycle contract.
lmdeploy/pytorch/strategies/ar/step_inputs.py Implements AR-specific step-state transitions and sampling-delta handling.
lmdeploy/pytorch/strategies/dllm/step_inputs.py Implements DLLM-specific step-state transitions including mask-aware sampling delta updates.
lmdeploy/pytorch/strategies/ar_spec/step_inputs.py Implements AR-Spec-specific step-state transitions including draft-token and mRoPE handling.
lmdeploy/pytorch/engine/model_agent/agent.py Replaces the old inline StepInputs with strategy_factory.build_step_inputs() and calls into the new API.
lmdeploy/pytorch/strategies/base/init.py Extends StrategyFactoryBase with build_step_inputs().
lmdeploy/pytorch/strategies/base/model_agent.py Removes obsolete ABC hooks (update_*_for_next_step, update_extra_inputs) now owned by StepInputs.
lmdeploy/pytorch/strategies/base/model_inputs.py Removes obsolete ABC hooks (merge, update_inputs) now owned by StepInputs.
lmdeploy/pytorch/strategies/base/sampling.py Removes obsolete ABC hooks (merge/step/update_sampling_delta) now owned by StepInputs.
lmdeploy/pytorch/strategies/ar/model_inputs.py Promotes AR decoding helpers (get_model_inputs_next_decoding, index_select_model_inputs) to module-level functions.
lmdeploy/pytorch/strategies/ar/model_agent.py Removes AR “next step” update methods now implemented in ARStepInputs.
lmdeploy/pytorch/strategies/ar/sampling.py Removes delta/merge/step methods now implemented in ARStepInputs.
lmdeploy/pytorch/strategies/ar/init.py Adds build_step_inputs() factory method for AR.
lmdeploy/pytorch/strategies/dllm/model_agent.py Removes DLLM “next step” update methods now implemented in DLLMStepInputs.
lmdeploy/pytorch/strategies/dllm/model_inputs.py Removes DLLM merge/update methods now implemented in DLLMStepInputs.
lmdeploy/pytorch/strategies/dllm/sampling.py Removes delta/merge/step methods now implemented in DLLMStepInputs; aligns repeated attrs with new ngram field names.
lmdeploy/pytorch/strategies/dllm/init.py Adds build_step_inputs() factory method for DLLM.
lmdeploy/pytorch/strategies/ar_spec/model_agent.py Removes AR-Spec “next step” update methods now implemented in ARSpecStepInputs.
lmdeploy/pytorch/strategies/ar_spec/model_inputs.py Removes AR-Spec merge/update methods now implemented in ARSpecStepInputs.
lmdeploy/pytorch/strategies/ar_spec/init.py Adds build_step_inputs() factory method for AR-Spec.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants