Skip to content

Basic dynamo support for traceable collectives#94440

Closed
wconstab wants to merge 28 commits intogh/wconstab/93/basefrom
gh/wconstab/93/head
Closed

Basic dynamo support for traceable collectives#94440
wconstab wants to merge 28 commits intogh/wconstab/93/basefrom
gh/wconstab/93/head

Conversation

@wconstab
Copy link
Copy Markdown
Contributor

@wconstab wconstab commented Feb 8, 2023

Stack from ghstack (oldest at bottom):

Make traceable collectives work with torchdynamo,
bypassing problems with tracing the AsyncTensor subclass.

Accept a suboptimal solution for now, and optimize it later.
For now, wait happens immediately, which generally forces an early sync.

Later, find a way either in dynamo or AOT stack to handle
AsyncCollectiveTensor to get the wait in the optimal place.

Note on implementation:

  • Dynamo traces 'user-level' fc apis that are designed to behave differently
    in eager vs compiled. In eager, there will be work-obj registration and
    a wrapper subclass will insert a 'wait' call at the appropriate time.
    In compile/trace mode, wait will be immetiately called, and work obj
    registration is required to be handled by the compile backend at runtime.
  • Dynamo needs to trace into some of the helper functions in the 'user-level'
    api, such as '_expand_group' which is essentially a constant transformation.

cc @soumith @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire

Attempt to make traceable allreduce work with torchdynamo,
bypassing problems with tracing the AsyncTensor subclass.

Accept a suboptimal solution for now, and optimize it later.
The long term solution should be "torchdispatch tensor subclass support", which is WIP

Option 1: Dynamo inlines Traceable Collectives, and handles AsyncTensor subclass
- AsyncTensor needs special handling since it is an unsupported subclass
- Special-case Dynamo to perform the simple function of AsyncTensor on its behalf:
  insert a 'Wait' op either "right after" the op that returned AsyncTensor (least optimal)
  or right before first user of AsyncTensor (optimal but more complicated)

  (a) add 'torch.distributed.traceable_collectives'
	  to `_dynamo.config.skipfiles_inline_module_allowlist`
  (b) implement AsyncTensor special behavior

Option 2: Dynamo allows traceable collectives in graph, doesn't inline
- we avoid Dynamo worrying about AsyncTensor at all. Just let it come up later,
  when doing dispatch tracing during AotAutograd
  (a) maybe the default already allows this op in graph!
  (b) still need AsyncTensor handling? Bc we seem to be faketensor tracing the op,
	  which fails when wrapping a FakeTensor in an AsyncTensor

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Feb 8, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/94440

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Merge Blocking SEVs

There is 1 active merge blocking SEVs. Please view them below:

If you must merge, use @pytorchbot merge -f.

✅ No Failures

As of commit b74309c:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Attempt to make traceable allreduce work with torchdynamo,
bypassing problems with tracing the AsyncTensor subclass.

Accept a suboptimal solution for now, and optimize it later.
The long term solution should be "torchdispatch tensor subclass support", which is WIP

Option 1: Dynamo inlines Traceable Collectives, and handles AsyncTensor subclass
- AsyncTensor needs special handling since it is an unsupported subclass
- Special-case Dynamo to perform the simple function of AsyncTensor on its behalf:
  insert a 'Wait' op either "right after" the op that returned AsyncTensor (least optimal)
  or right before first user of AsyncTensor (optimal but more complicated)

  (a) add 'torch.distributed.traceable_collectives'
	  to `_dynamo.config.skipfiles_inline_module_allowlist`
  (b) implement AsyncTensor special behavior

Option 2: Dynamo allows traceable collectives in graph, doesn't inline
- we avoid Dynamo worrying about AsyncTensor at all. Just let it come up later,
  when doing dispatch tracing during AotAutograd
  (a) maybe the default already allows this op in graph!
  (b) still need AsyncTensor handling? Bc we seem to be faketensor tracing the op,
	  which fails when wrapping a FakeTensor in an AsyncTensor

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
Attempt to make traceable allreduce work with torchdynamo,
bypassing problems with tracing the AsyncTensor subclass.

Accept a suboptimal solution for now, and optimize it later.
The long term solution should be "torchdispatch tensor subclass support", which is WIP

Option 1: Dynamo inlines Traceable Collectives, and handles AsyncTensor subclass
- AsyncTensor needs special handling since it is an unsupported subclass
- Special-case Dynamo to perform the simple function of AsyncTensor on its behalf:
  insert a 'Wait' op either "right after" the op that returned AsyncTensor (least optimal)
  or right before first user of AsyncTensor (optimal but more complicated)

  (a) add 'torch.distributed.traceable_collectives'
	  to `_dynamo.config.skipfiles_inline_module_allowlist`
  (b) implement AsyncTensor special behavior

Option 2: Dynamo allows traceable collectives in graph, doesn't inline
- we avoid Dynamo worrying about AsyncTensor at all. Just let it come up later,
  when doing dispatch tracing during AotAutograd
  (a) maybe the default already allows this op in graph!
  (b) still need AsyncTensor handling? Bc we seem to be faketensor tracing the op,
	  which fails when wrapping a FakeTensor in an AsyncTensor

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
Attempt to make traceable allreduce work with torchdynamo,
bypassing problems with tracing the AsyncTensor subclass.

Accept a suboptimal solution for now, and optimize it later.
The long term solution should be "torchdispatch tensor subclass support", which is WIP

Option 1: Dynamo inlines Traceable Collectives, and handles AsyncTensor subclass
- AsyncTensor needs special handling since it is an unsupported subclass
- Special-case Dynamo to perform the simple function of AsyncTensor on its behalf:
  insert a 'Wait' op either "right after" the op that returned AsyncTensor (least optimal)
  or right before first user of AsyncTensor (optimal but more complicated)

  (a) add 'torch.distributed.traceable_collectives'
	  to `_dynamo.config.skipfiles_inline_module_allowlist`
  (b) implement AsyncTensor special behavior

Option 2: Dynamo allows traceable collectives in graph, doesn't inline
- we avoid Dynamo worrying about AsyncTensor at all. Just let it come up later,
  when doing dispatch tracing during AotAutograd
  (a) maybe the default already allows this op in graph!
  (b) still need AsyncTensor handling? Bc we seem to be faketensor tracing the op,
	  which fails when wrapping a FakeTensor in an AsyncTensor

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Feb 9, 2023
Attempt to make traceable allreduce work with torchdynamo,
bypassing problems with tracing the AsyncTensor subclass.

Accept a suboptimal solution for now, and optimize it later.
The long term solution should be "torchdispatch tensor subclass support", which is WIP

Option 1: Dynamo inlines Traceable Collectives, and handles AsyncTensor subclass
- AsyncTensor needs special handling since it is an unsupported subclass
- Special-case Dynamo to perform the simple function of AsyncTensor on its behalf:
  insert a 'Wait' op either "right after" the op that returned AsyncTensor (least optimal)
  or right before first user of AsyncTensor (optimal but more complicated)

  (a) add 'torch.distributed.traceable_collectives'
	  to `_dynamo.config.skipfiles_inline_module_allowlist`
  (b) implement AsyncTensor special behavior

Option 2: Dynamo allows traceable collectives in graph, doesn't inline
- we avoid Dynamo worrying about AsyncTensor at all. Just let it come up later,
  when doing dispatch tracing during AotAutograd
  (a) maybe the default already allows this op in graph!
  (b) still need AsyncTensor handling? Bc we seem to be faketensor tracing the op,
	  which fails when wrapping a FakeTensor in an AsyncTensor

ghstack-source-id: 285432f
Pull Request resolved: #94440
Attempt to make traceable allreduce work with torchdynamo,
bypassing problems with tracing the AsyncTensor subclass.

Accept a suboptimal solution for now, and optimize it later.
The long term solution should be "torchdispatch tensor subclass support", which is WIP

Option 1: Dynamo inlines Traceable Collectives, and handles AsyncTensor subclass
- AsyncTensor needs special handling since it is an unsupported subclass
- Special-case Dynamo to perform the simple function of AsyncTensor on its behalf:
  insert a 'Wait' op either "right after" the op that returned AsyncTensor (least optimal)
  or right before first user of AsyncTensor (optimal but more complicated)

  (a) add 'torch.distributed.traceable_collectives'
	  to `_dynamo.config.skipfiles_inline_module_allowlist`
  (b) implement AsyncTensor special behavior

Option 2: Dynamo allows traceable collectives in graph, doesn't inline
- we avoid Dynamo worrying about AsyncTensor at all. Just let it come up later,
  when doing dispatch tracing during AotAutograd
  (a) maybe the default already allows this op in graph!
  (b) still need AsyncTensor handling? Bc we seem to be faketensor tracing the op,
	  which fails when wrapping a FakeTensor in an AsyncTensor

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
Attempt to make traceable allreduce work with torchdynamo,
bypassing problems with tracing the AsyncTensor subclass.

Accept a suboptimal solution for now, and optimize it later.
The long term solution should be "torchdispatch tensor subclass support", which is WIP

Option 1: Dynamo inlines Traceable Collectives, and handles AsyncTensor subclass
- AsyncTensor needs special handling since it is an unsupported subclass
- Special-case Dynamo to perform the simple function of AsyncTensor on its behalf:
  insert a 'Wait' op either "right after" the op that returned AsyncTensor (least optimal)
  or right before first user of AsyncTensor (optimal but more complicated)

  (a) add 'torch.distributed.traceable_collectives'
	  to `_dynamo.config.skipfiles_inline_module_allowlist`
  (b) implement AsyncTensor special behavior

Option 2: Dynamo allows traceable collectives in graph, doesn't inline
- we avoid Dynamo worrying about AsyncTensor at all. Just let it come up later,
  when doing dispatch tracing during AotAutograd
  (a) maybe the default already allows this op in graph!
  (b) still need AsyncTensor handling? Bc we seem to be faketensor tracing the op,
	  which fails when wrapping a FakeTensor in an AsyncTensor

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Feb 9, 2023
Attempt to make traceable allreduce work with torchdynamo,
bypassing problems with tracing the AsyncTensor subclass.

Accept a suboptimal solution for now, and optimize it later.
The long term solution should be "torchdispatch tensor subclass support", which is WIP

Option 1: Dynamo inlines Traceable Collectives, and handles AsyncTensor subclass
- AsyncTensor needs special handling since it is an unsupported subclass
- Special-case Dynamo to perform the simple function of AsyncTensor on its behalf:
  insert a 'Wait' op either "right after" the op that returned AsyncTensor (least optimal)
  or right before first user of AsyncTensor (optimal but more complicated)

  (a) add 'torch.distributed.traceable_collectives'
	  to `_dynamo.config.skipfiles_inline_module_allowlist`
  (b) implement AsyncTensor special behavior

Option 2: Dynamo allows traceable collectives in graph, doesn't inline
- we avoid Dynamo worrying about AsyncTensor at all. Just let it come up later,
  when doing dispatch tracing during AotAutograd
  (a) maybe the default already allows this op in graph!
  (b) still need AsyncTensor handling? Bc we seem to be faketensor tracing the op,
	  which fails when wrapping a FakeTensor in an AsyncTensor

ghstack-source-id: 5fe4823
Pull Request resolved: #94440
@kumpera
Copy link
Copy Markdown
Contributor

kumpera commented Apr 24, 2023

Maybe there's an alternative to having dynamo fully handling AsyncCollectiveTensor.

We could:

  1. Do wait scheduling in inductor's on its FX graph.
  2. If a tensor is not used in the FX graph, escape it as an AsyncCollectiveTensor

@wconstab
Copy link
Copy Markdown
Contributor Author

Maybe there's an alternative to having dynamo fully handling AsyncCollectiveTensor.

I'm not sure which version of this PR you read, or if my update failed to push. My earlier approach was to handle AsyncCollectiveTensor in dynamo, but for the latest PR it is just skipped and dynamo traces the collective and the wait immediately.

Make traceable collectives work with torchdynamo,
bypassing problems with tracing the AsyncTensor subclass.

Accept a suboptimal solution for now, and optimize it later.
For now, wait happens immediately, which generally forces an early sync.

Later, find a way either in dynamo or AOT stack to handle
AsyncCollectiveTensor to get the wait in the optimal place.

Note on implementation:
- Dynamo traces 'user-level' fc apis that are designed to behave differently
  in eager vs compiled.  In eager, there will be work-obj registration and
  a wrapper subclass will insert a 'wait' call at the appropriate time.
  In compile/trace mode, wait will be immetiately called, and work obj
  registration is required to be handled by the compile backend at runtime.
- Dynamo needs to trace into some of the helper functions in the 'user-level'
  api, such as '_expand_group' which is essentially a constant transformation.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Apr 24, 2023
Make traceable collectives work with torchdynamo,
bypassing problems with tracing the AsyncTensor subclass.

Accept a suboptimal solution for now, and optimize it later.
For now, wait happens immediately, which generally forces an early sync.

Later, find a way either in dynamo or AOT stack to handle
AsyncCollectiveTensor to get the wait in the optimal place.

Note on implementation:
- Dynamo traces 'user-level' fc apis that are designed to behave differently
  in eager vs compiled.  In eager, there will be work-obj registration and
  a wrapper subclass will insert a 'wait' call at the appropriate time.
  In compile/trace mode, wait will be immetiately called, and work obj
  registration is required to be handled by the compile backend at runtime.
- Dynamo needs to trace into some of the helper functions in the 'user-level'
  api, such as '_expand_group' which is essentially a constant transformation.

ghstack-source-id: 7b94d6f
Pull Request resolved: #94440
Make traceable collectives work with torchdynamo,
bypassing problems with tracing the AsyncTensor subclass.

Accept a suboptimal solution for now, and optimize it later.
For now, wait happens immediately, which generally forces an early sync.

Later, find a way either in dynamo or AOT stack to handle
AsyncCollectiveTensor to get the wait in the optimal place.

Note on implementation:
- Dynamo traces 'user-level' fc apis that are designed to behave differently
  in eager vs compiled.  In eager, there will be work-obj registration and
  a wrapper subclass will insert a 'wait' call at the appropriate time.
  In compile/trace mode, wait will be immetiately called, and work obj
  registration is required to be handled by the compile backend at runtime.
- Dynamo needs to trace into some of the helper functions in the 'user-level'
  api, such as '_expand_group' which is essentially a constant transformation.

[ghstack-poisoned]
Make traceable collectives work with torchdynamo,
bypassing problems with tracing the AsyncTensor subclass.

Accept a suboptimal solution for now, and optimize it later.
For now, wait happens immediately, which generally forces an early sync.

Later, find a way either in dynamo or AOT stack to handle
AsyncCollectiveTensor to get the wait in the optimal place.

Note on implementation:
- Dynamo traces 'user-level' fc apis that are designed to behave differently
  in eager vs compiled.  In eager, there will be work-obj registration and
  a wrapper subclass will insert a 'wait' call at the appropriate time.
  In compile/trace mode, wait will be immetiately called, and work obj
  registration is required to be handled by the compile backend at runtime.
- Dynamo needs to trace into some of the helper functions in the 'user-level'
  api, such as '_expand_group' which is essentially a constant transformation.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Apr 26, 2023
Make traceable collectives work with torchdynamo,
bypassing problems with tracing the AsyncTensor subclass.

Accept a suboptimal solution for now, and optimize it later.
For now, wait happens immediately, which generally forces an early sync.

Later, find a way either in dynamo or AOT stack to handle
AsyncCollectiveTensor to get the wait in the optimal place.

Note on implementation:
- Dynamo traces 'user-level' fc apis that are designed to behave differently
  in eager vs compiled.  In eager, there will be work-obj registration and
  a wrapper subclass will insert a 'wait' call at the appropriate time.
  In compile/trace mode, wait will be immetiately called, and work obj
  registration is required to be handled by the compile backend at runtime.
- Dynamo needs to trace into some of the helper functions in the 'user-level'
  api, such as '_expand_group' which is essentially a constant transformation.

ghstack-source-id: 6936d1f
Pull Request resolved: #94440
Copy link
Copy Markdown
Contributor

@kumpera kumpera left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure on wh

Make traceable collectives work with torchdynamo,
bypassing problems with tracing the AsyncTensor subclass.

Accept a suboptimal solution for now, and optimize it later.
For now, wait happens immediately, which generally forces an early sync.

Later, find a way either in dynamo or AOT stack to handle
AsyncCollectiveTensor to get the wait in the optimal place.

Note on implementation:
- Dynamo traces 'user-level' fc apis that are designed to behave differently
  in eager vs compiled.  In eager, there will be work-obj registration and
  a wrapper subclass will insert a 'wait' call at the appropriate time.
  In compile/trace mode, wait will be immetiately called, and work obj
  registration is required to be handled by the compile backend at runtime.
- Dynamo needs to trace into some of the helper functions in the 'user-level'
  api, such as '_expand_group' which is essentially a constant transformation.

[ghstack-poisoned]
Make traceable collectives work with torchdynamo,
bypassing problems with tracing the AsyncTensor subclass.

Accept a suboptimal solution for now, and optimize it later.
For now, wait happens immediately, which generally forces an early sync.

Later, find a way either in dynamo or AOT stack to handle
AsyncCollectiveTensor to get the wait in the optimal place.

Note on implementation:
- Dynamo traces 'user-level' fc apis that are designed to behave differently
  in eager vs compiled.  In eager, there will be work-obj registration and
  a wrapper subclass will insert a 'wait' call at the appropriate time.
  In compile/trace mode, wait will be immetiately called, and work obj
  registration is required to be handled by the compile backend at runtime.
- Dynamo needs to trace into some of the helper functions in the 'user-level'
  api, such as '_expand_group' which is essentially a constant transformation.

[ghstack-poisoned]
Make traceable collectives work with torchdynamo,
bypassing problems with tracing the AsyncTensor subclass.

Accept a suboptimal solution for now, and optimize it later.
For now, wait happens immediately, which generally forces an early sync.

Later, find a way either in dynamo or AOT stack to handle
AsyncCollectiveTensor to get the wait in the optimal place.

Note on implementation:
- Dynamo traces 'user-level' fc apis that are designed to behave differently
  in eager vs compiled.  In eager, there will be work-obj registration and
  a wrapper subclass will insert a 'wait' call at the appropriate time.
  In compile/trace mode, wait will be immetiately called, and work obj
  registration is required to be handled by the compile backend at runtime.
- Dynamo needs to trace into some of the helper functions in the 'user-level'
  api, such as '_expand_group' which is essentially a constant transformation.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Apr 26, 2023
Make traceable collectives work with torchdynamo,
bypassing problems with tracing the AsyncTensor subclass.

Accept a suboptimal solution for now, and optimize it later.
For now, wait happens immediately, which generally forces an early sync.

Later, find a way either in dynamo or AOT stack to handle
AsyncCollectiveTensor to get the wait in the optimal place.

Note on implementation:
- Dynamo traces 'user-level' fc apis that are designed to behave differently
  in eager vs compiled.  In eager, there will be work-obj registration and
  a wrapper subclass will insert a 'wait' call at the appropriate time.
  In compile/trace mode, wait will be immetiately called, and work obj
  registration is required to be handled by the compile backend at runtime.
- Dynamo needs to trace into some of the helper functions in the 'user-level'
  api, such as '_expand_group' which is essentially a constant transformation.

ghstack-source-id: 8ba2e58
Pull Request resolved: #94440
@wconstab wconstab added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 27, 2023
@wconstab
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request May 3, 2023
… build failures (#100424)

Summary:
This diff is reverting D45387167
D45387167: Basic dynamo support for traceable collectives (#94440) by wconstab has been identified to be causing the following test or build failures (internal)

If you believe this diff has been generated in error you may Commandeer and Abandon it.

Test Plan: NA

Reviewed By: s4ayub

Differential Revision: D45448312

Pull Request resolved: #100424
Approved by: https://github.com/rohan-varma, https://github.com/kumpera
@facebook-github-bot
Copy link
Copy Markdown
Contributor

This pull request has been reverted by 287f74c. To re-land this change, please open another pull request, assignthe same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk).

@facebook-github-bot facebook-github-bot deleted the gh/wconstab/93/head branch June 8, 2023 19:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants