Problem:
Transformers4Rec supports multi-GPU training for the next-item prediction task because it uses the HF Trainer (RMP #522), which under-the-hood supports DataParallel / DistributedDataParallel.
The binary classification / regression tasks are currently not supported by HF Trainer, but rather trained with a custom model.fit() method we provided, that doesn't support DataParallel / DistributedDataParallel.
Goal:
- Change the implementation of binary classification / regression tasks so that they can be trained (with multi-GPU) by using HF Trainer.
Starting Point:
Problem:
Transformers4Rec supports multi-GPU training for the next-item prediction task because it uses the HF Trainer (RMP #522), which under-the-hood supports DataParallel / DistributedDataParallel.
The binary classification / regression tasks are currently not supported by HF Trainer, but rather trained with a custom
model.fit()method we provided, that doesn't support DataParallel / DistributedDataParallel.Goal:
Starting Point: