Spaces:
Running
keep_end can cause NaNs in output logs
Hi
trl 0.16.0 - dpo_trainer.py
in concatenated_forward(), when "keep_end" is applied, the following code cuts from the right (default 1024):
input_ids = input_ids[:, -self.max_length :]
attention_mask = attention_mask[:, -self.max_length :]
loss_mask = loss_mask[:, -self.max_length :]
In batches where certain answers are very short, this can lead to input_ids completely filled up with padding tokens, with a loss_mask that is all zeroes (the answer is so short that nothing gets "captured" from the right).
This has the following effect:
- mean_chosen_logits = logits[:num_examples][loss_mask[:num_examples]].mean() -> will produce a NaN in the output log, as the filtering with the loss_mask results in an empty tensor. (Similar for mean_rejected_logits, depends on where the short samples are.)
I noticed these NaNs in my outputs, which is why I looked into it.
- For the logps the effect is different, since we have: per_token_logps[~loss_mask] = 0
Result: since the loss_mask covers the whole sequence, this means that eventually output["chosen_logps"] will have a "0" in it (the sum), which is passed on to the xxx_loss() function.
I was wondering if this is a known issue (which is then considered harmless I guess).
My current plan to resolve is to add a couple of lines in get_batch_loss_metrics() to 1) see if any logps in model_output are 0, and 2) filter out the corresponding samples (chosen/rejected) from the batch (with all other info).
Thank you
JDB