-
Notifications
You must be signed in to change notification settings - Fork 40
new feature: On policy distillation #344
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sfc-gh-thonguyen
wants to merge
7
commits into
snowflakedb:main
Choose a base branch
from
sfc-gh-thonguyen:thong/on_policy_distillation
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
796db60
Add on policy distillation
sfc-gh-thonguyen 949a91b
fix flake8
sfc-gh-thonguyen d7e5e61
remove redundant changes
sfc-gh-thonguyen 2e5b0e7
minor fix
sfc-gh-thonguyen d590efc
minor fix
sfc-gh-thonguyen aa563ca
optimize generator
sfc-gh-thonguyen 8f073ff
Merge branch 'main' into thong/on_policy_distillation
sfc-gh-mwyatt File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,151 @@ | ||
| # Copyright 2025 Snowflake Inc. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Configuration for On-Policy Distillation Trainer. | ||
|
|
||
| On-Policy Distillation trains a student model by having it generate its own | ||
| trajectories, then using a teacher model to provide per-token supervision via | ||
| reverse KL divergence. This contrasts with traditional (off-policy) distillation | ||
| where the teacher generates trajectories for the student to imitate. | ||
| """ | ||
|
|
||
| from typing import Dict | ||
| from typing import Union | ||
| from typing import cast | ||
|
|
||
| from pydantic import Field | ||
| from pydantic import ValidationInfo | ||
| from pydantic import field_validator | ||
| from pydantic import model_validator | ||
| from typing_extensions import Self | ||
|
|
||
| from arctic_training.config.model import ModelConfig | ||
| from arctic_training.config.trainer import TrainerConfig | ||
| from arctic_training.config.utils import HumanInt | ||
| from arctic_training.registry import get_registered_model_factory | ||
|
|
||
|
|
||
| class OnPolicyDistillationTrainerConfig(TrainerConfig): | ||
| """Configuration for On-Policy Distillation Trainer. | ||
|
|
||
| On-policy distillation trains the student on its own generated trajectories, | ||
| with the teacher providing dense per-token feedback via reverse KL divergence. | ||
| """ | ||
|
|
||
| teacher_model: ModelConfig | ||
| """ | ||
| Configuration for the teacher model used in on-policy distillation. | ||
| The teacher model provides per-token log probabilities for computing | ||
| the reverse KL divergence loss against student-generated trajectories. | ||
| """ | ||
|
|
||
| teacher_deepspeed: Dict = {} | ||
| """ | ||
| DeepSpeed configuration for the teacher model. This is automatically | ||
| computed based on the main model's DeepSpeed config and should not | ||
| be provided by the user. | ||
| """ | ||
|
|
||
| disable_teacher_dropout: bool = True | ||
| """ | ||
| Whether to disable dropout in the teacher model during training. | ||
| Recommended to keep True for stable distillation signal. | ||
| """ | ||
|
|
||
| num_rollouts_per_prompt: int = Field(default=4, ge=1) | ||
| """ | ||
| Number of trajectory samples to generate from the student per prompt. | ||
| Higher values provide more diverse on-policy samples but increase compute. | ||
| """ | ||
|
|
||
| max_new_tokens: HumanInt = Field(default=2048, ge=1) | ||
| """ | ||
| Maximum number of new tokens to generate for each student trajectory. | ||
| Should be set based on expected response length for the task. | ||
| """ | ||
|
|
||
| generation_temperature: float = Field(default=1.0, gt=0.0) | ||
| """ | ||
| Temperature for student trajectory generation. | ||
| Higher values produce more diverse samples but may reduce quality. | ||
| """ | ||
|
|
||
| generation_top_p: float = Field(default=1.0, gt=0.0, le=1.0) | ||
| """ | ||
| Top-p (nucleus) sampling parameter for student generation. | ||
| """ | ||
|
|
||
| generation_top_k: int = Field(default=0, ge=0) | ||
| """ | ||
| Top-k sampling parameter for student generation. 0 means no top-k filtering. | ||
| """ | ||
|
|
||
| beta: float = Field(default=1.0, gt=0.0) | ||
| """ | ||
| Coefficient for the reverse KL divergence loss. | ||
| Controls the strength of the distillation signal. | ||
| """ | ||
|
|
||
| generation_batch_size: int = Field(default=0, ge=0) | ||
| """ | ||
| Batch size for trajectory generation. If 0, uses micro_batch_size. | ||
| May need to be smaller than micro_batch_size due to memory constraints | ||
| during generation. | ||
| """ | ||
|
|
||
| @field_validator("teacher_model", mode="before") | ||
| @classmethod | ||
| def init_teacher_model_config(cls, v: Union[Dict, ModelConfig], info: ValidationInfo) -> ModelConfig: | ||
| """Initialize teacher model config from dict or ModelConfig.""" | ||
| subconfig = cls._get_subconfig_object( | ||
| v=v, | ||
| info=info, | ||
| get_class_fn=get_registered_model_factory, | ||
| attr_name="teacher_model_factory", | ||
| ) | ||
| return cast(ModelConfig, subconfig) | ||
|
|
||
| @model_validator(mode="after") | ||
| def build_teacher_deepspeed_config(self) -> Self: | ||
| """Build DeepSpeed config for teacher model.""" | ||
| if len(self.teacher_deepspeed) != 0: | ||
| raise ValueError( | ||
| "Teacher model DeepSpeed config is computed based on the main model " | ||
| "DeepSpeed config and should not be passed by the user." | ||
| ) | ||
|
|
||
| teacher_deepspeed = dict( | ||
| train_batch_size=self.deepspeed["train_batch_size"], | ||
| train_micro_batch_size_per_gpu=self.deepspeed["train_micro_batch_size_per_gpu"], | ||
| steps_per_print=self.deepspeed["steps_per_print"], | ||
| zero_optimization=dict( | ||
| stage=3 if self.deepspeed["zero_optimization"]["stage"] == 3 else 0, | ||
| stage3_param_persistence_threshold=1e4, | ||
| memory_efficient_linear=False, | ||
| ), | ||
| bfloat16=dict(enabled=True), | ||
| gradient_clipping=1.0, | ||
| prescale_gradients=False, | ||
| wall_clock_breakdown=False, | ||
| ) | ||
| self.teacher_deepspeed = teacher_deepspeed | ||
| return self | ||
|
|
||
| @model_validator(mode="after") | ||
| def set_generation_batch_size(self) -> Self: | ||
| """Set generation batch size to micro_batch_size if not specified.""" | ||
| if self.generation_batch_size == 0: | ||
| self.generation_batch_size = self.micro_batch_size | ||
| return self |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file should probably be with the trainer code.
config/is for base config classes