Skip to content

Commit 5872a8c

Browse files
dvrogozhpytorchmergebot
authored andcommitted
Use task submitter TLS in gloo working threads (#142184)
Fixes: #86830 CC: @albanD Pull Request resolved: #142184 Approved by: https://github.com/albanD
1 parent 692b5e7 commit 5872a8c

3 files changed

Lines changed: 13 additions & 0 deletions

File tree

test/distributed/test_c10d_gloo.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,6 +1188,11 @@ def test_allgather_noncontiguous_input(self):
11881188
# Take a column of 2D tensor, such that memory is not dense
11891189
self._test_allgather_basics(lambda t: t.expand(2, 2).contiguous()[:, 0])
11901190

1191+
@requires_gloo()
1192+
def test_allgather_inference_mode(self):
1193+
with torch.inference_mode():
1194+
self._test_allgather_basics(lambda t: t.clone())
1195+
11911196
def _test_allgather_stress(self, inputs, fn):
11921197
store = c10d.FileStore(self.file_name, self.world_size)
11931198
pg = self._create_process_group_gloo(

torch/csrc/distributed/c10d/ProcessGroupGloo.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,7 @@ void ProcessGroupGloo::AsyncWork::execute(
426426
work->recordFunctionBeforeCallback_();
427427
}
428428
try {
429+
at::ThreadLocalStateGuard g(work->getTLS());
429430
work->run();
430431
} catch (...) {
431432
work->finishWorkGlooError(std::current_exception());

torch/csrc/distributed/c10d/ProcessGroupGloo.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
#include <torch/csrc/distributed/c10d/Types.hpp>
2323
#include <torch/csrc/distributed/c10d/Utils.hpp>
2424

25+
#include <ATen/ThreadLocalState.h>
26+
2527
namespace c10d {
2628

2729
constexpr const char* GLOO_BACKEND_NAME = "gloo";
@@ -73,6 +75,10 @@ class TORCH_API ProcessGroupGloo : public Backend {
7375
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
7476
uint64_t getSequencenumber() const override;
7577

78+
inline at::ThreadLocalState getTLS() const {
79+
return tls_;
80+
}
81+
7682
protected:
7783
friend class ProcessGroupGloo;
7884

@@ -87,6 +93,7 @@ class TORCH_API ProcessGroupGloo : public Backend {
8793
c10::intrusive_ptr<at::ivalue::Future> future_;
8894
std::function<void()> recordFunctionBeforeCallback_;
8995
const uint64_t seq_;
96+
at::ThreadLocalState tls_;
9097
};
9198

9299
// Wrap c10d store as Gloo store

0 commit comments

Comments
 (0)