Skip to content

Enable mx fp8 fp4 support on ROCm#2046

Merged
pruthvistony merged 19 commits intorocm6.5_internal_testingfrom
rocm6.5_internal_testing_mx_f8_fp4
May 29, 2025
Merged

Enable mx fp8 fp4 support on ROCm#2046
pruthvistony merged 19 commits intorocm6.5_internal_testingfrom
rocm6.5_internal_testing_mx_f8_fp4

Conversation

@jagadish-amd
Copy link

@jagadish-amd jagadish-amd commented Apr 23, 2025

This PR enables mx data type support on ROCm.

Current test stats (accounting only blockwise scale tests)
PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v
Ran 452 tests in 17.470s
FAILED (failures=2, errors=2, skipped=337)
111 test pass

fp8 mx data type sample test case.
PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py TestFP8MatmulCudaCUDA.test_blockwise_mxfp8_nvfp4_numerics_test_case_name_a_eye_b_eye_fast_accum_False_128_128_128_recipe_mxfp8_cuda -v

HipblasLT log hipblaslt-bench --api_method c -m 128 -n 128 -k 128 --lda 128 --ldb 128 --ldc 128 --ldd 128 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1 --beta 0 --transA T --transB N --batch_count 1 --scaleA 3 --scaleB 3 --a_type f8_r --b_type f8_r --c_type bf16_r --d_type bf16_r --compute_type f32_r --algo_method index --solution_index -2147220478 --rotating 0 --cold_iters 0 --iters 0

fp4 mx data type sample test case.
PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py TestFP8MatmulCudaCUDA.test_blockwise_mxfp8_nvfp4_numerics_test_case_name_a_eye_b_eye_fast_accum_False_128_128_128_recipe_nvfp4_cuda -v
HipblasLT log hipblaslt-bench --api_method c -m 128 -n 128 -k 128 --lda 128 --ldb 128 --ldc 128 --ldd 128 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1 --beta 0 --transA T --transB N --batch_count 1 --scaleA 3 --scaleB 3 --a_type f4_r --b_type f4_r --c_type bf16_r --d_type bf16_r --compute_type f32_r --algo_method index --solution_index -2147220478 --rotating 0 --cold_iters 0 --iters 0

Commits:

  1. ROCm MX-FP8 Gemm (PR from @petrex )
    Ported the patch from ROCm MX-FP8 Gemm pytorch/pytorch#147553
    Commented few lines to avoid compilation error. (check for todo comments)

  2. Refine _platform_supports_mx_gemm check

  3. For mx fp8, A and B need not be kFloat8_e8m0fnu type

  4. Add fp4 support (PR from @petrex )
    Ported the patch from AMD/ROCm OCP Micro-scaling Format (mx-fp8/mx-fp4) Support pytorch/pytorch#151360
    Added fp4 type in aten/src/ATen/cuda/CUDADataType.h
    Added more mappings in aten/src/ATen/cuda/CUDADataType.h
    Use e8m0 scaling dtype for fp4 test case for ROCm in test/test_matmul_cuda.py

  5. test_matmul: change code to correctly skip

  6. test_matmul: skip if nv format
    skip tests if Matrix dimensions must be multiples of 32.
    skip convert to swizzled format

  7. add fp4 support for data_to_mx_scale

  8. test_matmul: Add mxfp4 test case

Ported the patch from pytorch#147553
Commented few lines to avoid compilation error. (check for todo comments)

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
@jagadish-amd jagadish-amd changed the title Enable mx f8 fp4 support on ROCm [In Progress] Enable mx f8 fp4 support on ROCm Apr 23, 2025
@rocm-repo-management-api

This comment was marked as off-topic.

petrex
petrex previously requested changes Apr 24, 2025
Copy link

@petrex petrex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @jagadish-amd !
left few comments , lets discuss offline as well

@rocm-repo-management-api

This comment was marked as off-topic.

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Apr 25, 2025

Jenkins build for 1b8bf596817e0403f7038b90b5c656aa30b6df82 commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Apr 25, 2025

Jenkins build for 1b8bf596817e0403f7038b90b5c656aa30b6df82 commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@pruthvistony pruthvistony marked this pull request as draft April 28, 2025 05:37
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Apr 28, 2025

Jenkins build for 1b8bf596817e0403f7038b90b5c656aa30b6df82 commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Apr 29, 2025

Jenkins build for 1b8bf596817e0403f7038b90b5c656aa30b6df82 commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Apr 29, 2025

Jenkins build for f9038f6947de578ac45da35984f5a080c3301438 commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Apr 29, 2025

Jenkins build for f9038f6947de578ac45da35984f5a080c3301438 commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Apr 30, 2025

Jenkins build for f9038f6947de578ac45da35984f5a080c3301438 commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Apr 30, 2025

Jenkins build for f9038f6947de578ac45da35984f5a080c3301438 commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

unittest.skip is used as decorator and not inside function
to skip. This results in wrong status as test been passed.
Use unittest.SkipTest to correctly skip the test inside function.

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Apr 30, 2025

Jenkins build for 8250784fca89bc3b107d02066f28b33e6c864410 commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

Detected error during Pytorch building:

[1/4] Generating ATen declarations_yaml
[2/4] Generating ATen headers
[3/4] Generating ATen sources
[1/7991] Creating directories for 'aotriton_external'
[2/7991] Performing download step (download, verify and extract) for 'aotriton_external'
FAILED: aotriton_external-prefix/src/aotriton_external-stamp/aotriton_external-download /var/lib/jenkins/pytorch/build/aotriton_external-prefix/src/aotriton_external-stamp/aotriton_external-download 
cd /var/lib/jenkins/pytorch/build && /opt/conda/envs/py_3.12/bin/cmake -DCMAKE_MESSAGE_LOG_LEVEL=VERBOSE -P /var/lib/jenkins/pytorch/build/aotriton_external-prefix/src/aotriton_external-stamp/download-aotriton_external.cmake && /opt/conda/envs/py_3.12/bin/cmake -DCMAKE_MESSAGE_LOG_LEVEL=VERBOSE -P /var/lib/jenkins/pytorch/build/aotriton_external-prefix/src/aotriton_external-stamp/verify-aotriton_external.cmake && /opt/conda/envs/py_3.12/bin/cmake -DCMAKE_MESSAGE_LOG_LEVEL=VERBOSE -P /var/lib/jenkins/pytorch/build/aotriton_external-prefix/src/aotriton_external-stamp/extract-aotriton_external.cmake && /opt/conda/envs/py_3.12/bin/cmake -E touch /var/lib/jenkins/pytorch/build/aotriton_external-prefix/src/aotriton_external-stamp/aotriton_external-download
-- Downloading...
   dst='/var/lib/jenkins/pytorch/build/aotriton_external-prefix/src/aotriton-0.9.2b_612896439f-manylinux_2_28_x86_64-rocm6.5-shared.tar.gz'
   timeout='none'
   inactivity timeout='none'

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented May 3, 2025

Jenkins build for b3ed59b1b38c5008fa295ebbbb6d9f6d3cb73bf4 commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented May 5, 2025

Jenkins build for b0bf24a8d2c0bdc481d8be8f10fd2750163eabef commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented May 6, 2025

Jenkins build for b0bf24a8d2c0bdc481d8be8f10fd2750163eabef commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented May 6, 2025

Jenkins build for b0bf24a8d2c0bdc481d8be8f10fd2750163eabef commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented May 8, 2025

Jenkins build for b0bf24a8d2c0bdc481d8be8f10fd2750163eabef commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented May 14, 2025

Jenkins build for 89a10007566ee17896bb51c2a3100f8da20f48d5 commit finished as ABORTED
Links: Blue Ocean view / Build artifacts

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented May 14, 2025

Jenkins build for 89a10007566ee17896bb51c2a3100f8da20f48d5 commit finished as ABORTED
Links: Blue Ocean view / Build artifacts

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented May 15, 2025

Jenkins build for 89a10007566ee17896bb51c2a3100f8da20f48d5 commit finished as ABORTED
Links: Blue Ocean view / Build artifacts

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
@jagadish-amd jagadish-amd changed the title [In Progress] Enable mx f8 fp4 support on ROCm Enable mx f8 fp4 support on ROCm May 20, 2025
@jagadish-amd jagadish-amd marked this pull request as ready for review May 20, 2025 22:48
Copy link
Collaborator

@pruthvistony pruthvistony left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please change the ROCm 6.5 version check to ROCm 7.0, since 6.5 release is dropped.

Copy link
Collaborator

@jeffdaily jeffdaily left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need questions resolved before approving.

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
@jagadish-amd jagadish-amd changed the title Enable mx f8 fp4 support on ROCm Enable mx fp8 fp4 support on ROCm May 28, 2025
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
@pruthvistony pruthvistony merged commit 4634421 into rocm6.5_internal_testing May 29, 2025
@pruthvistony pruthvistony deleted the rocm6.5_internal_testing_mx_f8_fp4 branch May 29, 2025 16:46
pruthvistony pushed a commit that referenced this pull request Jun 4, 2025
Ported mx fp8 part from #2046

Current test stats (accounting only blockwise scale tests)
PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k
test_blockwise -v

Ran 225 tests in 8.256s
FAILED (failures=1, skipped=150)
_74 test pass_

**fp8 mx data type sample test case.**

test_blockwise_mxfp8_numerics_test_case_name_data_random_scales_one_fast_accum_True_512_128_256_cuda
(__main__.TestFP8MatmulCudaCUDA)
hipblaslt-bench --api_method c -m 256 -n 512 -k 128 --lda 128 --ldb 128
--ldc 256 --ldd 256 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0
--alpha 1 --beta 0 --transA T --transB N --batch_count 1 --scaleA 3
--scaleB 3 --a_type f8_r --b_type f8_r --c_type bf16_r --d_type bf16_r
--compute_type f32_r --algo_method index --solution_index -2146957310
--rotating 0 --cold_iters 0 --iters 0

---------

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
pragupta pushed a commit that referenced this pull request Jul 21, 2025
Ported mx fp8 part from #2046

Current test stats (accounting only blockwise scale tests)
PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k
test_blockwise -v

Ran 225 tests in 8.256s
FAILED (failures=1, skipped=150)
_74 test pass_

**fp8 mx data type sample test case.**

test_blockwise_mxfp8_numerics_test_case_name_data_random_scales_one_fast_accum_True_512_128_256_cuda
(__main__.TestFP8MatmulCudaCUDA)
hipblaslt-bench --api_method c -m 256 -n 512 -k 128 --lda 128 --ldb 128
--ldc 256 --ldd 256 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0
--alpha 1 --beta 0 --transA T --transB N --batch_count 1 --scaleA 3
--scaleB 3 --a_type f8_r --b_type f8_r --c_type bf16_r --d_type bf16_r
--compute_type f32_r --algo_method index --solution_index -2146957310
--rotating 0 --cold_iters 0 --iters 0

---------

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
(cherry picked from commit d17e222)
pragupta pushed a commit to pragupta/pytorch that referenced this pull request Jul 21, 2025
Ported mx fp8 part from ROCm#2046

Current test stats (accounting only blockwise scale tests)
PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k
test_blockwise -v

Ran 225 tests in 8.256s
FAILED (failures=1, skipped=150)
_74 test pass_

**fp8 mx data type sample test case.**

test_blockwise_mxfp8_numerics_test_case_name_data_random_scales_one_fast_accum_True_512_128_256_cuda
(__main__.TestFP8MatmulCudaCUDA)
hipblaslt-bench --api_method c -m 256 -n 512 -k 128 --lda 128 --ldb 128
--ldc 256 --ldd 256 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0
--alpha 1 --beta 0 --transA T --transB N --batch_count 1 --scaleA 3
--scaleB 3 --a_type f8_r --b_type f8_r --c_type bf16_r --d_type bf16_r
--compute_type f32_r --algo_method index --solution_index -2146957310
--rotating 0 --cold_iters 0 --iters 0

---------

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
(cherry picked from commit d17e222)
pragupta pushed a commit that referenced this pull request Jul 22, 2025
Ported mx fp8 part from #2046

Current test stats (accounting only blockwise scale tests)
PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k
test_blockwise -v

Ran 225 tests in 8.256s
FAILED (failures=1, skipped=150)
_74 test pass_

**fp8 mx data type sample test case.**

test_blockwise_mxfp8_numerics_test_case_name_data_random_scales_one_fast_accum_True_512_128_256_cuda
(__main__.TestFP8MatmulCudaCUDA)
hipblaslt-bench --api_method c -m 256 -n 512 -k 128 --lda 128 --ldb 128
--ldc 256 --ldd 256 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0
--alpha 1 --beta 0 --transA T --transB N --batch_count 1 --scaleA 3
--scaleB 3 --a_type f8_r --b_type f8_r --c_type bf16_r --d_type bf16_r
--compute_type f32_r --algo_method index --solution_index -2146957310
--rotating 0 --cold_iters 0 --iters 0

---------

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
(cherry picked from commit d17e222)
jithunnair-amd pushed a commit that referenced this pull request Jul 22, 2025
Ported mx fp8 part from #2046

Current test stats (accounting only blockwise scale tests)
PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k
test_blockwise -v

Ran 225 tests in 8.256s
FAILED (failures=1, skipped=150)
_74 test pass_

**fp8 mx data type sample test case.**

test_blockwise_mxfp8_numerics_test_case_name_data_random_scales_one_fast_accum_True_512_128_256_cuda
(__main__.TestFP8MatmulCudaCUDA)
hipblaslt-bench --api_method c -m 256 -n 512 -k 128 --lda 128 --ldb 128
--ldc 256 --ldd 256 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0
--alpha 1 --beta 0 --transA T --transB N --batch_count 1 --scaleA 3
--scaleB 3 --a_type f8_r --b_type f8_r --c_type bf16_r --d_type bf16_r
--compute_type f32_r --algo_method index --solution_index -2146957310
--rotating 0 --cold_iters 0 --iters 0

---------

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
(cherry picked from commit d17e222)
pragupta pushed a commit that referenced this pull request Jul 29, 2025
Ported mx fp8 part from #2046

Current test stats (accounting only blockwise scale tests)
PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k
test_blockwise -v

Ran 225 tests in 8.256s
FAILED (failures=1, skipped=150)
_74 test pass_

**fp8 mx data type sample test case.**

test_blockwise_mxfp8_numerics_test_case_name_data_random_scales_one_fast_accum_True_512_128_256_cuda
(__main__.TestFP8MatmulCudaCUDA)
hipblaslt-bench --api_method c -m 256 -n 512 -k 128 --lda 128 --ldb 128
--ldc 256 --ldd 256 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0
--alpha 1 --beta 0 --transA T --transB N --batch_count 1 --scaleA 3
--scaleB 3 --a_type f8_r --b_type f8_r --c_type bf16_r --d_type bf16_r
--compute_type f32_r --algo_method index --solution_index -2146957310
--rotating 0 --cold_iters 0 --iters 0

---------

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
(cherry picked from commit d17e222)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants