Skip to content

Bcast python api patch#2561

Merged
jjsjann123 merged 19 commits intodevelfrom
bcast_python_api_patch
Mar 10, 2023
Merged

Bcast python api patch#2561
jjsjann123 merged 19 commits intodevelfrom
bcast_python_api_patch

Conversation

@jjsjann123
Copy link
Copy Markdown
Collaborator

Fixes python handling of expanded broadcast dimension.
e.g. torch.randint(0, 1, (5, 5), device="cuda").bool().unsqueeze(-1).expand(5, 5, 5)

Changes contiguity representation on python.
computeContiguity currently returns an array with the length of tensor rank, elements in the array can be: True, False or None, where None indicates a given dimension is broadcast.

@jjsjann123
Copy link
Copy Markdown
Collaborator Author

cc'ing @kevinstephano @rdspring1 @jacobhinkle on the contiguity change.

@jjsjann123
Copy link
Copy Markdown
Collaborator Author

linking #2551

@IvanYashchuk
Copy link
Copy Markdown
Collaborator

Could you please add this test? We need to make sure sizes=..., strides=... interface works.

    def test_fix_2549(self):
        a = torch.ones(4, 1, dtype=torch.double, device='cuda')
        b = torch.ones(4, 4, dtype=torch.double, device='cuda')

        def nvfuser_fusion_id(fd : FusionDefinition) -> None :
            T0 = fd.define_tensor(sizes=a.shape, strides=a.stride(), dtype=DataType.Double, is_cpu=False)
            T1 = fd.define_tensor(sizes=b.shape, strides=b.stride(), dtype=DataType.Double, is_cpu=False)
            T2 = fd.ops.broadcast_in_dim(T0, output_shape=[4, 4], broadcast_dims=[0, 1])
            T3 = fd.ops.div(T1, T2)
            fd.add_output(T3)

        with FusionDefinition() as fd:
            nvfuser_fusion_id(fd)

        out = fd.execute([a, b])
        self.assertEqual(out[0], b / a)

@IvanYashchuk
Copy link
Copy Markdown
Collaborator

Is this change need to be versioned?

@jjsjann123
Copy link
Copy Markdown
Collaborator Author

Is this change need to be versioned?

Good catch. I'll bump that~~

@jjsjann123
Copy link
Copy Markdown
Collaborator Author

Linking the old PR from Xiang with the contiguity refactor #2517

Copy link
Copy Markdown
Collaborator

@zasdfgbnm zasdfgbnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing! Just some minor comments.

if (not_broadcast(i)) {
break;
}
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Please revert this white space change (and consider using a modern editor rather than vim😈😈😈😈)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this~~ somehow my clangformat changes didn't get pushed.

and consider using a modern editor rather than vim

I'm going to unfriend you on all my social media.

jjsjann123 and others added 2 commits March 9, 2023 18:49
@jjsjann123
Copy link
Copy Markdown
Collaborator Author

all comments resolved. I'll merge this when CI finishes.

@jjsjann123 jjsjann123 merged commit 4bc286a into devel Mar 10, 2023
@jjsjann123 jjsjann123 deleted the bcast_python_api_patch branch March 10, 2023 05:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants