Refactor dispatch and redistribute to expose local tensor APIs#476
Refactor dispatch and redistribute to expose local tensor APIs#476
Conversation
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Since `make_fx` cannot yet handle tensor subclasses correctly, refactor dispatch code to expose APIs that takes local tensors and then trace these APIs instead. [ghstack-poisoned]
[ghstack-poisoned]
…APIs" Since `make_fx` cannot yet handle tensor subclasses correctly, refactor dispatch code to expose APIs that takes local tensors and then trace these APIs instead. [ghstack-poisoned]
…APIs" Since `make_fx` cannot yet handle tensor subclasses correctly, refactor dispatch code to expose APIs that takes local tensors and then trace these APIs instead. [ghstack-poisoned]
…APIs" Since `make_fx` cannot yet handle tensor subclasses correctly, refactor dispatch code to expose APIs that takes local tensors and then trace these APIs instead. [ghstack-poisoned]
…APIs" Since `make_fx` cannot yet handle tensor subclasses correctly, refactor dispatch code to expose APIs that takes local tensors and then trace these APIs instead. [ghstack-poisoned]
…se local tensor APIs" Since `make_fx` cannot yet handle tensor subclasses correctly, refactor dispatch code to expose APIs that takes local tensors and then trace these APIs instead. [ghstack-poisoned]
…APIs" Since `make_fx` cannot yet handle tensor subclasses correctly, refactor dispatch code to expose APIs that takes local tensors and then trace these APIs instead. [ghstack-poisoned]
aazzolini
left a comment
There was a problem hiding this comment.
Could you provide a pseudo-code of the intended use of these functions exposed?
spmd/tensor/dispatch.py
Outdated
|
|
||
|
|
||
| def operator_dispatch( | ||
| def prepare_inputs( |
There was a problem hiding this comment.
Is the goal to use prepare_inputs directly, and bypass operator_dispatch, when using make_fx?
If this is the case, then we will probably miss some operators' implementations that directly use DTensor.
Could you provide a quick pseudo-code of the intended call sequence?
There was a problem hiding this comment.
This is how it is used in the PR on top
we use make_fx to trace two things 1. redistributed inputs 2. local op. IIUC, we don't redistribute output at the moment in DT? If we do that in the future, we will also needs to add that to trace as well.
Ideally, I wanted to trace dispatch_with_local_tensors(local_args, arg_specs, output_specs), which will trigger redistribute on input, local comp op, redistribute on output.
There was a problem hiding this comment.
The issue i see is that if we do it this way we will be missing 1) the decompositions; 2) the custom op implementations (those that don't have a propagation rule, but that require direct DTensor-aware implementation). I think for some ops we will actually need those.
Would it be possible to work around make_fx limitations somehow and still be able to trace the implementation of DTensor at some level?
…se local tensor APIs" Since `make_fx` cannot yet handle tensor subclasses correctly, refactor dispatch code to expose APIs that takes local tensors and then trace these APIs instead. [ghstack-poisoned]
…APIs" Since `make_fx` cannot yet handle tensor subclasses correctly, refactor dispatch code to expose APIs that takes local tensors and then trace these APIs instead. [ghstack-poisoned]
…se local tensor APIs" Since `make_fx` cannot yet handle tensor subclasses correctly, refactor dispatch code to expose APIs that takes local tensors and then trace these APIs instead. [ghstack-poisoned]
…APIs" Since `make_fx` cannot yet handle tensor subclasses correctly, refactor dispatch code to expose APIs that takes local tensors and then trace these APIs instead. [ghstack-poisoned]
spmd/tensor/redistribute.py
Outdated
| return new_local_tensor | ||
|
|
||
|
|
||
| def redistribute_spmd_tensor( |
There was a problem hiding this comment.
can we merge this with the Redistribute autograd function?
There was a problem hiding this comment.
In that case, I assume pack_args_kwargs_with_local_tensor will then also need to call into Redistribute autograd function, which occurs in the __torch_dispatch__ function under no_grad mode?
There was a problem hiding this comment.
Oh I mean we should probably also change pack_args_kwargs_with_local_tensor by calling into _redistribute_with_local_tensor directly? so that we could safely delete this redistribute_with_spmd_tensor api
There was a problem hiding this comment.
sure, let me update that.
There was a problem hiding this comment.
Hey @wanchaol , I tried update pack_args_kwargs_with_local_tensor and remove redistributed_dtensor, but code becomes a bit verbose, as _redistribute_with_local_tensor take more argument and requires one additional DTensor wrapping. Let me know if you prefer to get rid of redistributed_dtensor. I can do that in a follow up PR.
spmd/tensor/redistribute.py
Outdated
| return new_local_tensor | ||
|
|
||
|
|
||
| def redistribute_spmd_tensor( |
There was a problem hiding this comment.
Oh I mean we should probably also change pack_args_kwargs_with_local_tensor by calling into _redistribute_with_local_tensor directly? so that we could safely delete this redistribute_with_spmd_tensor api
…se local tensor APIs" Since `make_fx` cannot yet handle tensor subclasses correctly, refactor dispatch code to expose APIs that takes local tensors and then trace these APIs instead. [ghstack-poisoned]
…APIs" Since `make_fx` cannot yet handle tensor subclasses correctly, refactor dispatch code to expose APIs that takes local tensors and then trace these APIs instead. [ghstack-poisoned]
…se local tensor APIs" Since `make_fx` cannot yet handle tensor subclasses correctly, refactor dispatch code to expose APIs that takes local tensors and then trace these APIs instead. [ghstack-poisoned]
…APIs" Since `make_fx` cannot yet handle tensor subclasses correctly, refactor dispatch code to expose APIs that takes local tensors and then trace these APIs instead. [ghstack-poisoned]
…se local tensor APIs" Since `make_fx` cannot yet handle tensor subclasses correctly, refactor dispatch code to expose APIs that takes local tensors and then trace these APIs instead. [ghstack-poisoned]
…APIs" Since `make_fx` cannot yet handle tensor subclasses correctly, refactor dispatch code to expose APIs that takes local tensors and then trace these APIs instead. [ghstack-poisoned]
Since `make_fx` cannot yet handle tensor subclasses correctly, refactor dispatch code to expose APIs that takes local tensors and then trace these APIs instead.
Stack from ghstack (oldest at bottom):
Since
make_fxcannot yet handle tensor subclasses correctly,refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.