Skip to content
This repository was archived by the owner on Aug 5, 2025. It is now read-only.

Refactor dispatch and redistribute to expose local tensor APIs#476

Merged
mrshenli merged 24 commits intomainfrom
gh/mrshenli/4/head
Oct 3, 2022
Merged

Refactor dispatch and redistribute to expose local tensor APIs#476
mrshenli merged 24 commits intomainfrom
gh/mrshenli/4/head

Conversation

@mrshenli
Copy link
Copy Markdown
Contributor

@mrshenli mrshenli commented Sep 19, 2022

Stack from ghstack (oldest at bottom):

Since make_fx cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.

[ghstack-poisoned]
Since `make_fx` cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.

[ghstack-poisoned]
mrshenli added a commit that referenced this pull request Sep 19, 2022
Since `make_fx` cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.

ghstack-source-id: ab716d0
Pull Request resolved: #476
…APIs"

Since `make_fx` cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.

[ghstack-poisoned]
mrshenli added a commit that referenced this pull request Sep 19, 2022
Since `make_fx` cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.

ghstack-source-id: e8afede
Pull Request resolved: #476
…APIs"

Since `make_fx` cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.

[ghstack-poisoned]
mrshenli added a commit that referenced this pull request Sep 19, 2022
Since `make_fx` cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.

ghstack-source-id: 8b9c92f
Pull Request resolved: #476
…APIs"

Since `make_fx` cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.

[ghstack-poisoned]
mrshenli added a commit that referenced this pull request Sep 19, 2022
Since `make_fx` cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.

ghstack-source-id: ab161a8
Pull Request resolved: #476
…APIs"

Since `make_fx` cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.

[ghstack-poisoned]
mrshenli added a commit that referenced this pull request Sep 19, 2022
Since `make_fx` cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.

ghstack-source-id: eff09f6
Pull Request resolved: #476
…se local tensor APIs"

Since `make_fx` cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.

[ghstack-poisoned]
…APIs"

Since `make_fx` cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.

[ghstack-poisoned]
Copy link
Copy Markdown
Contributor

@aazzolini aazzolini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you provide a pseudo-code of the intended use of these functions exposed?



def operator_dispatch(
def prepare_inputs(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the goal to use prepare_inputs directly, and bypass operator_dispatch, when using make_fx?

If this is the case, then we will probably miss some operators' implementations that directly use DTensor.

Could you provide a quick pseudo-code of the intended call sequence?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how it is used in the PR on top

https://github.com/pytorch/tau/blob/a06ce4426bbddb84dc75eb9f0c10894c5c80bf41/test/spmd/test_tracing.py#L267-L301

we use make_fx to trace two things 1. redistributed inputs 2. local op. IIUC, we don't redistribute output at the moment in DT? If we do that in the future, we will also needs to add that to trace as well.

Ideally, I wanted to trace dispatch_with_local_tensors(local_args, arg_specs, output_specs), which will trigger redistribute on input, local comp op, redistribute on output.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue i see is that if we do it this way we will be missing 1) the decompositions; 2) the custom op implementations (those that don't have a propagation rule, but that require direct DTensor-aware implementation). I think for some ops we will actually need those.

Would it be possible to work around make_fx limitations somehow and still be able to trace the implementation of DTensor at some level?

…se local tensor APIs"

Since `make_fx` cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.

[ghstack-poisoned]
…APIs"

Since `make_fx` cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.

[ghstack-poisoned]
…se local tensor APIs"

Since `make_fx` cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.

[ghstack-poisoned]
…APIs"

Since `make_fx` cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.

[ghstack-poisoned]
return new_local_tensor


def redistribute_spmd_tensor(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we merge this with the Redistribute autograd function?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, I assume pack_args_kwargs_with_local_tensor will then also need to call into Redistribute autograd function, which occurs in the __torch_dispatch__ function under no_grad mode?

https://github.com/pytorch/tau/blob/c71e2866015ca23beee4e17c9d8dae415d5f86b4/spmd/tensor/utils.py#L57

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I mean we should probably also change pack_args_kwargs_with_local_tensor by calling into _redistribute_with_local_tensor directly? so that we could safely delete this redistribute_with_spmd_tensor api

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, let me update that.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @wanchaol , I tried update pack_args_kwargs_with_local_tensor and remove redistributed_dtensor, but code becomes a bit verbose, as _redistribute_with_local_tensor take more argument and requires one additional DTensor wrapping. Let me know if you prefer to get rid of redistributed_dtensor. I can do that in a follow up PR.

return new_local_tensor


def redistribute_spmd_tensor(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I mean we should probably also change pack_args_kwargs_with_local_tensor by calling into _redistribute_with_local_tensor directly? so that we could safely delete this redistribute_with_spmd_tensor api

Copy link
Copy Markdown
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

…se local tensor APIs"

Since `make_fx` cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.

[ghstack-poisoned]
…APIs"

Since `make_fx` cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.

[ghstack-poisoned]
mrshenli added a commit that referenced this pull request Oct 3, 2022
Since `make_fx` cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.

ghstack-source-id: 039fd93
Pull Request resolved: #476
…se local tensor APIs"

Since `make_fx` cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.

[ghstack-poisoned]
…APIs"

Since `make_fx` cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.

[ghstack-poisoned]
mrshenli added a commit that referenced this pull request Oct 3, 2022
Since `make_fx` cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.

ghstack-source-id: 825c215
Pull Request resolved: #476
@mrshenli mrshenli mentioned this pull request Oct 3, 2022
…se local tensor APIs"

Since `make_fx` cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.

[ghstack-poisoned]
…APIs"

Since `make_fx` cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.

[ghstack-poisoned]
mrshenli added a commit that referenced this pull request Oct 3, 2022
Since `make_fx` cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.

ghstack-source-id: 43857fe
Pull Request resolved: #476
@mrshenli mrshenli changed the base branch from gh/mrshenli/4/base to main October 3, 2022 16:40
@mrshenli mrshenli merged commit 5e38ab6 into main Oct 3, 2022
richqyz pushed a commit that referenced this pull request Oct 5, 2022
Since `make_fx` cannot yet handle tensor subclasses correctly,
refactor dispatch code to expose APIs that takes local tensors and
then trace these APIs instead.
@facebook-github-bot facebook-github-bot deleted the gh/mrshenli/4/head branch November 3, 2022 14:19
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants