Skip to content

Implement mark_sharding as a custom dynamo op#6524

Merged
wonjoo-wj merged 7 commits intomasterfrom
wonjoo/dynamo-mark-sharding-python-custom-op
Mar 5, 2024
Merged

Implement mark_sharding as a custom dynamo op#6524
wonjoo-wj merged 7 commits intomasterfrom
wonjoo/dynamo-mark-sharding-python-custom-op

Conversation

@wonjoo-wj
Copy link
Copy Markdown
Collaborator

@wonjoo-wj wonjoo-wj commented Feb 13, 2024

Follow-up to #6161, this PR implementsmark_sharding as a custom dynamo op

The newly introduced custom op's function signature is:

dynamo_mark_sharding(Tensor input, int[] device_ids, int[] mesh_shape, str axis_names, str partition_spec) -> Tensor

As per https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md, the variables types that PyTorch accepts in registering function are very limited. It does not recognize Tuple, Str[], or any custom types that we want for mark_sharding. So when we define our custom op dynamo_mark_sharding, we had to make it a little hacky:

  1. mesh_shape: Instead of a tuple, we just define this as a list. Then we just do a manual conversion to tuple.
  2. axis_names/partition_spec: PyTorch does not provide a way to define types for multiple types (ex. type for partition_spec is Tuple[Union[Tuple, int, str, None]. So we just make this a string, then use Python's ast.literal_eval (which is an injection-safe eval) to convert the string equivalent version of the variables into their "real" types. Example of how ast.literal_eval works:
>>> import ast
>>> x = "(0, None)"
>>> type(x)
<class 'str'>
>>> x_eval = ast.literal_eval(x)
>>> x_eval
(0, None)
>>> type(x_eval)
<class 'tuple'>

Companion PR in LLaMa 2: pytorch-tpu/llama#53


TODO

  • Clean up the unnecessary C++ dynamo_custom_op in a follow-up PR

@vanbasten23
Copy link
Copy Markdown
Collaborator

I wonder why we want to implement mark_sharding as a custom dynamo op

@wonjoo-wj
Copy link
Copy Markdown
Collaborator Author

I wonder why we want to implement mark_sharding as a custom dynamo op

This is a follow-up PR to #6161, which originally aimed to support Dynamo + SPMD activation sharding (i.e. allow mark_sharding in a torch.compile call). In order to allow this, we need to make mark_sharding traceable and making the entire mark_sharding a python custom op was the easiest way to make it traceable.

@wonjoo-wj
Copy link
Copy Markdown
Collaborator Author

wonjoo-wj commented Feb 15, 2024

To reviewers, I have some refactoring I want to do here based on my offline sync with Jack and Jiewen. Please don't review this yet, I'll tag you here once it's ready for review (ETA tomorrow). Thanks!

@wonjoo-wj
Copy link
Copy Markdown
Collaborator Author

Based on our conversation yesterday, we ideally want to keep the same existing xs.mark_sharding API so customers' UX stays the same. However, one blocker I saw while trying to do this is that if we want to enable mark_sharding inside torch.compile (i.e. support activation sharding), we cannot keep the existing xs.mark_sharding because it takes in Mesh as an input. And Dynamo does not know how to trace a custom type like Mesh. So it seems like we have to introduce a new API to make it traceable.

cc @JackCaoG

@JackCaoG
Copy link
Copy Markdown
Collaborator

And Dynamo does not know how to trace a custom type like Mesh Can you post the error message here? Dynamo should be able to understand basic custom python types.

@wonjoo-wj
Copy link
Copy Markdown
Collaborator Author

wonjoo-wj commented Feb 21, 2024

By adding xs.Mesh into the forward function that is being torch.compile'ed like such:

         # # TODO(yeounoh) remove this after activation sharding support is enabled.
         # num_devices = xr.global_runtime_device_count()
         # device_ids = torch.arange(num_devices)
-        # data_model_mesh = xs.Mesh(device_ids, (4, 1, 2))
+        data_model_mesh = xs.Mesh(device_ids, (4, 1, 2))
         # xs.mark_sharding(output, data_model_mesh, (0, 1, 2), use_dynamo_custom_op=True)

The error we get is:

  File "/home/wonjoo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 190, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: undefined LOAD_FAST

from user code:
   File "/home/wonjoo/llama/llama/generation.py", line 183, in _generate_one_token
    logits = self.model(input_tokens, input_pos_tensor, output_pos_tensor)
  File "/home/wonjoo/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wonjoo/miniconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/wonjoo/llama/llama/model.py", line 374, in forward
    h = layer(h, freqs_cis, mask, input_indexes)
  File "/home/wonjoo/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wonjoo/llama/llama/model.py", line 310, in forward
    h = x + self.attention.forward(
  File "/home/wonjoo/llama/llama/model.py", line 210, in forward
    data_model_mesh = xs.Mesh(device_ids, (4, 1, 2))

Full error logs is at https://gist.github.com/wonjoolee95/937575e2f5498a2179554974e0a52264.

This error is similar we saw in our prior attempt to make mark_sharding traceable, similar errors would be thrown for dynamo complaining to be unable to trace code like os.environ.

@JackCaoG
Copy link
Copy Markdown
Collaborator

  File "/home/wonjoo/miniconda3/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 839, in LOAD_FAST
    unimplemented("undefined LOAD_FAST")
        if name not in self.symbolic_locals:
            unimplemented("undefined LOAD_FAST")

can you print the name here? Mesh is a relativelly easy python types, I wonder if there is an easy way to make dynamo understand it.

@wonjoo-wj
Copy link
Copy Markdown
Collaborator Author

Posting an update after offline sync with Jack: so the issue with xs.Mesh with an issue with my user code (llama/model.py).

The current issue under discussion is from torch_xla.distributed.spmd.mark_sharding -> torch_xla.experimental.dynamo_mark_sharding. It does not seem to like the input type (full error logs):

torch._dynamo.exc.TorchRuntimeError: Failed running call_function xla.dynamo_mark_sharding(*(FakeTensor(..., device='xla:0', size=(1, 128, 512)), [u0], '(4, 1, 2)', 'None', '(0, 1, 2)'), **{}):
xla::dynamo_mark_sharding() Expected a value of type 'List[int]' for argument 'device_ids' but instead found type 'immutable_list'.
Position: 1
Value: [u0]
Declaration: xla::dynamo_mark_sharding(Tensor t, int[] device_ids, str mesh_shape, str axis_names, str partition_spec) -> Tensor
Cast error details: Unable to cast Python instance of type <class 'torch.fx.immutable_collections.immutable_list'> to C++ type 'std::vector<long, std::allocator<long> >'

This error is happening at exactly this line in the new xs.mark_sharding:

return torch.ops.xla.dynamo_mark_sharding(t, device_ids_tensor_list, mesh_shape_str, axis_names_str, partition_spec_str)

mesh_shape = list(mesh.mesh_shape)
axis_names = str(mesh.axis_names)
partition_spec = '(1, 0)'
torch.ops.xla.dynamo_mark_sharding(linear.fc2.weight, device_ids,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't dynamo_mark_sharding be inside the linear? What we are trying to verify is whether custom op can be captured by dynamo.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that is correct -- this was to see if the "normal" case can be successful with the new custom op. Added a new test to reflect that.

@wonjoo-wj
Copy link
Copy Markdown
Collaborator Author

wonjoo-wj commented Feb 28, 2024

Due to Dynamo capturing a python native list type as an torch.fx.immutable_collections.immutable_list as described in #6524 (comment), using the existing xs.mark_sharding to directly call the custom op does not seem possible for now.

I can follow-up with PyTorch to see if we can do something to get this resolved, but for now, we can stick to having a separate custom op to enable dynamo mark sharding.

With these changes, I can confirm that LLaMa 2 Dynamo + SPMD with activation sharding is passing as following: pytorch-tpu/llama#53.

@JackCaoG, this should be ready to be reviewed now. Thanks!

@wonjoo-wj wonjoo-wj requested a review from JackCaoG February 28, 2024 02:15
@wonjoo-wj wonjoo-wj force-pushed the wonjoo/dynamo-mark-sharding-python-custom-op branch from a9d827f to d79ca73 Compare February 28, 2024 07:50
@wonjoo-wj
Copy link
Copy Markdown
Collaborator Author

test_parallel_cow_materialize_error_xla_bfloat16 test failing on CPU/GPU tests, seemingly unrelated. Re-running the CI.

@wonjoo-wj wonjoo-wj force-pushed the wonjoo/dynamo-mark-sharding-python-custom-op branch from ade4c02 to f8de76e Compare February 28, 2024 20:42
@wonjoo-wj wonjoo-wj force-pushed the wonjoo/dynamo-mark-sharding-python-custom-op branch from f8de76e to b7078af Compare February 28, 2024 23:30
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
import torch_xla.debug.metrics as met
import torch_xla.experimental.dynamo_mark_sharding
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can move this under torch_xla.distirbuted.spmd.

Copy link
Copy Markdown
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@wonjoo-wj
Copy link
Copy Markdown
Collaborator Author

Thanks for the review, Yeounoh. I'll go ahead and merge this for now and open a follow-up PR soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants