Skip to content

Commit 20143e5

Browse files
suofacebook-github-bot
authored andcommitted
Revert D21245094: [resubmit] Enable global observers API
Test Plan: revert-hammer Differential Revision: D21245094 Original commit changeset: 595e41b18206 fbshipit-source-id: 90344b361857d76ce5db75438c949dad1f5f186b
1 parent d294c06 commit 20143e5

8 files changed

Lines changed: 53 additions & 58 deletions

File tree

aten/src/ATen/ThreadLocalState.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ ThreadLocalState::ThreadLocalState(bool keep_grad_mode)
1515
grad_mode_enabled_ = GradMode::is_enabled();
1616
}
1717
#endif
18+
record_function_enabled_ = _tls_is_record_function_enabled();
1819
}
1920

2021
/* static */
@@ -26,9 +27,19 @@ void ThreadLocalState::setThreadLocalState(
2627
}
2728
#endif
2829

30+
c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_);
31+
2932
ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_);
3033

31-
c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_);
34+
_tls_set_record_function_enabled(state.record_function_enabled_);
35+
}
36+
37+
thread_local bool is_record_function_enabled_ = true;
38+
bool _tls_is_record_function_enabled() {
39+
return is_record_function_enabled_;
40+
}
41+
void _tls_set_record_function_enabled(bool is_enabled) {
42+
is_record_function_enabled_ = is_enabled;
3243
}
3344

3445
} // namespace at

aten/src/ATen/ThreadLocalState.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ class TORCH_API ThreadLocalState {
3434
bool grad_mode_enabled_;
3535
#endif
3636

37+
// Whether RecordFunctions need to be disabled;
38+
// used in core PyTorch to avoid infitite recursion
39+
// in observers framework
40+
bool record_function_enabled_;
41+
3742
friend class ThreadLocalStateGuard;
3843
};
3944

@@ -55,4 +60,8 @@ class TORCH_API ThreadLocalStateGuard {
5560
const ThreadLocalState prev_state_;
5661
};
5762

63+
// Internal, turns on/off record function observers
64+
TORCH_API bool _tls_is_record_function_enabled();
65+
TORCH_API void _tls_set_record_function_enabled(bool);
66+
5867
} // namespace at

c10/util/StringUtil.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,15 @@ struct _str_wrapper<const char*> final {
7070

7171
// For c10::str() with an empty argument list (which is common in our assert macros),
7272
// we don't want to pay the binary size for constructing and destructing a stringstream
73-
// or even constructing a string. Let's just return a reference to an empty string.
73+
// or even constructing a string. Let's just return a reference to a global empty string.
74+
#pragma GCC diagnostic push
75+
#pragma GCC diagnostic ignored "-Wpragmas"
76+
#pragma GCC diagnostic ignored "-Wglobal-constructors"
77+
const std::string empty_string_literal;
78+
#pragma GCC diagnostic pop
7479
template<>
7580
struct _str_wrapper<> final {
7681
static const std::string& call() {
77-
thread_local const std::string empty_string_literal;
7882
return empty_string_literal;
7983
}
8084
};

docs/source/notes/large_scale_deployments.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,6 @@ Here's an example:
6161
/* needs_inputs */ true,
6262
/* sampling_prob */ 0.01
6363
);
64-
// Note, to enable observers in the model calling thread,
65-
// call enableObservers() in the thread before running a model
6664
}
6765
6866
bool onFunctionEnter(const RecordFunction& fn) {

torch/csrc/autograd/profiler.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,6 @@ std::unordered_map<uint16_t, std::shared_ptr<RangeEventList>>
3232
thread_local std::shared_ptr<RangeEventList> event_list;
3333
thread_local uint16_t thread_id;
3434

35-
// use RecordFunctionGuard to keep track of observers,
36-
// enable/disableProfiler are tied to the code range
37-
thread_local std::vector<std::shared_ptr<RecordFunctionGuard>> g_;
38-
3935
} // namespace
4036

4137
void registerCUDAMethods(CUDAStubs* stubs) {
@@ -201,7 +197,7 @@ void enableProfiler(ProfilerConfig config) {
201197
/* sampling_prob */ 1.0,
202198
/* scopes */ {RecordScope::FUNCTION, RecordScope::USER_SCOPE});
203199
state = new_state;
204-
g_.emplace_back();
200+
c10::impl::tls_set_dispatch_key_included(c10::DispatchKey::Profiler, true);
205201

206202
if(state == ProfilerState::CUDA) {
207203
// event recording appears to have some startup overhead, so we need to
@@ -232,8 +228,7 @@ thread_event_lists disableProfiler() {
232228

233229
popCallback();
234230
state = ProfilerState::Disabled;
235-
TORCH_INTERNAL_ASSERT(!g_.empty());
236-
g_.pop_back();
231+
c10::impl::tls_set_dispatch_key_included(c10::DispatchKey::Profiler, false);
237232

238233
if (old_state == ProfilerState::NVTX) {
239234
return thread_event_lists();

torch/csrc/autograd/record_function.cpp

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -212,21 +212,13 @@ void popCallback() {
212212
manager().popCallback();
213213
}
214214

215-
bool observersEnabled() {
216-
return c10::impl::tls_is_dispatch_key_included(c10::DispatchKey::Profiler);
217-
}
218-
219-
void enableObservers(bool enable) {
220-
c10::impl::tls_set_dispatch_key_included(c10::DispatchKey::Profiler, enable);
221-
}
222-
223215
void _runBeforeCallbacks(RecordFunction* rf, const std::string& funcName) {
224216
TORCH_INTERNAL_ASSERT(rf != nullptr);
225217
rf->_before(funcName);
226218
}
227219

228220
RecordFunction::RecordFunction(RecordScope scope) : scope_(scope) {
229-
if (manager().hasCallbacks() && observersEnabled()) {
221+
if (manager().hasCallbacks() && at::_tls_is_record_function_enabled()) {
230222
active_ = true;
231223
}
232224
}

torch/csrc/autograd/record_function.h

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,25 @@ struct TORCH_API RecordFunction {
220220
uint64_t callbacks_version_ = 0;
221221
};
222222

223+
class TORCH_API RecordFunctionGuard {
224+
public:
225+
explicit RecordFunctionGuard(bool is_enabled)
226+
: prev_value_(at::_tls_is_record_function_enabled()) {
227+
at::_tls_set_record_function_enabled(is_enabled);
228+
}
229+
virtual ~RecordFunctionGuard() {
230+
at::_tls_set_record_function_enabled(prev_value_);
231+
}
232+
private:
233+
bool prev_value_ = false;
234+
};
235+
236+
class TORCH_API DisableRecordFunctionGuard : public RecordFunctionGuard {
237+
public:
238+
DisableRecordFunctionGuard() : RecordFunctionGuard(false) {}
239+
virtual ~DisableRecordFunctionGuard() {}
240+
};
241+
223242
// Returns whether there're callbacks registered with pushCallback
224243
TORCH_API bool hasCallbacks();
225244

@@ -293,32 +312,5 @@ TORCH_API void pushCallback(
293312
*/
294313
TORCH_API void popCallback();
295314

296-
// Enable observers thread locally
297-
TORCH_API void enableObservers(bool enable = true);
298-
299-
// Returns whether observers are enabled (thread locally)
300-
TORCH_API bool observersEnabled();
301-
302-
class TORCH_API RecordFunctionGuard {
303-
public:
304-
explicit RecordFunctionGuard(bool is_enabled = true)
305-
: prev_value_(observersEnabled()) {
306-
enableObservers(is_enabled);
307-
}
308-
309-
virtual ~RecordFunctionGuard() {
310-
enableObservers(prev_value_);
311-
}
312-
313-
private:
314-
bool prev_value_ = false;
315-
};
316-
317-
class TORCH_API DisableRecordFunctionGuard : public RecordFunctionGuard {
318-
public:
319-
DisableRecordFunctionGuard() : RecordFunctionGuard(false) {}
320-
virtual ~DisableRecordFunctionGuard() {}
321-
};
322-
323315
} // namespace profiler
324316
}} // namespace torch::autograd

torch/csrc/jit/mobile/module.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
#include <torch/csrc/jit/mobile/observer.h>
66
#endif
77

8-
#include <torch/csrc/autograd/record_function.h>
9-
108
namespace torch {
119
namespace jit {
1210
std::ostream& operator<<(std::ostream& out, Instruction inst);
@@ -48,14 +46,10 @@ c10::IValue Module::run_method(const std::string& method_name, Stack stack) {
4846
at::DebugInfoGuard guard(at::DebugInfoKind::MOBILE_RUNTIME_INFO, debug_info);
4947
#endif
5048

51-
c10::IValue result;
52-
{
53-
torch::autograd::profiler::RecordFunctionGuard g;
54-
auto m = find_method(method_name);
55-
stack.insert(stack.begin(), object_);
56-
m->run(stack);
57-
result = stack.front();
58-
}
49+
auto m = find_method(method_name);
50+
stack.insert(stack.begin(), object_);
51+
m->run(stack);
52+
c10::IValue result = stack.front();
5953

6054
#if defined(PYTORCH_MOBILE_OBSERVER)
6155
if (observer) {

0 commit comments

Comments
 (0)