Skip to content

Support activation sharding#53

Closed
wonjoo-wj wants to merge 1 commit intooptimize_spmd_shardingfrom
wonjoo/activation-sharding
Closed

Support activation sharding#53
wonjoo-wj wants to merge 1 commit intooptimize_spmd_shardingfrom
wonjoo/activation-sharding

Conversation

@wonjoo-wj
Copy link
Copy Markdown
Collaborator

@wonjoo-wj wonjoo-wj commented Feb 15, 2024

With pytorch/xla#6524, we can now support activation sharding.

This is an example of the changes we need to call the new dynamo python custom op.

Note that this requires pytorch/xla#6524 to be merged first. Tested locally, llama 2 inference with dynamo+spmd (with activation sharding) is successful: https://gist.github.com/wonjoolee95/a290a68f29c52bd395b16ae6df651531.

Also note that performance is not tested, so we need to find the optimal sharding strategy. This only tests the functionality.

Comment thread llama/model.py
device_ids = torch.arange(num_devices)
data_model_mesh = xs.Mesh(device_ids, (4, 1, 2))
xs.mark_sharding(output, data_model_mesh, (0, 1, 2), use_dynamo_custom_op=True)
# num_devices = xr.global_runtime_device_count()
Copy link
Copy Markdown

@miladm miladm Mar 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wonjoolee95 can you add a discussion on the reason to keep output = self.wo(output) in line 204 when we want to do the dynamo_mark_sharding call after?

Copy link
Copy Markdown
Collaborator

@yeounoh yeounoh Mar 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the original implementation, and we are just adding the annotations. We have to do output projection as-is.

@yeounoh yeounoh self-requested a review March 20, 2024 18:25
Comment thread llama/model.py
# custom python dynamo mark sharding
import torch_xla.experimental.dynamo_mark_sharding
device_ids = [0]
mesh_shape = [1, 1, 1]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, why is the mesh shape 1, 1, 1?

Copy link
Copy Markdown
Collaborator

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for testing -- @wonjoolee95

  • could you remove the "[WONJOO]" debugging probes, and make it look more formal if useful?
  • Also, you have to change the mesh shape. Don't think we are sharding the activation in the current impl?

@wonjoo-wj
Copy link
Copy Markdown
Collaborator Author

Moving this over to #55

@wonjoo-wj wonjoo-wj closed this Mar 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants