Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 72 additions & 20 deletions arctic_training/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
from typing import cast

import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
import torch
import torch.nn.functional as F
from pydantic import ValidationInfo
from pydantic import field_validator
from pydantic import model_validator
from typing_extensions import Self
from typing import List

try:
from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss
Expand All @@ -46,6 +48,16 @@
from arctic_training.trainer.trainer import Trainer
from arctic_training.trainer.utils import to_device

# the following function is from DeepSpeed-Chat
def _z3_params_to_fetch(
param_list: List[torch.nn.Parameter],
) -> List[torch.nn.Parameter]:
return [
p
for p in param_list
if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE
]


def get_logprobs(
logits: torch.Tensor, labels: torch.Tensor, ignore_label_index: int
Expand All @@ -67,16 +79,39 @@ def get_logprobs(
f" must have the same shape {labels.shape}."
)

output_prob = []
output_item_count = []

labels = labels[:, 1:].clone()
logits = logits[:, :-1, :]
loss_mask = labels != ignore_label_index

# dummy token; we'll ignore the losses on these tokens later
labels[labels == ignore_label_index] = 0
for label, logit in zip(labels, logits):
logit = logit[label != ignore_label_index]
logit_logps = logit.log_softmax(-1)
# logit_logps = logit_logps[label != ignore_label_index]
label = label[label != ignore_label_index]
output_prob.append(torch.gather(logit_logps, dim=1, index=label.unsqueeze(1)).squeeze(1).sum().view(1))
output_item_count.append(torch.tensor(label.numel()).view(1).to(logit.device))

output_prob = torch.cat(output_prob, dim=0)
output_item_count = torch.cat(output_item_count, dim=0)

# loss_mask = labels != ignore_label_index

# # dummy token; we'll ignore the losses on these tokens later
# labels[labels == ignore_label_index] = 0

# per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)

per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
# labels = labels[labels != ignore_label_index]

# tmp_a = (per_token_logps * loss_mask).sum(-1)
# tmp_b = loss_mask.sum(-1)

# import pdb; pdb.set_trace()

return output_prob, output_item_count

return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)


class DPOTrainerConfig(TrainerConfig):
Expand Down Expand Up @@ -194,11 +229,12 @@ def forward_model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
use_cache=False,
output_hidden_states=True,
# output_hidden_states=True,
)
logits = outputs.logits.to(torch.float32)
logprobs, completion_sizes = get_logprobs(logits, batch["labels"], self.config.ignore_label_index)
return logits, logprobs, completion_sizes, outputs.hidden_states[-1]
return logits, logprobs, completion_sizes, None
# outputs.hidden_states[-1]

def forward_reference_model(
self, batch: Dict[str, torch.Tensor]
Expand All @@ -208,15 +244,16 @@ def forward_reference_model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
use_cache=False,
output_hidden_states=True,
# output_hidden_states=True,
)
logits = output.logits.to(torch.float32)
logprobs, completion_sizes = get_logprobs(logits, batch["labels"], self.config.ignore_label_index)
return (
logits.detach(),
logprobs.detach(),
completion_sizes.detach(),
output.hidden_states[-1].detach(),
# output.hidden_states[-1].detach(),
None,
)

def dpo_loss(
Expand Down Expand Up @@ -257,15 +294,30 @@ def loss(self, batch) -> torch.Tensor:
logits, logprobs, _, hidden_state = self.forward_model(batch)

# Activate if we have liger kernel
if self.liger_dpo_loss is not None:
losses, _, _ = self.liger_dpo_loss(
hidden_state,
self.model.module.lm_head.weight,
batch["labels"][:, 1:],
ref_input=ref_hidden_state.detach(),
ref_weight=self.ref_model.module.lm_head.weight,
)
else:
losses, chosen_rewards, rejected_rewards = self.dpo_loss(logprobs, ref_logprobs)

# if self.liger_dpo_loss is not None:
# A = self.model.module.lm_head.weight
# B = self.ref_model.module.lm_head.weight

# # Gather parameters for ZeRO-3
# params_to_fetch = []
# if hasattr(A, "ds_id"):
# params_to_fetch.extend(_z3_params_to_fetch([A]))
# if hasattr(B, "ds_id"):
# params_to_fetch.extend(_z3_params_to_fetch([B]))

# # Use parameters within the context manager where they are gathered
# with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=len(params_to_fetch) > 0):
# loss_output = self.liger_dpo_loss(
# A,
# hidden_state[:, :-1],
# batch["labels"][:, 1:],
# ref_input=ref_hidden_state.detach()[:, :-1],
# ref_weight=B,
# )
# # if torch.distributed.get_rank() == 0:
# # import pdb; pdb.set_trace()
# # torch.distributed.barrier()
# return loss_output[0]
# else:
losses, chosen_rewards, rejected_rewards = self.dpo_loss(logprobs, ref_logprobs)
return losses.mean()
Loading