Conversation
| Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: | ||
| A tuple with the loss, logits and labels (each being optional). | ||
| """ | ||
| has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"]) |
There was a problem hiding this comment.
No need to test for old deprecated argument names since they have all been changed in the lib and the user can now set their own name if they have an old model they are still using.
Codecov Report
@@ Coverage Diff @@
## master #7191 +/- ##
==========================================
- Coverage 80.86% 79.41% -1.46%
==========================================
Files 169 169
Lines 32293 32322 +29
==========================================
- Hits 26115 25668 -447
- Misses 6178 6654 +476
Continue to review full report at Codecov.
|
LysandreJik
left a comment
There was a problem hiding this comment.
LGTM, very nice addition.
| def nested_xla_mesh_reduce(tensors, name): | ||
| if is_torch_tpu_available(): | ||
| import torch_xla.core.xla_model as xm | ||
|
|
||
| if isinstance(tensors, (list, tuple)): | ||
| return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors)) | ||
| return xm.mesh_reduce(name, tensors, torch.cat) | ||
| else: | ||
| raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`") | ||
|
|
There was a problem hiding this comment.
Did you get a chance to test this on TPU?
There was a problem hiding this comment.
No, was planning to ask you about it this morning.
* Trainer accep multiple labels * Missing import * Fix dosctrings
This is a follow-up from #7126. The same kinds of models that can output multiple predictions expect multiple labels (not named "labels") so the evaluation code needs to be changed for this. To support models built by users, I added a
label_namesfield in theTrainingArgumentswhich contain the label names. It then defaults to["labels"]for most models,["start_positions", "end_positions"]for question answering models if the user does not set it to work seamlessly for all Transformers models.I ended up writing a few util functions that concat/numpify for tensors or nested lists/tuples of tensors to avoid testing everywhere in
Trainer, I think the design is cleaner this way and it also supports model with crazy outputs (if we setoutput_attentions=Truefor instance). I also added a test for the multiple labels predictions.