Skip to content

[WIP] Make some hard-coded changes for activation sharding#6161

Closed
wonjoo-wj wants to merge 2 commits intomasterfrom
wonjoo/activation-sharding-update
Closed

[WIP] Make some hard-coded changes for activation sharding#6161
wonjoo-wj wants to merge 2 commits intomasterfrom
wonjoo/activation-sharding-update

Conversation

@wonjoo-wj
Copy link
Copy Markdown
Collaborator

[WIP] Make some hard-coded changes for activation sharding

TODOs

  • Implement a way to properly remove the @xr.requires_pjrt
    • Possibly caching the device type after the first call in runtime.py
  • Change asserts with torch dynamo conditionals

Copy link
Copy Markdown
Collaborator Author

@wonjoo-wj wonjoo-wj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Offline sync with @JackCaoG:

  • np.unique -> write it pure Python? Or just use torch.unique? Same for all.
  • requires_pjrt -> cache the env var once in the runtime init.
  • Is num_devices only used for assertion? If so, can we skip this check in Dynamo code flow only?
  • global_runtime_device_count -- can we also cache this?
  • To check for correctness, look at HLO sharding.

TODOs

  • Make these changes and update the PR (by Friday or Monday). Make sure it does not break existing SPMD implementation and LLaMa 2 inference run (w/ activation sharding).
  • And then check the correctness throughout next week.

cc @yeounoh

@wonjoo-wj
Copy link
Copy Markdown
Collaborator Author

While I was making some of the mentioned fixes and re-running the tests and LLaMa 2, I was seeing bunch of complaints from Dynamo saying it can't track much of the portion in mark_sharding.py, such as:

torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable(Mesh) _get_op_sharding_args [TupleVariable()] {}
...
x86_64.egg/torch_xla/distributed/spmd/xla_sharding.py", line 518, in mark_sharding
    unwrap_sharded_tensor(t), *mesh._get_op_sharding_args(partition_spec))

And also:

torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [TensorVariable()] {}
...
x86_64.egg/torch_xla/distributed/spmd/xla_sharding.py", line 558, in wrap_as_sharded_tensor
    return XLAShardedTensor(t)

Most importantly, I saw this error:

torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable(_xla_mark_sharding_dynamo_custom_op) __call__ [TensorVariable(), ListVariable(), ListVariable(), ListVariable(), ConstantVariable(int)] {}
...
  File "/workspace/llama/llama/model.py", line 212, in forward
    xs.mark_sharding(output, data_model_mesh, (0, 1, 2), use_dynamo_custom_op=True)
  File "/usr/local/lib/python3.8/site-packages/torch_xla-2.2.0+git5a149b8-py3.8-linux-x86_64.egg/torch_xla/distributed/spmd/xla_sharding.py", line 538, in mark_sharding
    annotate_func(

Which made me doubt that the tracing of the dynamo custom op is happening correctly. After some digging, I believe the correct way of invoking this custom dynamo op should be directly through the torch.ops like:

    torch.ops.xla.xla_mark_sharding_dynamo_custom_op(
        unwrap_sharded_tensor(t), [[0, 1]], [], [], 1)

Shown here: https://github.com/pytorch/xla/pull/6161/files#diff-3dcff2b7395bbf1f8a09170775388ef686a1e5f593b3c3889996d78c93a9c394R540-R541.

But now we need to make some bigger changes in the mark_sharding.py, specially with the return type XLAShardedTensor as it is unknown to Dynamo. Now again, why wasn't this captured in the unit tests..

I'll continue to look into it tomorrow.

@JackCaoG
Copy link
Copy Markdown
Collaborator

I started to wondering if it is easier just let dynamo treat the mark_sharding itself as a custom op, instead of the C++ pybind.

@wonjoo-wj
Copy link
Copy Markdown
Collaborator Author

I started to wondering if it is easier just let dynamo treat the mark_sharding itself as a custom op, instead of the C++ pybind.

That could be one option -- we still would need to figure out how to deal with the return type of mark_sharding though. As of now, the mark_sharding returns a custom type of XLAShardedTensor which would not be traceable by Dynamo. One option would be to represent the shard with python default types. Thoughts on this, @yeounoh?

@yeounoh
Copy link
Copy Markdown
Contributor

yeounoh commented Dec 15, 2023

I started to wondering if it is easier just let dynamo treat the mark_sharding itself as a custom op, instead of the C++ pybind.

That could be one option -- we still would need to figure out how to deal with the return type of mark_sharding though. As of now, the mark_sharding returns a custom type of XLAShardedTensor which would not be traceable by Dynamo. One option would be to represent the shard with python default types. Thoughts on this, @yeounoh?

XLAShardedTensor is already a torch.Tensor, which is the same mechanism DTensor anchors on to be compatible with Dynamo. So I would expect or hope it to be captured and treated as a regular tensor.

@wonjoo-wj
Copy link
Copy Markdown
Collaborator Author

Coming from offline sync with @yeounoh:

  • XLAShardedTensor should be traceable by Dynamo as it is an extension for torch::Tensor. DTensor, which is also an extension of torch::Tensor is traceable by Dynamo and so should XLAShardedTensor. TODO: open an issue on PyTorch to find out.
  • This PR's updated way of calling torch.ops.xla.xla_mark_sharding_dynamo_custom_op seems to be the correct way. We still need to do some refactoring on this code path (need to register the existing pybind as the dispatch function to our custom op).
  • Given the uncertainty with XLAShardedTensor traceability and amount of refactoring, probably we won't be able to make this to 2.2 but I'll see what I can do.

cc @yeounoh @JackCaoG

@wonjoo-wj
Copy link
Copy Markdown
Collaborator Author

Opened an issue upstream at pytorch/pytorch#115970. I'll also reach out in PyTorch Slack on Monday.

@wonjoo-wj
Copy link
Copy Markdown
Collaborator Author

This is completed with #6524, closing.

@wonjoo-wj wonjoo-wj closed this Aug 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants