[Dynamo] minor enhancements to attention and register a few functions#345
[Dynamo] minor enhancements to attention and register a few functions#345xinli-git merged 2 commits intohidet-org:mainfrom
Conversation
hjjq
left a comment
There was a problem hiding this comment.
Thanks @xinli-git, LGTM. Just a minor question about torch_sum
| @register_function(torch.sum) | ||
| @register_method(torch.Tensor.sum) | ||
| def torch_sum(x: Tensor, *, dtype: Optional[DataType] = None) -> Tensor: | ||
| if dtype: | ||
| x = x.astype(dtype_from_torch(dtype)) | ||
| output = ops.sum(x, dims=list(range(len(x.shape))), keep_dim=True) | ||
| return output | ||
|
|
||
|
|
||
| @register_function(torch.sum) | ||
| @register_method(torch.Tensor.sum) | ||
| def torch_sum( | ||
| x: Tensor, dim, keepdim=False, *, dtype: Optional[DataType] = None, out: Optional[Tensor] = None | ||
| ) -> Tensor: | ||
| if out is not None: | ||
| raise NotImplementedError("hidet: does not support torch.sum(..., out=...)") | ||
| if dtype: | ||
| x = x.astype(dtype_from_torch(dtype)) | ||
| output = ops.sum(x, dims=dim, keep_dim=keepdim) | ||
| return output | ||
|
|
||
|
|
There was a problem hiding this comment.
Why do we need two torch_sums here?
There was a problem hiding this comment.
I followed the convention for the mean method above. Not entirely sure either as I thought python does not support overloading. Perhaps @yaoyaoding knows?
There was a problem hiding this comment.
I see. But does torch.Tensor.sum and torch.sum have the same signature? If they do, then no need for overloading? https://pytorch.org/docs/stable/generated/torch.sum.html#torch.sum
Also it doesn't seem that either of them has the out argument.
There was a problem hiding this comment.
Also it doesn't seem that either of them has the out argument.
right, let me fix this
I think the overload is for sum(x, *, dtype) and sum(x, dims, keepdim, ...)?
There was a problem hiding this comment.
I think the overload is for sum(x, *, dtype) and sum(x, dims, keepdim, ...)?
I see.
Also, keepdim in the first case (L989) should default to False?
Lastly, torch.Tensor.sum seems to have a slightly different signature, where dim has a default value (whereas torch.sum doesn't have a default, making dim mandatory). So in the case below, it will resolve to the first case because of missing dim, and possibly produce wrong results?
a = torch.randn(...)
b = a.sum(keepdim=True)
There was a problem hiding this comment.
There was a problem hiding this comment.
It is weird, I see two places where Torch generates
@overload
def xxx(args, )One under ./_C/__init__.pyi (does not have "out"), another under ./_C/_VariableFunctions.pyi (has "out"). However, both are just generated signatures that don't really represent the actual implementation. I think they are there to make the IDEs work.
The actual implementation should be at aten/src/ATen/native/native_functions.yaml, which has "out".
Even if the actual op does not support "out", having an optional out argument should not break the inspect.Signature.bind function. so we should still be fine, and it would be better to include "out" here



Encountered a few minor issues when compiling a transformer-based model using torch.compile with very large batch sizes, submitting the fix here.