Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# We export some functions and classes for test_jit.py directly from libtorch.so,
# it's not important to have BC for them
('_TorchScriptTesting.*', datetime.date(9999, 1, 1)),
('profiler::_call_end_callbacks_on_jit_fut*', datetime.date(9999, 1, 1)),
('aten::append*', datetime.date(2020, 4, 15)),
('aten::real*', datetime.date(2020, 4, 15)),
('aten::imag*', datetime.date(2020, 4, 15)),
Expand Down
8 changes: 7 additions & 1 deletion torch/autograd/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,11 @@ def _call_end_callbacks_on_future(self, fut):
Arguments:
fut: (torch._C.Future): future for which to schedule
callback for.

Returns:
A future that completes with the value of the passed in future when
the profiling callbacks have ran.

"""
# Throw if we have already attached a callback onto the future.
if not self.run_callbacks_on_exit:
Expand All @@ -399,7 +404,8 @@ def _call_end_callbacks_on_future(self, fut):
# We are scheduling to run this RecordFunction's end callbacks when the
# passed in future completes, so don't run end callbacks on exit.
self.run_callbacks_on_exit = False
torch.ops.profiler._call_end_callbacks_on_jit_fut(self.handle, fut)
profiled_future = torch.ops.profiler._call_end_callbacks_on_jit_fut(self.handle, fut)
return profiled_future


class emit_nvtx(object):
Expand Down
28 changes: 18 additions & 10 deletions torch/csrc/autograd/record_function_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,15 @@ void record_function_exit(const at::Tensor& handle) {
rec._end();
}

void _call_end_callbacks_on_fut(
c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut(
const at::Tensor& handle,
const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
// Save and pass thread local state into the callback
at::ThreadLocalState tls_state;
// Add a callback onto the future to mark run RecordFunction's end callbacks
// when the future is completed.
fut->addCallback(
// Copy handle and tls_state by value to persist after the python
// context manager is exited.
[handle, tls_state = std::move(tls_state)]() {
// Profiling callback that ends the associated record_function
// and returns the value of the passed in future.
std::function<c10::IValue(void)> futureProfilingFunc =
[fut, handle, tls_state = std::move(tls_state)]() {
TORCH_INTERNAL_ASSERT(
handle.defined(),
"Undefined RecordFunction handle. This can happen if the handle is "
Expand All @@ -68,7 +66,15 @@ void _call_end_callbacks_on_fut(
at::ThreadLocalStateGuard g(tls_state);
auto& rec = getRecordFunctionFromTensor(handle);
rec._end();
});
// Note: this future is returned to the user to ensure that a call to wait()
// ensures that profiling callbacks have ran. To ensure that this is
// transparent, we must make this future propagate the value of the RPC
// future.
return fut->constValue();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment to explain why we need to forward the return value here. IIUC, it is because _invoke_rpc actually returns this future?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we get the future which has the actual result from _invoke_rpc. Then, if profiling we run this, and we want to ensure that the correct value of the future from _invoke_rpc is propagated. Will add this to the comment

};
// Define a future that completes after the profiling callbacks are run.
auto profiledFut = fut->then(futureProfilingFunc, fut->type());
return profiledFut;
}

// Internal only, do not use directly, use Python's record_function()
Expand All @@ -84,12 +90,14 @@ c10::AliasAnalysisKind aliasAnalysisFromSchema() {

jit::RegisterOperators reg_fut_ops({
jit::Operator(
"profiler::_call_end_callbacks_on_jit_fut(Tensor x, Future(t) y) -> ()",
"profiler::_call_end_callbacks_on_jit_fut(Tensor x, Future(t) y) -> Future(t)",
[](jit::Stack& stack) {
// Pop inputs, which should be a future and a tensor
auto fut = jit::pop(stack).toFuture();
auto tensor = jit::pop(stack).toTensor();
_call_end_callbacks_on_fut(tensor, fut);
auto profiledFut = _call_end_callbacks_on_fut(tensor, fut);
// return future that completes when profiling callbacks have run.
jit::push(stack, std::move(profiledFut));
return 0;
},
aliasAnalysisFromSchema()),
Expand Down
23 changes: 23 additions & 0 deletions torch/csrc/distributed/rpc/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,29 @@ PyObject* rpc_init(PyObject* /* unused */) {
on the remote node. This is for internal use cases such as profiling
only.
)")
.def(
"_get_profiling_future",
[](const PyRRef& self) {
return std::make_shared<jit::PythonFutureWrapper>(
self.getProfilingFuture());
},
py::call_guard<py::gil_scoped_acquire>(),
R"(
Returns future that completes when the profiling event corresponding
to the creation of this RRef on the remote node has been recorded.
)")
.def(
"_set_profiling_future",
[](PyRRef& self,
const std::shared_ptr<jit::PythonFutureWrapper>&
wrappedFuture) {
self.setProfilingFuture(wrappedFuture->fut);
},
py::call_guard<py::gil_scoped_acquire>(),
R"(
Set future that is completed when the profiling event corresponding
to the creation of this RRef on the remote node has been recorded.
)")
// not releasing GIL to avoid context switch
.def("__str__", &PyRRef::str);

Expand Down
12 changes: 11 additions & 1 deletion torch/csrc/distributed/rpc/py_rref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ TypePtr tryInferTypeWithTypeHint(

/////////////////////////// PyRRef //////////////////////////////////

PyRRef::PyRRef(c10::intrusive_ptr<RRef> rref) : rref_(std::move(rref)) {
PyRRef::PyRRef(c10::intrusive_ptr<RRef> rref)
: rref_(std::move(rref)), profilingFuture_(c10::nullopt) {
TORCH_CHECK(rref_, "PyRRef must not wrap nullptr");
}

Expand All @@ -122,6 +123,15 @@ c10::intrusive_ptr<JitFuture> PyRRef::getFuture() const {
rref_->getOwnerCreationFuture(), false /* hasValue */);
}

c10::intrusive_ptr<JitFuture> PyRRef::getProfilingFuture() const {
TORCH_INTERNAL_ASSERT(profilingFuture_, "Profiling future has not been set!");
return *profilingFuture_;
}

void PyRRef::setProfilingFuture(c10::intrusive_ptr<JitFuture> profilingFuture) {
profilingFuture_ = std::move(profilingFuture);
}

bool PyRRef::isOwner() const {
return rref_->isOwner();
}
Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/distributed/rpc/py_rref.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,21 @@ class PyRRef {
// This is only used to get the future corresponding to the rref for profiling
// use cases.
c10::intrusive_ptr<JitFuture> getFuture() const;
// Keeps track of the future responsible for profiling owner creation
// acknowledgement
c10::intrusive_ptr<JitFuture> getProfilingFuture() const;
// Sets the future responsible for profiling owner creation acknowledgement.
// This future is set from python to be a future that returns when profiling
// callbacks have been run.
void setProfilingFuture(c10::intrusive_ptr<JitFuture> profilingFuture);

// create a proxy on this RRef, which can be used to launch RPC on the owner
// of this RRef to run functions on the object referenced by this RRef.
py::object createRRefProxy(const RRefProxyType& mode) const;

private:
c10::intrusive_ptr<RRef> rref_;
c10::optional<c10::intrusive_ptr<JitFuture>> profilingFuture_;
};

} // namespace rpc
Expand Down
9 changes: 7 additions & 2 deletions torch/distributed/rpc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,8 @@ def remote(to, func, args=None, kwargs=None):
if should_profile:
assert torch.autograd._profiler_enabled()
assert rf is not None
rf._call_end_callbacks_on_future(rref._get_future())
fut = rf._call_end_callbacks_on_future(rref._get_future())
rref._set_profiling_future(fut)

return rref

Expand Down Expand Up @@ -527,7 +528,11 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RP
assert torch.autograd._profiler_enabled()
assert rf is not None
# Schedule profiling callbacks to run when the future completes.
rf._call_end_callbacks_on_future(fut)
# This returns a future that is completed when the original future
# completes and the profiling callbacks have been completed as well,
# to guarantee that fut.wait() completes the profiling. This new
# future will contain the same value as the original future.
fut = rf._call_end_callbacks_on_future(fut)
return fut


Expand Down
4 changes: 0 additions & 4 deletions torch/testing/_internal/distributed/rpc/jit/rpc_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Dict, Tuple
import unittest

import torch
import time
Expand Down Expand Up @@ -994,7 +993,6 @@ def callback(fut):
with self.assertRaisesRegex(RuntimeError, "Another expected error"):
future.wait()

@unittest.skip("RPC profiling tests are flaky, see https://github.com/pytorch/pytorch/issues/37557")
@dist_init
def test_call_rpc_with_profiling(self):
# Ensures that we can call torch.ops.profiler._call_end_callbacks_on_jit_fut on a jit
Expand All @@ -1017,7 +1015,6 @@ def test_call_rpc_with_profiling(self):
function_event = get_function_event(events, prof_key)
self.assertTrue(torch.jit._qualified_name(one_arg) in function_event.name)

@unittest.skip("RPC profiling tests are flaky, see https://github.com/pytorch/pytorch/issues/37557")
def test_record_function_jit_end_callbacks_with_fork(self):
# Ensures that we can call rf._call_end_callbacks_on_future on a jit
# future in python eager mode with torch.jit.fork
Expand All @@ -1035,7 +1032,6 @@ def test_record_function_jit_end_callbacks_with_fork(self):
# profiling event cpu time
self.assertGreaterEqual(sleep_event.cpu_time * 1e-6, sleep_interval)

@unittest.skip("RPC profiling tests are flaky, see https://github.com/pytorch/pytorch/issues/37557")
def test_call_fork_in_jit_with_profiling(self):
# Ensures that we can call torch.ops.profiler._call_end_callbacks_on_jit_fut on a jit
# future from within a script function with torch.jit.fork
Expand Down
32 changes: 11 additions & 21 deletions torch/testing/_internal/distributed/rpc/rpc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,14 +844,11 @@ def _profiler_test_with_rpc(self, rpc_exec_mode, func, args, use_record_function
self.assertTrue(rpc_exec_mode == RPCExecMode.REMOTE)
rref = rpc.remote(worker_name(dst), func, args=args)
rref.to_here()
# We need to wait for the instance to be created on
# the owner, and get back a positive confirmation.
# Calling to_here does not ensure that we have finished
# processing the Owner's confirmation of this RRef. To do
# this, we wait until the current RRef context doesn't have
# any pending users, which indicates that the confirmation
# was processed on this worker.
wait_until_pending_users_flushed()
# To avoid flakiness, wait for the RRef to be profiled. This
# means that we received the acknowledgement of successful
# creation on the owner and ran the callbacks responsible
# for recording the profiling event.
rref._get_profiling_future().wait()
if use_record_function:
record_function.__exit__()

Expand Down Expand Up @@ -882,14 +879,12 @@ def _profiler_test_with_rpc(self, rpc_exec_mode, func, args, use_record_function
rpc_event_idx = next(i for i, event in enumerate(events) if rpc_exec_mode.value in event.name)
self.assertLess(foo_event_ix, rpc_event_idx)

@unittest.skip("RPC profiling tests are flaky, see https://github.com/pytorch/pytorch/issues/37557")
@dist_init
def test_profiler_with_sync_rpc_udf(self):
self._profiler_test_with_rpc(RPCExecMode.SYNC, my_sleep_func, args=(1,))
self._profiler_test_with_rpc(RPCExecMode.SYNC, my_sleep_func, args=(1,),
use_record_function=True)

@unittest.skip("RPC profiling tests are flaky, see https://github.com/pytorch/pytorch/issues/37557")
@dist_init
def test_profiler_with_sync_rpc_builtin(self):
self._profiler_test_with_rpc(
Expand All @@ -900,14 +895,12 @@ def test_profiler_with_sync_rpc_builtin(self):
use_record_function=True
)

@unittest.skip("RPC profiling tests are flaky, see https://github.com/pytorch/pytorch/issues/37557")
@dist_init
def test_profiler_with_async_rpc_udf(self):
self._profiler_test_with_rpc(RPCExecMode.ASYNC, my_sleep_func, args=(1,))
self._profiler_test_with_rpc(RPCExecMode.ASYNC, my_sleep_func, args=(1,),
use_record_function=True)

@unittest.skip("RPC profiling tests are flaky, see https://github.com/pytorch/pytorch/issues/37557")
@dist_init
def test_profiler_with_async_rpc_builtin(self):
self._profiler_test_with_rpc(
Expand All @@ -918,14 +911,12 @@ def test_profiler_with_async_rpc_builtin(self):
use_record_function=True
)

@unittest.skip("RPC profiling tests are flaky, see https://github.com/pytorch/pytorch/issues/37557")
@dist_init
def test_profiler_with_remote_udf(self):
self._profiler_test_with_rpc(RPCExecMode.REMOTE, my_sleep_func, args=(1,))
self._profiler_test_with_rpc(RPCExecMode.REMOTE, my_sleep_func, args=(1,),
use_record_function=True)

@unittest.skip("RPC profiling tests are flaky, see https://github.com/pytorch/pytorch/issues/37557")
@dist_init
def test_profiler_with_remote_builtin(self):
self._profiler_test_with_rpc(
Expand All @@ -936,7 +927,6 @@ def test_profiler_with_remote_builtin(self):
use_record_function=True
)

@unittest.skip("RPC profiling tests are flaky, see https://github.com/pytorch/pytorch/issues/37557")
@dist_init
def test_profiler_with_script_async_rpc(self):
self._profiler_test_with_rpc(
Expand All @@ -949,7 +939,6 @@ def test_profiler_with_script_async_rpc(self):
use_record_function=True,
)

@unittest.skip("RPC profiling tests are flaky, see https://github.com/pytorch/pytorch/issues/37557")
@dist_init
def test_profiler_with_script_sync_rpc(self):
self._profiler_test_with_rpc(
Expand All @@ -962,7 +951,6 @@ def test_profiler_with_script_sync_rpc(self):
use_record_function=True,
)

@unittest.skip("RPC profiling tests are flaky, see https://github.com/pytorch/pytorch/issues/37557")
@dist_init
def test_profiler_with_script_remote_rpc(self):
self._profiler_test_with_rpc(
Expand All @@ -975,7 +963,6 @@ def test_profiler_with_script_remote_rpc(self):
use_record_function=True,
)

@unittest.skip("RPC profiling tests are flaky, see https://github.com/pytorch/pytorch/issues/37557")
@dist_init
def test_async_record_function_double_end_callbacks(self):
num_sleep_seconds = 1
Expand All @@ -993,7 +980,6 @@ def test_async_record_function_double_end_callbacks(self):
rf._call_end_callbacks_on_future(fut)
fut.wait()

@unittest.skip("RPC profiling tests are flaky, see https://github.com/pytorch/pytorch/issues/37557")
@dist_init
def test_async_record_function_cbs_jit_call(self):
if self.rank == 1:
Expand All @@ -1009,8 +995,12 @@ def test_async_record_function_cbs_jit_call(self):
worker_name(0), my_script_func, args=(torch.tensor(1),)
)
# Intentionally calling record_function internals
torch.ops.profiler._call_end_callbacks_on_jit_fut(rf.handle, fut)
fut.wait()
fut = torch.ops.profiler._call_end_callbacks_on_jit_fut(rf.handle, fut)
result = fut.wait()
# Validate that the profiling future returns the same value as the RPC
# future.
expected = torch.add(torch.tensor(1), torch.tensor(1))
self.assertEqual(result, expected)
events = pf.function_events
rpc_event = get_function_event(
events, torch.jit._qualified_name(my_script_func)
Expand Down