Skip to content

Add flag to enable activation sharding#55

Merged
wonjoo-wj merged 4 commits intooptimize_spmd_shardingfrom
activation_sharding
Mar 20, 2024
Merged

Add flag to enable activation sharding#55
wonjoo-wj merged 4 commits intooptimize_spmd_shardingfrom
activation_sharding

Conversation

@wonjoo-wj
Copy link
Copy Markdown
Collaborator

With this PR, you can add flag --enable_activation_sharding True to your command to enable activation sharding. By default, this is set to false.

Comment thread llama/model.py Outdated
Copy link
Copy Markdown
Collaborator

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM minor suggestion

Comment thread llama/model.py Outdated
@wonjoo-wj
Copy link
Copy Markdown
Collaborator Author

wonjoo-wj commented Mar 20, 2024

@yeounoh, I've updated the code to dynamically fetch the num_device, device_ids, mesh_shape by calculating num_device outside of the torch.compile'd function. Verified that now it properly accepts the argument --enable_activation_sharding like:

PJRT_DEVICE=TPU XLA_USE_SPMD=1 python3 example_text_completion.py --ckpt_dir checkpointing/7B/ --tokenizer_path spiece.model --max_seq_len 2048 --max_gen_len 1000 --max_batch_size 1 --mp False --dynamo True --spmd True --enable_activation_sharding True

I'll go ahead and merge this for now and follow-up if there are any other things to address.

@wonjoo-wj wonjoo-wj merged commit f7b9278 into optimize_spmd_sharding Mar 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants