Skip to content

<< and >> operators seem silently broken for DTensor operand 1 and scalar operand 2 #156533

@vkuzo

Description

@vkuzo

🐛 Describe the bug

In pytorch/ao#2420, I found that << and >> were silently wrong for DTensor operands. Specifically:

user code:

tmp = max_abs_int32 >> hp_mbits

if max_abs_int32 is not a DTensor, this gave correct results. If max_abs_int32 is a DTensor, the >> had no effect, and the local value of tmp was equal to the local value of max_abs_int32. I fixed it by replacing >> with torch.bitwise_right_shift, filing an issue if someone can make the >> operator work as expected here.

In the same PR, a similar issue happened with <<, where it had no effect for DTensor operands, but worked properly when I replaced it with torch.bitwise_left_shift.

I'm on an NVIDIA B200 machine, in the off chance that this is hardware specific.

Versions

https://www.internalfb.com/phabricator/paste/view/P1847216180

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @weifengpy @tianyu-l @XilunWu @SherlockNoMad @ppwwyyxx @H-Huang

Metadata

Metadata

Assignees

No one assigned

    Labels

    bot-triagedThis is a label only to be used by the auto triage botmodule: dtensordistributed tensor tagoncall: distributedAdd this issue/PR to distributed oncall triage queueoncall: distributed parallelismsAdd this issue/PR to the distributed parallelisms oncall triage queueptd-bot-triagedtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions