Skip to content

Commit

Permalink
fix #6050
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga authored Nov 16, 2024
1 parent 6c08478 commit dc82821
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/llamafactory/train/dpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,10 @@ def get_batch_loss_metrics(
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().item()
metrics[f"{prefix}rewards/accuracies"] = (chosen_rewards > rejected_rewards).float().mean().item()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().item()
metrics[f"{prefix}logps/rejected"] = policy_chosen_logps.mean().item()
metrics[f"{prefix}logps/chosen"] = policy_rejected_logps.mean().item()
metrics[f"{prefix}logits/rejected"] = policy_chosen_logits.mean().item()
metrics[f"{prefix}logits/chosen"] = policy_rejected_logits.mean().item()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.mean().item()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.mean().item()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.mean().item()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.mean().item()
if self.loss_type == "orpo":
metrics[f"{prefix}sft_loss"] = sft_loss.mean().item()
metrics[f"{prefix}odds_ratio_loss"] = ((losses - sft_loss) / self.beta).mean().item()
Expand Down

0 comments on commit dc82821

Please sign in to comment.