Skip to content

[Feature] use pytest for sgl-kernel#4697

Closed
adarshxs wants to merge 17 commits intosgl-project:mainfrom
adarshxs:pytest_transition
Closed

[Feature] use pytest for sgl-kernel#4697
adarshxs wants to merge 17 commits intosgl-project:mainfrom
adarshxs:pytest_transition

Conversation

@adarshxs
Copy link
Copy Markdown
Collaborator

Motivation

For: #4690

Modifications

test files in: https://github.com/sgl-project/sglang/tree/main/sgl-kernel/tests

Checklist

Comment thread sgl-kernel/tests/test_activation.py Outdated
@adarshxs adarshxs requested a review from zhyncs March 26, 2025 17:37
@zhyncs
Copy link
Copy Markdown
Collaborator

zhyncs commented Mar 26, 2025

speculative/test_eagle_utils.py and speculative/test_speculative_sampling.py should also be updated

Comment thread sgl-kernel/tests/test_trt_allreduce.py
@adarshxs
Copy link
Copy Markdown
Collaborator Author

@zhyncs one test is failing. Increasing tolerance should help?

@zhyncs zhyncs requested a review from FlamingoPg as a code owner March 29, 2025 17:02
@FlamingoPg
Copy link
Copy Markdown
Collaborator

nice work~

Copy link
Copy Markdown
Collaborator

@FlamingoPg FlamingoPg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM~

)
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
def test_grouped_gemm_accuracy(out_dtype):
Ms = [1, 16, 32, 256, 1024]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For different shapes you should use @pytest.mark.parametrize instead of for loop

Comment thread sgl-kernel/tests/test_deep_gemm.py Outdated
)
def test_gemm():
print("Testing GEMM:")
for m in (64, 128, 4096):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Comment thread sgl-kernel/tests/test_deep_gemm.py Outdated

def test_m_grouped_gemm_contiguous():
print("Testing grouped contiguous GEMM:")
for num_groups, m, k, n in (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same



def test_accuracy():
Ms = [1, 128, 512, 1024, 4096]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Comment thread sgl-kernel/tests/test_fp8_gemm.py Outdated


def test_accuracy():
Ms = [1, 128, 512, 1024, 4096]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Comment thread sgl-kernel/tests/test_int8_gemm.py Outdated


def test_accuracy():
Ms = [1, 16, 32, 64, 128, 512, 1024, 4096, 8192]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

@adarshxs adarshxs closed this Mar 29, 2025
@adarshxs
Copy link
Copy Markdown
Collaborator Author

adarshxs commented Mar 29, 2025

apologies messed up. opening a new PR

@adarshxs adarshxs deleted the pytest_transition branch April 19, 2025 14:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants