Skip to content

[Dynamo] minor enhancements to attention and register a few functions#345

Merged
xinli-git merged 2 commits intohidet-org:mainfrom
xinli-git:minor_enhancements
Aug 14, 2023
Merged

[Dynamo] minor enhancements to attention and register a few functions#345
xinli-git merged 2 commits intohidet-org:mainfrom
xinli-git:minor_enhancements

Conversation

@xinli-git
Copy link
Copy Markdown
Contributor

Encountered a few minor issues when compiling a transformer-based model using torch.compile with very large batch sizes, submitting the fix here.

@yaoyaoding yaoyaoding requested a review from hjjq August 9, 2023 20:32
Copy link
Copy Markdown
Collaborator

@hjjq hjjq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @xinli-git, LGTM. Just a minor question about torch_sum

Comment on lines +984 to +1005
@register_function(torch.sum)
@register_method(torch.Tensor.sum)
def torch_sum(x: Tensor, *, dtype: Optional[DataType] = None) -> Tensor:
if dtype:
x = x.astype(dtype_from_torch(dtype))
output = ops.sum(x, dims=list(range(len(x.shape))), keep_dim=True)
return output


@register_function(torch.sum)
@register_method(torch.Tensor.sum)
def torch_sum(
x: Tensor, dim, keepdim=False, *, dtype: Optional[DataType] = None, out: Optional[Tensor] = None
) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.sum(..., out=...)")
if dtype:
x = x.astype(dtype_from_torch(dtype))
output = ops.sum(x, dims=dim, keep_dim=keepdim)
return output


Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need two torch_sums here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed the convention for the mean method above. Not entirely sure either as I thought python does not support overloading. Perhaps @yaoyaoding knows?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python itself does not support function overloading. We used the inspect module to support overloading in hidet. This is needed because some pytorch function/methods have multiple signatures.
The implementation can be found at here and here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

Copy link
Copy Markdown
Collaborator

@hjjq hjjq Aug 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. But does torch.Tensor.sum and torch.sum have the same signature? If they do, then no need for overloading? https://pytorch.org/docs/stable/generated/torch.sum.html#torch.sum
Also it doesn't seem that either of them has the out argument.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also it doesn't seem that either of them has the out argument.
right, let me fix this

I think the overload is for sum(x, *, dtype) and sum(x, dims, keepdim, ...)?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually jump to the signatures in the python code to check the variants of the torch functions:
image
and its interesting that the code has out parameter but the documentation does not have.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the overload is for sum(x, *, dtype) and sum(x, dims, keepdim, ...)?

I see.
Also, keepdim in the first case (L989) should default to False?
Lastly, torch.Tensor.sum seems to have a slightly different signature, where dim has a default value (whereas torch.sum doesn't have a default, making dim mandatory). So in the case below, it will resolve to the first case because of missing dim, and possibly produce wrong results?

a = torch.randn(...)
b = a.sum(keepdim=True)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually jump to the signatures in the python code to check the variants of the torch functions: image and its interesting that the code has out parameter but the documentation does not have.

image
Interestingly, my pytorch code doesn't have out. Maybe we have different version/build of torch.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is weird, I see two places where Torch generates

@overload
def xxx(args, )

One under ./_C/__init__.pyi (does not have "out"), another under ./_C/_VariableFunctions.pyi (has "out"). However, both are just generated signatures that don't really represent the actual implementation. I think they are there to make the IDEs work.

The actual implementation should be at aten/src/ATen/native/native_functions.yaml, which has "out".

Even if the actual op does not support "out", having an optional out argument should not break the inspect.Signature.bind function. so we should still be fine, and it would be better to include "out" here

@xinli-git xinli-git merged commit edb6503 into hidet-org:main Aug 14, 2023
@xinli-git xinli-git deleted the minor_enhancements branch August 14, 2023 21:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants