[Distributed] Make xm.all_gather a single graph in Dynamo#4922
[Distributed] Make xm.all_gather a single graph in Dynamo#4922alanwaketan merged 10 commits intomasterfrom
Conversation
|
|
||
|
|
||
| g_xrt_world_size = None | ||
| def xrt_world_size(defval=1): |
There was a problem hiding this comment.
@wconstab This is the python function that I want to use in 'allow_in_graph'.
There was a problem hiding this comment.
hmm, if you are going to manually cache the value of this anyway, then i think just using allow_in_graph without the caching is the same thing.
the issue with allow_in_graph is if you expect the value to be updated on later iterations, allow_in_graph will prevent that from working. But if you expect the value to be a constant for the whole execution, then allow_in_graph will capture the value during compile and reuse it later (e.g. cache it)
There was a problem hiding this comment.
I tried to use allow_in_graph. However, it looks like that the function I pass into allow_in_graph will need to return a tensor type? If the function return a bool or int, is there a workaround?
There was a problem hiding this comment.
Here is how I use allow_in_graph:
ptxla@t1v-n-307ffe96-w-0:/workspaces/work/pytorch/xla$ git diff
diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py
index 6ff4a5a5..a07ff472 100755
--- a/torch_xla/core/xla_model.py
+++ b/torch_xla/core/xla_model.py
@@ -6,6 +6,7 @@ import time
from typing import List, Optional
import torch
import torch.distributed._functional_collectives
+from torch._dynamo import allow_in_graph
import torch.nn.functional as F
import torch_xla
from torch_xla.experimental import pjrt
@@ -1088,3 +1089,6 @@ def optimization_barrier_(tensors):
tensors (List[torch.Tensor]): List of `torch.Tensor` to add barrier to.
"""
torch_xla._XLAC._xla_optimization_barrier_(tensors)
+
+
+allow_in_graph(xrt_world_size)
And here is the error:
root@t1v-n-307ffe96-w-0:/workspaces/work/pytorch/xla# PJRT_DEVICE=TPU python test/test_mp_all_gather.py
concurrent.futures.process._RemoteTraceback:
"""
Traceback (most recent call last):
File "/usr/local/lib/python3.8/concurrent/futures/process.py", line 239, in _process_worker
r = call_item.fn(*call_item.args, **call_item.kwargs)
File "/usr/local/lib/python3.8/concurrent/futures/process.py", line 198, in _process_chunk
return [fn(*args) for args in chunk]
File "/usr/local/lib/python3.8/concurrent/futures/process.py", line 198, in <listcomp>
return [fn(*args) for args in chunk]
File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 92, in wrapper
return fn(*args, **kwargs)
File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 245, in _run_thread_per_device
replica_results = list(
File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
yield fs.pop().result()
File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 444, in result
return self.__get_result()
File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
raise self._exception
File "/usr/local/lib/python3.8/concurrent/futures/thread.py", line 57, in run
result = self.fn(*self.args, **self.kwargs)
File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 238, in _thread_fn
return fn()
File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 341, in __call__
self.fn(global_ordinal(), *self.args, **self.kwargs)
File "/workspaces/work/pytorch/xla/test/test_mp_all_gather.py", line 32, in _mp_fn
result = compiled_all_gather(ordinal_tensor, dim=0)
File "/workspaces/work/pytorch/torch/_dynamo/eval_frame.py", line 252, in _fn
return fn(*args, **kwargs)
File "/workspaces/work/pytorch/torch/_dynamo/eval_frame.py", line 405, in catch_errors
return callback(frame, cache_size, hooks, frame_state)
File "/workspaces/work/pytorch/torch/_dynamo/convert_frame.py", line 122, in _fn
return fn(*args, **kwargs)
File "/workspaces/work/pytorch/torch/_dynamo/convert_frame.py", line 331, in _convert_frame_assert
return _compile(
File "/workspaces/work/pytorch/torch/_dynamo/utils.py", line 169, in time_wrapper
r = func(*args, **kwargs)
File "/workspaces/work/pytorch/torch/_dynamo/convert_frame.py", line 401, in _compile
out_code = transform_code_object(code, transform)
File "/workspaces/work/pytorch/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object
transformations(instructions, code_options)
File "/workspaces/work/pytorch/torch/_dynamo/convert_frame.py", line 386, in transform
tracer.run()
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1972, in run
super().run()
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 670, in run
and self.step()
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 630, in step
getattr(self, inst.opname)(inst)
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 355, in wrapper
return inner_fn(self, inst)
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1138, in CALL_FUNCTION_KW
self.call_function(fn, args, kwargs)
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 521, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/workspaces/work/pytorch/torch/_dynamo/variables/functions.py", line 269, in call_function
return super().call_function(tx, args, kwargs)
File "/workspaces/work/pytorch/torch/_dynamo/variables/functions.py", line 102, in call_function
return tx.inline_user_function_return(
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 557, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 2077, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 2155, in inline_call_
tracer.run()
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 670, in run
and self.step()
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 630, in step
getattr(self, inst.opname)(inst)
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 355, in wrapper
return inner_fn(self, inst)
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1138, in CALL_FUNCTION_KW
self.call_function(fn, args, kwargs)
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 521, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/workspaces/work/pytorch/torch/_dynamo/variables/functions.py", line 269, in call_function
return super().call_function(tx, args, kwargs)
File "/workspaces/work/pytorch/torch/_dynamo/variables/functions.py", line 102, in call_function
return tx.inline_user_function_return(
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 557, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 2077, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 2155, in inline_call_
tracer.run()
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 670, in run
and self.step()
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 630, in step
getattr(self, inst.opname)(inst)
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 355, in wrapper
return inner_fn(self, inst)
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1086, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/workspaces/work/pytorch/torch/_dynamo/symbolic_convert.py", line 521, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/workspaces/work/pytorch/torch/_dynamo/variables/torch.py", line 603, in call_function
tensor_variable = wrap_fx_proxy(
File "/workspaces/work/pytorch/torch/_dynamo/variables/builder.py", line 923, in wrap_fx_proxy
return wrap_fx_proxy_cls(
File "/workspaces/work/pytorch/torch/_dynamo/variables/builder.py", line 1098, in wrap_fx_proxy_cls
unimplemented(
File "/workspaces/work/pytorch/torch/_dynamo/exc.py", line 107, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function <function xrt_world_size at 0x7fb184e94ca0>
from user code:
File "/workspaces/work/pytorch/xla/torch_xla/core/xla_model.py", line 550, in all_gather
return _all_gather_using_all_reduce(
File "/workspaces/work/pytorch/xla/torch_xla/core/xla_model.py", line 511, in _all_gather_using_all_reduce
left, right = ordinal, xrt_world_size() - 1 - ordinal
Set torch._dynamo.config.verbose=True or TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "test/test_mp_all_gather.py", line 66, in <module>
xmp.spawn(_mp_fn, args=())
File "/workspaces/work/pytorch/xla/torch_xla/distributed/xla_multiprocessing.py", line 367, in spawn
return pjrt.spawn(fn, nprocs, start_method, args)
File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 365, in spawn
_run_multiprocess(spawn_fn, start_method=start_method)
File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 92, in wrapper
return fn(*args, **kwargs)
File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 322, in _run_multiprocess
replica_results = list(
File "/workspaces/work/pytorch/xla/torch_xla/experimental/pjrt.py", line 323, in <genexpr>
itertools.chain.from_iterable(
File "/usr/local/lib/python3.8/concurrent/futures/process.py", line 484, in _chain_from_iterable_of_lists
for element in iterable:
File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
yield fs.pop().result()
File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 444, in result
return self.__get_result()
File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
raise self._exception
torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function <function xrt_world_size at 0x7fb184e94ca0>
from user code:
File "/workspaces/work/pytorch/xla/torch_xla/core/xla_model.py", line 550, in all_gather
return _all_gather_using_all_reduce(
File "/workspaces/work/pytorch/xla/torch_xla/core/xla_model.py", line 511, in _all_gather_using_all_reduce
left, right = ordinal, xrt_world_size() - 1 - ordinal
Set torch._dynamo.config.verbose=True or TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True
root@t1v-n-307ffe96-w-0:/workspaces/work/pytorch/xla#
| return g_xrt_world_size | ||
|
|
||
| g_ordinal = None | ||
| def get_ordinal(defval=0): |
There was a problem hiding this comment.
@wconstab This is the python function that I want to use in 'allow_in_graph'.
7673dd6 to
89f7c6c
Compare
8681385 to
674e53c
Compare
| """ | ||
| if pjrt.using_pjrt(): | ||
| return pjrt.global_ordinal() | ||
| global g_ordinal |
There was a problem hiding this comment.
I think this will break PJRT + v3 cases, the implementation we had checks the devices in
m.def("_xla_get_default_device_ordinal", []() {
std::string device_str = GetCurrentThreadDevice();
torch::lazy::BackendDevice device =
bridge::AtenDeviceToXlaDevice(device_str);
return device.ordinal();
});
There was a problem hiding this comment.
Hmm, this is confusing. That call is in the C++ layer. Then allow_in_graph won't work here.
But we can work around by caching a map...
There was a problem hiding this comment.
I am not sure actually, effectively this function won't return constant in the v3 cases because there are two devices per process. This is a bit tricky.
There was a problem hiding this comment.
I wonder if we can bypass the v3 cases for now, what's going to happen if you add a condition here to skip this cahce value of we are on v3 + PJRT?
There was a problem hiding this comment.
It will introduce graph breaks in Dynamo.
There was a problem hiding this comment.
Should this use thread local storage instead?
There was a problem hiding this comment.
That's cool. Was not aware python has this feature. Let me work on it.
There was a problem hiding this comment.
Dynamo doesn't seem to compile in the same thread as the user code. threading.local doesn't work here.
| participating replicas. | ||
| """ | ||
| if pin_layout and xla_device_hw( | ||
| value.device) in ('TPU', 'GPU', 'XPU') and output == None: |
There was a problem hiding this comment.
I think we had it because CPU was not supported at some point. Do you need to remove it because it will break dynamo?
3fb46da to
26cfb00
Compare
|
Thanks Jack for approving. |
Summary:
This pull request makes xm.all_gather, the _all_gather_using_all_reduce path, a single graph in Dynamo. To do that, it:
Test Plan:
PJRT_DEVICE=TPU python test/test_mp_all_gather.py