Skip to content

Reland #2 "[C10] PG observability hooks. (#108815, #110907)"#111072

Closed
wconstab wants to merge 4 commits intogh/wconstab/205/basefrom
gh/wconstab/205/head
Closed

Reland #2 "[C10] PG observability hooks. (#108815, #110907)"#111072
wconstab wants to merge 4 commits intogh/wconstab/205/basefrom
gh/wconstab/205/head

Conversation

@wconstab
Copy link
Contributor

@wconstab wconstab commented Oct 11, 2023

Stack from ghstack (oldest at bottom):

This reverts commit 314a502.

Changes since original PR:
Reland 1

  • rename torch.distributed.hooks to torch.distributed._hooks

Reland 2

  • make _hooks importable even if !distributed.is_available()
  • handle cuda driver exit intermittent failure caused by new cuda api usage in callback caller (see prev PR in stack)

(original PR #108815 desc copied below)

Expose a set of observability hooks into C10D such that our users can
detect collectives failure both faster and more easily.

The design is similar to NCCL desync debug that it minimized the
overhead by doing most of the work out of the main thread.

This PR introduces a new module torch.distributed.hooks that exposes the following set of methods:

register_collective_start_hook
register_collective_end_hook
register_process_group_hook

The process group hook exposes PG creation on the member ranks and call them inline from the
the PG creation code. This is fine since this happens during initialization and a limited number of times.

The collective start/end hooks are fired from a single background thread. It reads
events from a C++ queue and dispatches over.

Queue notification is oddly done using a pipe, this is needed so python can abort the thread on shutdown
and have it as background thread. This is not possible with more reasonable choices like a condvar.

This reverts commit 314a502.

(original PR #108815 desc copied below)

Expose a set of observability hooks into C10D such that our users can
detect collectives failure both faster and more easily.

The design is similar to NCCL desync debug that it minimized the
overhead by doing most of the work out of the main thread.

This PR introduces a new module torch.distributed.hooks that exposes the following set of methods:

    register_collective_start_hook
    register_collective_end_hook
    register_process_group_hook

The process group hook exposes PG creation on the member ranks and call them inline from the
the PG creation code. This is fine since this happens during initialization and a limited number of times.

The collective start/end hooks are fired from a single background thread. It reads
events from a C++ queue and dispatches over.

Queue notification is oddly done using a pipe, this is needed so python can abort the thread on shutdown
and have it as background thread. This is not possible with more reasonable choices like a condvar.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 11, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111072

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 74c5ace with merge base 4d29b40 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (c10d) release notes category label Oct 11, 2023
This reverts commit 314a502.

(original PR #108815 desc copied below)

Expose a set of observability hooks into C10D such that our users can
detect collectives failure both faster and more easily.

The design is similar to NCCL desync debug that it minimized the
overhead by doing most of the work out of the main thread.

This PR introduces a new module torch.distributed.hooks that exposes the following set of methods:

    register_collective_start_hook
    register_collective_end_hook
    register_process_group_hook

The process group hook exposes PG creation on the member ranks and call them inline from the
the PG creation code. This is fine since this happens during initialization and a limited number of times.

The collective start/end hooks are fired from a single background thread. It reads
events from a C++ queue and dispatches over.

Queue notification is oddly done using a pipe, this is needed so python can abort the thread on shutdown
and have it as background thread. This is not possible with more reasonable choices like a condvar.

[ghstack-poisoned]
@wconstab wconstab added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 11, 2023
This reverts commit 314a502.

Changes since original PR:
Reland 1
 *  rename torch.distributed.hooks to torch.distributed._hooks

Reland 2
 * make _hooks importable even if !distributed.is_available()
 * handle cuda driver exit intermittent failure caused by new cuda api usage in callback caller (see prev PR in stack)

(original PR #108815 desc copied below)

Expose a set of observability hooks into C10D such that our users can
detect collectives failure both faster and more easily.

The design is similar to NCCL desync debug that it minimized the
overhead by doing most of the work out of the main thread.

This PR introduces a new module torch.distributed.hooks that exposes the following set of methods:

    register_collective_start_hook
    register_collective_end_hook
    register_process_group_hook

The process group hook exposes PG creation on the member ranks and call them inline from the
the PG creation code. This is fine since this happens during initialization and a limited number of times.

The collective start/end hooks are fired from a single background thread. It reads
events from a C++ queue and dispatches over.

Queue notification is oddly done using a pipe, this is needed so python can abort the thread on shutdown
and have it as background thread. This is not possible with more reasonable choices like a condvar.

[ghstack-poisoned]
@wconstab wconstab force-pushed the gh/wconstab/205/head branch from de8fd74 to e3334df Compare October 11, 2023 19:16
wconstab added a commit that referenced this pull request Oct 11, 2023
This reverts commit 314a502.

(original PR #108815 desc copied below)

Expose a set of observability hooks into C10D such that our users can
detect collectives failure both faster and more easily.

The design is similar to NCCL desync debug that it minimized the
overhead by doing most of the work out of the main thread.

This PR introduces a new module torch.distributed.hooks that exposes the following set of methods:

    register_collective_start_hook
    register_collective_end_hook
    register_process_group_hook

The process group hook exposes PG creation on the member ranks and call them inline from the
the PG creation code. This is fine since this happens during initialization and a limited number of times.

The collective start/end hooks are fired from a single background thread. It reads
events from a C++ queue and dispatches over.

Queue notification is oddly done using a pipe, this is needed so python can abort the thread on shutdown
and have it as background thread. This is not possible with more reasonable choices like a condvar.

ghstack-source-id: 07012bb
Pull Request resolved: #111072
This reverts commit 314a502.

Changes since original PR:
Reland 1
 *  rename torch.distributed.hooks to torch.distributed._hooks

Reland 2
 * make _hooks importable even if !distributed.is_available()
 * handle cuda driver exit intermittent failure caused by new cuda api usage in callback caller (see prev PR in stack)

(original PR #108815 desc copied below)

Expose a set of observability hooks into C10D such that our users can
detect collectives failure both faster and more easily.

The design is similar to NCCL desync debug that it minimized the
overhead by doing most of the work out of the main thread.

This PR introduces a new module torch.distributed.hooks that exposes the following set of methods:

    register_collective_start_hook
    register_collective_end_hook
    register_process_group_hook

The process group hook exposes PG creation on the member ranks and call them inline from the
the PG creation code. This is fine since this happens during initialization and a limited number of times.

The collective start/end hooks are fired from a single background thread. It reads
events from a C++ queue and dispatches over.

Queue notification is oddly done using a pipe, this is needed so python can abort the thread on shutdown
and have it as background thread. This is not possible with more reasonable choices like a condvar.

[ghstack-poisoned]
@wconstab wconstab force-pushed the gh/wconstab/205/head branch from e3334df to 74c5ace Compare October 11, 2023 22:26
wconstab added a commit that referenced this pull request Oct 11, 2023
This reverts commit 314a502.

(original PR #108815 desc copied below)

Expose a set of observability hooks into C10D such that our users can
detect collectives failure both faster and more easily.

The design is similar to NCCL desync debug that it minimized the
overhead by doing most of the work out of the main thread.

This PR introduces a new module torch.distributed.hooks that exposes the following set of methods:

    register_collective_start_hook
    register_collective_end_hook
    register_process_group_hook

The process group hook exposes PG creation on the member ranks and call them inline from the
the PG creation code. This is fine since this happens during initialization and a limited number of times.

The collective start/end hooks are fired from a single background thread. It reads
events from a C++ queue and dispatches over.

Queue notification is oddly done using a pipe, this is needed so python can abort the thread on shutdown
and have it as background thread. This is not possible with more reasonable choices like a condvar.

ghstack-source-id: 2c235b7
Pull Request resolved: #111072
Comment on lines +13 to +15
if not dist.is_available():
print("torch.distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit unconventional, but I see you following a pattern, will submit a followup PR later (i.e. python -c "import test_hooks;print('Hello World')" should work, but it would not in your case.

Comment on lines +152 to +153
except OSError:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit (imo it's better to print some warning to ease debugging later on)

Suggested change
except OSError:
pass
except OSError as e:
warn.warning(f"Failed to delete {self.file_name}: {e}")

Comment on lines +188 to +189
except OSError:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above (i.e. print warning). Also those two classes feel overlap with each other a bit. Perhaps this could be recactored into some abstract test class?

super().setUp()
self._spawn_processes()

def tearDown(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit (if you are doing type annotation above)

Suggested change
def tearDown(self):
def tearDown(self) -> None:


def coll_start(status):
starts.append(status)
print(f"col_start {len(starts)} rank{self.rank}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't structured logging be better than print?

Comment on lines +14 to +15
evt.timestamp =
std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit

Suggested change
evt.timestamp =
std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
using namespace std::chrono;
evt.timestamp = system_clock::to_time_t(system_clock::now());

namespace details {

// we start dropping events after this
const size_t MAX_QUEUE_SIZE = 512;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit

Suggested change
const size_t MAX_QUEUE_SIZE = 512;
constexpr size_t MAX_QUEUE_SIZE = 512;


bool dequeue_c10d_event(EventInfo& evt) {
std::unique_lock<std::mutex> lock(event_queue_lock);
if (event_queue.size() == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit

Suggested change
if (event_queue.size() == 0) {
if (event_queue.empty()) {

@wconstab
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 13, 2023
C++ side callbacks allow for advance users to get
access to the collective firehose.

It's worth mentioning and discussing the dire environment that those
callbacks are invoked. From either main thread of watchdog thread and
with a PTD lock held.
Pull Request resolved: #110307
Approved by: https://github.com/fduwjj
ghstack dependencies: #111061, #111072
@facebook-github-bot facebook-github-bot deleted the gh/wconstab/205/head branch October 16, 2023 14:25
pytorchmergebot added a commit that referenced this pull request Oct 16, 2023
This reverts commit 359336e.

Reverted #110307 on behalf of https://github.com/wconstab due to this sits on top of another PR #111072 that needs to be reverted due to internal release testing failure / multisect blame ([comment](#110307 (comment)))
wconstab added a commit to wconstab/pytorch that referenced this pull request Oct 16, 2023
…ytorch#2 "[C10] PG observability hooks. (pytorch#108815, pytorch#110907)" (pytorch#111072)" for test or build failures (pytorch#111393)

Summary:
This diff is reverting D50250526
D50250526: Reland pytorch#2 "[C10] PG observability hooks. (pytorch#108815, pytorch#110907)" (pytorch#111072) by wconstab has been identified to be causing the following test or build failures:

Tests affected:
- [cogwheel:cogwheel_ig_clips_tab_derived_feature_importance#test_ig_clips_tab_derived_feature_importance](https://www.internalfb.com/intern/test/844425021976403/)

Here's the Multisect link:
https://www.internalfb.com/multisect/3290230
Here are the tasks that are relevant to this breakage:

We're generating a revert to back out the changes in this diff, please note the backout may land if someone accepts it.

If you believe this diff has been generated in error you may Commandeer and Abandon it.


Test Plan: NA

Differential Revision: D50299914

Pulled By: wconstab
@facebook-github-bot
Copy link
Contributor

@pytorchbot revert -m="Diff reverted internally" -c="ghfirst"

This Pull Request has been reverted by a revert inside Meta. To re-land this change, please open another pull request, assign the same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk).)

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@wconstab your PR has been successfully reverted.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (c10d) release notes category Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants