diff --git a/arctic_training/trainer/dpo_trainer.py b/arctic_training/trainer/dpo_trainer.py index 7dbd2b99..1b81011b 100644 --- a/arctic_training/trainer/dpo_trainer.py +++ b/arctic_training/trainer/dpo_trainer.py @@ -20,12 +20,14 @@ from typing import cast import deepspeed +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus import torch import torch.nn.functional as F from pydantic import ValidationInfo from pydantic import field_validator from pydantic import model_validator from typing_extensions import Self +from typing import List try: from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss @@ -46,6 +48,16 @@ from arctic_training.trainer.trainer import Trainer from arctic_training.trainer.utils import to_device +# the following function is from DeepSpeed-Chat +def _z3_params_to_fetch( + param_list: List[torch.nn.Parameter], +) -> List[torch.nn.Parameter]: + return [ + p + for p in param_list + if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE + ] + def get_logprobs( logits: torch.Tensor, labels: torch.Tensor, ignore_label_index: int @@ -67,16 +79,39 @@ def get_logprobs( f" must have the same shape {labels.shape}." ) + output_prob = [] + output_item_count = [] + labels = labels[:, 1:].clone() logits = logits[:, :-1, :] - loss_mask = labels != ignore_label_index - # dummy token; we'll ignore the losses on these tokens later - labels[labels == ignore_label_index] = 0 + for label, logit in zip(labels, logits): + logit = logit[label != ignore_label_index] + logit_logps = logit.log_softmax(-1) + # logit_logps = logit_logps[label != ignore_label_index] + label = label[label != ignore_label_index] + output_prob.append(torch.gather(logit_logps, dim=1, index=label.unsqueeze(1)).squeeze(1).sum().view(1)) + output_item_count.append(torch.tensor(label.numel()).view(1).to(logit.device)) + + output_prob = torch.cat(output_prob, dim=0) + output_item_count = torch.cat(output_item_count, dim=0) + + # loss_mask = labels != ignore_label_index + + # # dummy token; we'll ignore the losses on these tokens later + # labels[labels == ignore_label_index] = 0 + + # per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + # labels = labels[labels != ignore_label_index] + + # tmp_a = (per_token_logps * loss_mask).sum(-1) + # tmp_b = loss_mask.sum(-1) + + # import pdb; pdb.set_trace() + + return output_prob, output_item_count - return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1) class DPOTrainerConfig(TrainerConfig): @@ -194,11 +229,12 @@ def forward_model( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], use_cache=False, - output_hidden_states=True, + # output_hidden_states=True, ) logits = outputs.logits.to(torch.float32) logprobs, completion_sizes = get_logprobs(logits, batch["labels"], self.config.ignore_label_index) - return logits, logprobs, completion_sizes, outputs.hidden_states[-1] + return logits, logprobs, completion_sizes, None + # outputs.hidden_states[-1] def forward_reference_model( self, batch: Dict[str, torch.Tensor] @@ -208,7 +244,7 @@ def forward_reference_model( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], use_cache=False, - output_hidden_states=True, + # output_hidden_states=True, ) logits = output.logits.to(torch.float32) logprobs, completion_sizes = get_logprobs(logits, batch["labels"], self.config.ignore_label_index) @@ -216,7 +252,8 @@ def forward_reference_model( logits.detach(), logprobs.detach(), completion_sizes.detach(), - output.hidden_states[-1].detach(), + # output.hidden_states[-1].detach(), + None, ) def dpo_loss( @@ -257,15 +294,30 @@ def loss(self, batch) -> torch.Tensor: logits, logprobs, _, hidden_state = self.forward_model(batch) # Activate if we have liger kernel - if self.liger_dpo_loss is not None: - losses, _, _ = self.liger_dpo_loss( - hidden_state, - self.model.module.lm_head.weight, - batch["labels"][:, 1:], - ref_input=ref_hidden_state.detach(), - ref_weight=self.ref_model.module.lm_head.weight, - ) - else: - losses, chosen_rewards, rejected_rewards = self.dpo_loss(logprobs, ref_logprobs) - + # if self.liger_dpo_loss is not None: + # A = self.model.module.lm_head.weight + # B = self.ref_model.module.lm_head.weight + + # # Gather parameters for ZeRO-3 + # params_to_fetch = [] + # if hasattr(A, "ds_id"): + # params_to_fetch.extend(_z3_params_to_fetch([A])) + # if hasattr(B, "ds_id"): + # params_to_fetch.extend(_z3_params_to_fetch([B])) + + # # Use parameters within the context manager where they are gathered + # with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=len(params_to_fetch) > 0): + # loss_output = self.liger_dpo_loss( + # A, + # hidden_state[:, :-1], + # batch["labels"][:, 1:], + # ref_input=ref_hidden_state.detach()[:, :-1], + # ref_weight=B, + # ) + # # if torch.distributed.get_rank() == 0: + # # import pdb; pdb.set_trace() + # # torch.distributed.barrier() + # return loss_output[0] + # else: + losses, chosen_rewards, rejected_rewards = self.dpo_loss(logprobs, ref_logprobs) return losses.mean()