Skip to content

Put row lengths on the same device on gpu#113

Merged
oliverholworthy merged 3 commits intoNVIDIA-Merlin:mainfrom
edknv:torch/multi-gpu-row-lengths
Mar 22, 2023
Merged

Put row lengths on the same device on gpu#113
oliverholworthy merged 3 commits intoNVIDIA-Merlin:mainfrom
edknv:torch/multi-gpu-row-lengths

Conversation

@edknv
Copy link
Copy Markdown
Contributor

@edknv edknv commented Mar 21, 2023

In a multi-gpu setting, tensors may be generated on different devices. This PR forces torch.cumsum(row_lengths, 0) to be on the same device as the zero_value tensor. If they are on different devices, torch.cat() can't concatenate them, e.g.,

  File "/usr/local/lib/python3.8/dist-packages/merlin/dataloader/torch.py", line 169, in _row_lengths_to_offsets
    return torch.cat((zero_value, torch.cumsum(row_lengths, 0)))
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cpu! (when checking argument for argument tensors in method wrapper_cat)

@edknv edknv self-assigned this Mar 21, 2023
@edknv edknv added the bug Something isn't working label Mar 21, 2023
@edknv edknv added this to the Merlin 23.03 milestone Mar 21, 2023
Comment thread merlin/dataloader/torch.py Outdated
if len(row_lengths.shape) == 2:
zero_value = zero_value.view(-1, 1)
return torch.cat((zero_value, torch.cumsum(row_lengths, 0)))
return torch.cat((zero_value, torch.cumsum(row_lengths, 0).to(device=self.device)))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assuming torch.cat preserves the device of the values being concatented, maybe an alternative to this would be to change the line where zero_value is defined to use row_lengths.device instead of self.device?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that might make the row_lengths to offsets method more robust, but fact that the error is showing up suggests that even with this fix we might end up with a mismatch between the loader.device attribute and the output tensor device?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion. Honestly I'm not sure which option is best, but it also works, and I think your suggestion is better. Changed in de593b2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants