@@ -220,6 +220,25 @@ struct TORCH_API RecordFunction {
220220 uint64_t callbacks_version_ = 0 ;
221221};
222222
223+ class TORCH_API RecordFunctionGuard {
224+ public:
225+ explicit RecordFunctionGuard (bool is_enabled)
226+ : prev_value_(at::_tls_is_record_function_enabled()) {
227+ at::_tls_set_record_function_enabled (is_enabled);
228+ }
229+ virtual ~RecordFunctionGuard () {
230+ at::_tls_set_record_function_enabled (prev_value_);
231+ }
232+ private:
233+ bool prev_value_ = false ;
234+ };
235+
236+ class TORCH_API DisableRecordFunctionGuard : public RecordFunctionGuard {
237+ public:
238+ DisableRecordFunctionGuard () : RecordFunctionGuard(false ) {}
239+ virtual ~DisableRecordFunctionGuard () {}
240+ };
241+
223242// Returns whether there're callbacks registered with pushCallback
224243TORCH_API bool hasCallbacks ();
225244
@@ -293,32 +312,5 @@ TORCH_API void pushCallback(
293312 */
294313TORCH_API void popCallback ();
295314
296- // Enable observers thread locally
297- TORCH_API void enableObservers (bool enable = true );
298-
299- // Returns whether observers are enabled (thread locally)
300- TORCH_API bool observersEnabled ();
301-
302- class TORCH_API RecordFunctionGuard {
303- public:
304- explicit RecordFunctionGuard (bool is_enabled = true )
305- : prev_value_(observersEnabled()) {
306- enableObservers (is_enabled);
307- }
308-
309- virtual ~RecordFunctionGuard () {
310- enableObservers (prev_value_);
311- }
312-
313- private:
314- bool prev_value_ = false ;
315- };
316-
317- class TORCH_API DisableRecordFunctionGuard : public RecordFunctionGuard {
318- public:
319- DisableRecordFunctionGuard () : RecordFunctionGuard(false ) {}
320- virtual ~DisableRecordFunctionGuard () {}
321- };
322-
323315} // namespace profiler
324316}} // namespace torch::autograd
0 commit comments