Skip to content

[SPMD][PoC] compile & execute with PjRt#3684

Merged
JackCaoG merged 48 commits intomasterfrom
xla_spmd_pjrt_integration
Oct 17, 2022
Merged

[SPMD][PoC] compile & execute with PjRt#3684
JackCaoG merged 48 commits intomasterfrom
xla_spmd_pjrt_integration

Conversation

@yeounoh
Copy link
Copy Markdown
Contributor

@yeounoh yeounoh commented Jul 6, 2022

This is a follow-up to #3476 and contributes to #3871. The changes include:

  • Compile partitioned HLO computation graph with sharding annotations.
  • PjRtComputationClient integration to support SPMD sharded operations.
  • PjRtShardedData struct to represent sharded Data.
  • InputHandler for parameter sharding and sharded data transfer.
  • Remove duplicate copies of sharding annotations.
  • ExecuteReplicated for partitioned computation.

The PoC implementation supports replicated and tiled sharding annotations, and single-host xla:tpu backend. This enables a simple sharded computation on v3-8, like

    t1 = torch.randn(1, 128, device='cpu')
    t2 = torch.randn(1, 128, device='cpu')
    expected = t1 @ t2.T

    xt1 = t1.to(xm.xla_device())
    xt2 = t2.to(xm.xla_device())
    xs.mark_sharding(xt1, (1, 8), (0, 1))
    self.assertEqual('{devices=[1,8]0,1,2,3,4,5,6,7}',
                     torch_xla._XLAC._get_xla_sharding_spec(xt1))

    actual = (xt1 @ xt2.T).cpu()
    self.assertTrue(torch.allclose(expected, actual))

@yeounoh yeounoh added DO_NOT_MERGE Not for merging. distributed SPMD and other distributed things. labels Jul 6, 2022
@yeounoh yeounoh self-assigned this Jul 6, 2022
@yeounoh yeounoh marked this pull request as draft July 6, 2022 01:14
@yeounoh yeounoh force-pushed the xla_spmd_pjrt_integration branch 5 times, most recently from d38bd7d to 09f4640 Compare July 11, 2022 07:18
@yeounoh yeounoh force-pushed the xla_spmd_pjrt_integration branch from 09f4640 to 5e07428 Compare July 13, 2022 20:13
@yeounoh yeounoh force-pushed the xla_spmd_pjrt_integration branch 4 times, most recently from c9399ac to 91262bc Compare July 23, 2022 00:14
@yeounoh yeounoh force-pushed the xla_spmd_pjrt_integration branch 15 times, most recently from 0ca964c to c26f94b Compare July 26, 2022 04:28
@yeounoh
Copy link
Copy Markdown
Contributor Author

yeounoh commented Oct 13, 2022

CPU test passes, but the GPU fails with the following somewhat unrelated (at least on the outset) error:

*** Begin stack trace ***
	tsl::CurrentStackTrace[abi:cxx11]()
	
	
	gsignal
	abort
	
	xla::XrtLocalService::XrtLocalService(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, int)
	xla::XrtComputationClient::MaybeCreateLocalService(xla::XrtComputationClient::Options const&)
	xla::XrtComputationClient::XrtComputationClient(xla::XrtComputationClient::Options, std::unique_ptr<tensorflow::tpu::TopologyProto, std::default_delete<tensorflow::tpu::TopologyProto> >)
	xla::ComputationClient::Create()
	
	
	xla::ComputationClient::Get()
	
	
	_PyMethodDef_RawFastCallKeywords
	
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallKeywords
	
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallDict
	
	_PyObject_GenericGetAttrWithDict
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_FastCallKeywords
	
	_PyEval_EvalFrameDefault
	
	
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallDict
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallDict
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallKeywords
	
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallKeywords
	
	_PyEval_EvalFrameDefault
	_PyFunction_FastCallKeywords
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_FastCallKeywords
	
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	PyEval_EvalCode
	
	PyRun_StringFlags
	PyRun_SimpleStringFlags
	
	_Py_UnixMain
	__libc_start_main
	
*** End stack trace ***
Traceback (most recent call last):
  File "/tmp/pytorch/xla/test/test_torch_distributed_multi_all_reduce_xla_backend.py", line 38, in <module>
    xmp.spawn(_mp_fn, args=())
  File "/opt/conda/lib/python3.7/site-packages/torch_xla-1.14-py3.7-linux-x86_64.egg/torch_xla/distributed/xla_multiprocessing.py", line 399, in spawn
    start_method=start_method)
  File "/opt/conda/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File "/opt/conda/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 146, in join
    signal_name=name
torch.multiprocessing.spawn.ProcessExitedException: process 0 terminated with signal SIGABRT

The mater branch is green, though. cc @JackCaoG

@JackCaoG
Copy link
Copy Markdown
Collaborator

Seesm irrelevant, let me just restart the gpu ci

@JackCaoG
Copy link
Copy Markdown
Collaborator

I will take another pass and try to merge it.

@yeounoh
Copy link
Copy Markdown
Contributor Author

yeounoh commented Oct 13, 2022

Seesm irrelevant, let me just restart the gpu ci

Yea, this one succeeded. Thanks @JackCaoG

Comment thread test/test_xla_sharding.py
expected = t + t

xt = t.to(xm.xla_device())
n_devices = xm.xrt_world_size()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does CI run this test or we only run it on TPU?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only run cpp tests -- covers the internal changes that affects the non-spmd code paths -- and the python API tests are disabled link. I will re-enable it after debugging/ adding the API unit tests.

virtual void TransferToServer(absl::Span<const TensorSource> tensors,
absl::Span<const DataPtr> datas) = 0;

// Transfers local sharded tensor values to the TPU servers and returns a
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would use TPU Device instead of TPU Server, there is no server in PJRT context.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread torch_xla/csrc/tensor.cpp
void XLATensor::SetShardingSpec(const ShardingSpec& sharding_spec) {
XLA_CHECK(GetIrValue().node != nullptr) << "Tyring to access a null cursor";
dynamic_cast<XlaNode*>(data()->ir_value.node.get())
dynamic_cast<XlaNode*>(GetIrValue().node.get())
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, we should add a XlaNodeCast to replace dynamic_cast<XlaNode*> so it is cleaner

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I normally prefer more explicit type identifiers especially for casting (similar to avoid using auto too much).

Comment thread torch_xla/csrc/tensor.cpp
// TODO(yeounoh): Sharding annotation must be removed by explicit call to
// ClearSharding.
ShardingSpecPtr sharding = sharding_spec();
if (sharding != nullptr) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need a test for this. For example when we deep copy a tensor with sharding, the result tensor should also have sharding. Something similar to

y = copy.deepcopy(x)

@steventk-g can you add a test case?

Copy link
Copy Markdown
Collaborator

@steventk-g steventk-g Oct 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I've created an issue to track it #4095

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, @steventk-g let me handle this if you haven't already started.

Comment thread torch_xla/csrc/tensor.cpp

auto cached_computation = std::make_shared<CachedComputation>(
std::move(compile_result.computation));
std::move(compile_result.computation), compile_result.is_sharded);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need is_sharded separatelly in CachedComputation?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could pass around is_sharded between APIs, or wrap it inside the CachedComputation. Is sharded is later needed for the execution (will be associated with the cached computation only), and the latter doesn't require changing the function APIs here and there.

Copy link
Copy Markdown
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly LGTM, I had a question regarding ExecuteReplicated in #3684 (comment). If we can align on that this pr is ready to merge.

Copy link
Copy Markdown
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yeounoh ! I will merge this pr to unblock @steventk-g

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

distributed SPMD and other distributed things.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants