Skip to content

new feature: On policy distillation#344

Open
sfc-gh-thonguyen wants to merge 7 commits into
snowflakedb:mainfrom
sfc-gh-thonguyen:thong/on_policy_distillation
Open

new feature: On policy distillation#344
sfc-gh-thonguyen wants to merge 7 commits into
snowflakedb:mainfrom
sfc-gh-thonguyen:thong/on_policy_distillation

Conversation

@sfc-gh-thonguyen
Copy link
Copy Markdown
Collaborator

@sfc-gh-thonguyen sfc-gh-thonguyen commented Jan 28, 2026

Based on this blog post https://thinkingmachines.ai/blog/on-policy-distillation/ -- figured Arctic Training would be an appropriate place to have this feature.

Training validated with GSM8K dataset on Qwen3-1.7B model using Qwen3-8B teacher.
image
image
Lower teacher perplexity means teacher is less surprised by the student's answer. Higher teacher logprob means teacher agrees with the student's answer.

image

Lower reverse KL means student's answers converge to teacher's.

image

Student's perplexity initially jumped up, meaning it's learning, and slowly ramped down, meaning it's getting more confident along the training progress.

Full dashboard: https://snowflake.wandb.io/thongnguyen/on-policy-distillation-gsm8k/runs/zuwzrd11?nw=nwuserthongnguyen

Once this PR is in we can make the claim ArcticTraining supports RL :)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file should probably be with the trainer code. config/ is for base config classes

Comment on lines +354 to +377
def step(self, batch: Dict[str, torch.Tensor]) -> None:
"""Execute a single training step.

Overrides the base step to handle the unique requirements of
on-policy distillation (generation + training).
"""
self.model.train()

loss = self.loss(batch)

self.backward(loss)

def maybe_item(v):
return v.item() if torch.is_tensor(v) else v

self.metrics.record("loss", maybe_item(loss))

self.model.step()

self.checkpoint()

# Update step counters
self.global_step = self.model.global_steps
self.global_step_this_run = self.global_step - self.global_step_at_start_this_run
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears to be an exact copy of the step method in the base trainer class. Why do we redefine it here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants