[aotd] coerce_same_metadata_as_tangent with expected_type for e.g.AsyncCollectiveTensor#139095
[aotd] coerce_same_metadata_as_tangent with expected_type for e.g.AsyncCollectiveTensor#139095IvanKobzarev wants to merge 15 commits intogh/IvanKobzarev/81/basefrom
Conversation
…ncCollectiveTensor [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139095
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f4d09da with merge base 5f266b5 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| if is_subclass and not is_subclass_meta: | ||
| # Unexpected subclass, during tracing we guessed it was a plain Tensor | ||
| if hasattr(x, "__coerce_same_metadata_as_tangent__"): | ||
| x = x.__coerce_same_metadata_as_tangent__(None, torch.Tensor) |
There was a problem hiding this comment.
hmm. Two things:
(1) Right now this branch is hardcoded for the case where we expected a subclass tangent but got a plain tensor tangent. But we specifically updated the coerce API to be more generic: it can technically allow a subclass to convert to any other subclass type, if they are able to handle it. It seems to me like if we are going with that more general API, we should properly handle that here: don't haredcode torch.Tensor, just directly pass in the type(x).
Since we're potentially being BC breaking (I think only DTensor uses this API, although I think a few people have forked DTensor out of tree over time), it might also be better to optionally only pass the type argument in when the types are different
(2) you added this call as a completely new one, on top of the existing a call below (x = x.__coerce_same_metadata_as_tangent__(meta.meta)). It seems better to consolidate them into a single call?
…for e.g.AsyncCollectiveTensor" Based on discussion here: #138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
…for e.g.AsyncCollectiveTensor" Based on discussion here: #138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
torch/distributed/tensor/_api.py
Outdated
|
|
||
| def __coerce_same_metadata_as_tangent__(self, flatten_spec): | ||
| def __coerce_same_metadata_as_tangent__(self, flatten_spec, expected_type=None): | ||
| assert expected_type is None |
There was a problem hiding this comment.
Nit: this can be user facing, if the user ends up running their code in such a way that e.g. the expected tangents are DTensors but the actual tangents are plain tensors. So we should make sure the error message if this assert fails is very clear. (include the actual/expected types and metadata)
There was a problem hiding this comment.
Or alternatively, we could have this function return None if it is unable to properly coerce, and let AOTAutograd raise the error itself if it sees a None return
There was a problem hiding this comment.
Yes, with current logic if it returns None we will raise an error with all details. https://github.com/pytorch/pytorch/blob/main/torch/_functorch/_aot_autograd/runtime_wrappers.py#L1486
I think we should also document the logic of coercing by metadata and type, the meaning of None etc.
There was a problem hiding this comment.
hmm this code as-written (the assert above) still seem like it will regress our error message?
Prior to this PR, if we expected a DTensor tangent but we got a plain tensor tangent, we would get this nice error message: https://github.com/pytorch/pytorch/blob/main/torch/_functorch/_aot_autograd/runtime_wrappers.py#L1505
Seems like you want to return None here?
There was a problem hiding this comment.
Yes, better to not assert and just return None here for user-friendly error message.
| raise RuntimeError("Not implemented") | ||
|
|
||
| t = self.trigger_wait() | ||
| while isinstance(t, AsyncCollectiveTensor): |
There was a problem hiding this comment.
nit: we shouldn't need this while loop either (I think we will be a in a very weird place if there are every nested AsyncCollectiveTensors, so it seems pointless to defensively program for it)
| def maybe_coerce_to_memory_format(t, memory_format): | ||
| if not t.is_contiguous(memory_format=meta.memory_format): | ||
| return t.contiguous(memory_format=meta.memory_format) | ||
| return t |
There was a problem hiding this comment.
superficially from reading this function, it's not clear what the return type is supposed to be? (here you are returning a single tensor t, while lower down you return x, [x])
There was a problem hiding this comment.
maybe_coerce_to_memory_format is only for memory_format processing of one argument.
The whole function process_runtime_tangent returns tuple to be able to do flatenning (2d item in tuple) at the same traversal as processing, to do only one traversal in total.
So the return type is Tuple[ChangedRawItem, FlattenedChangedItems]
|
|
||
| if not x.is_contiguous(memory_format=meta.memory_format): | ||
| x = x.contiguous(memory_format=meta.memory_format) | ||
| if is_subclass and not is_subclass_meta: |
There was a problem hiding this comment.
Ok, concretely, I think:
(1) instead of special-casing the is_subclass case, we should just check if the type of the two tensors is different, and if so then unconditionally call coerce(metadata or None, type(actual))
(2) if the coerce() function returns None, the subclass has indicated that it cannot perform the coercion, and so we can raise the old error message (with even more useful information in the error than the subclass could have given, like exactly which tangent in the tangent list we are up to)
There was a problem hiding this comment.
Ok. Fused the logic in one coercion and updated PR.
…for e.g.AsyncCollectiveTensor" Based on discussion here: #138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
…for e.g.AsyncCollectiveTensor" Based on discussion here: #138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
…for e.g.AsyncCollectiveTensor" Based on discussion here: #138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
…for e.g.AsyncCollectiveTensor" Based on discussion here: #138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
…for e.g.AsyncCollectiveTensor" Based on discussion here: #138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
…for e.g.AsyncCollectiveTensor" Based on discussion here: #138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
…for e.g.AsyncCollectiveTensor" Based on discussion here: #138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
…for e.g.AsyncCollectiveTensor" Based on discussion here: #138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
…for e.g.AsyncCollectiveTensor" Based on discussion here: #138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
| from torch.utils._python_dispatch import return_and_correct_aliasing | ||
|
|
||
|
|
||
| class WrapSC(torch.Tensor): |
There was a problem hiding this comment.
nit: up to you, but I find WrapperSubclass a bit clearer than WrapperSC
| def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): | ||
| assert meta is None | ||
| a = inner_tensors["a"] | ||
| if type(a) is torch.Tensor: |
There was a problem hiding this comment.
I guess the purpose of this check is that you only want to run the assertions below if we are at runtime, not trace time.
This isn't very robust though, since it will fail if the inner tensor a is itself another subclass.
You probably want to use is_fake(a), to tell if you are in the middle of tracing?
There was a problem hiding this comment.
Honestly, I just copied this from TwoTensor :)
| self, expected_metadata: Any, expected_type: Optional[Type] = None | ||
| ): | ||
| if expected_type is torch.Tensor: | ||
| return self.a |
There was a problem hiding this comment.
tbh this code as written is a bit confusing: the idea of this function is that it is supposed to enforce that the return type is the same as expected_type. You might want to add an assert type(self.a) == expected_type here?
There was a problem hiding this comment.
Thanks, yes, this will be more general to be able to wrap subclasses and coerce to them.
…for e.g.AsyncCollectiveTensor" Based on discussion here: #138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ncCollectiveTensor (pytorch#139095) Based on discussion here: pytorch#138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. Pull Request resolved: pytorch#139095 Approved by: https://github.com/bdhirsh
…ncCollectiveTensor ghstack-source-id: c38ae0f Pull Request resolved: pytorch/pytorch#139095
Stack from ghstack (oldest at bottom):
Based on discussion here: #138731
Introducing ability for subclass implement type convertion to expected_type.
Here if
expected_type=NonemeansSubclassClassis expected.E.g. for
DTensorwe may find tangentAsyncCollectiveTensorwhere we expectedTensor- in this caseexpected_type=Tensorwill be called during runtimeAdding implementation to AsyncCollectiveTensor, that just triggers
wait().cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o