Skip to content

[Pallas] Refactor the gmm kernel#7099

Merged
alanwaketan merged 10 commits intomasterfrom
alanwaketan/gmm
May 23, 2024
Merged

[Pallas] Refactor the gmm kernel#7099
alanwaketan merged 10 commits intomasterfrom
alanwaketan/gmm

Conversation

@alanwaketan
Copy link
Copy Markdown
Collaborator

Summary:
This is an effort to refactor the code from #6940 and aims to remove useless code in that part. It reduces the amount of code from ~400 lines to ~50 lines. However, a bummer is the original gmm kernel doesn't work at all... It assumes groups_sizes is a cpu tensor. That means we need to materialize this input in order to use this gmm kernel, and that will introduce graph breaks in the computation. I will need yet another follow up to make this code actually functional... Good news is the test cases seem functional, yay...

Test Plan:
python test/test_megablox.py

@JackCaoG JackCaoG added the tpuci label May 23, 2024
@alanwaketan
Copy link
Copy Markdown
Collaborator Author

No need to try tpu ci as I mentioned .... This thing is not functional on tpu....

@alanwaketan alanwaketan merged commit 5327033 into master May 23, 2024
@alanwaketan alanwaketan deleted the alanwaketan/gmm branch May 23, 2024 01:55
Copy link
Copy Markdown
Collaborator

@wonjoo-wj wonjoo-wj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

qihqi pushed a commit that referenced this pull request May 29, 2024
Summary:
This is an effort to refactor the code from #6940 and aims to remove useless code in that part. It reduces the amount of code from ~400 lines to ~50 lines. However, a bummer is the original gmm kernel doesn't work at all... It assumes groups_sizes is a cpu tensor. That means we need to materialize this input in order to use this gmm kernel, and that will introduce graph breaks in the computation. I will need yet another follow up to make this code actually functional... Good news is the test cases seem functional, yay...

Test Plan:
python test/test_megablox.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants