Enable mx fp8 fp4 support on ROCm#2046
Enable mx fp8 fp4 support on ROCm#2046pruthvistony merged 19 commits intorocm6.5_internal_testingfrom
Conversation
Ported the patch from pytorch#147553 Commented few lines to avoid compilation error. (check for todo comments) Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
This comment was marked as off-topic.
This comment was marked as off-topic.
There was a problem hiding this comment.
thanks @jagadish-amd !
left few comments , lets discuss offline as well
This comment was marked as off-topic.
This comment was marked as off-topic.
|
Jenkins build for 1b8bf596817e0403f7038b90b5c656aa30b6df82 commit finished as FAILURE |
|
Jenkins build for 1b8bf596817e0403f7038b90b5c656aa30b6df82 commit finished as FAILURE |
|
Jenkins build for 1b8bf596817e0403f7038b90b5c656aa30b6df82 commit finished as FAILURE |
|
Jenkins build for 1b8bf596817e0403f7038b90b5c656aa30b6df82 commit finished as FAILURE |
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
|
Jenkins build for f9038f6947de578ac45da35984f5a080c3301438 commit finished as FAILURE |
|
Jenkins build for f9038f6947de578ac45da35984f5a080c3301438 commit finished as FAILURE |
|
Jenkins build for f9038f6947de578ac45da35984f5a080c3301438 commit finished as FAILURE |
|
Jenkins build for f9038f6947de578ac45da35984f5a080c3301438 commit finished as FAILURE |
unittest.skip is used as decorator and not inside function to skip. This results in wrong status as test been passed. Use unittest.SkipTest to correctly skip the test inside function. Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
|
Jenkins build for 8250784fca89bc3b107d02066f28b33e6c864410 commit finished as FAILURE Detected error during Pytorch building: |
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
|
Jenkins build for b3ed59b1b38c5008fa295ebbbb6d9f6d3cb73bf4 commit finished as FAILURE |
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
|
Jenkins build for b0bf24a8d2c0bdc481d8be8f10fd2750163eabef commit finished as FAILURE |
|
Jenkins build for b0bf24a8d2c0bdc481d8be8f10fd2750163eabef commit finished as FAILURE |
|
Jenkins build for b0bf24a8d2c0bdc481d8be8f10fd2750163eabef commit finished as FAILURE |
|
Jenkins build for b0bf24a8d2c0bdc481d8be8f10fd2750163eabef commit finished as FAILURE |
|
Jenkins build for 89a10007566ee17896bb51c2a3100f8da20f48d5 commit finished as ABORTED |
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
|
Jenkins build for 89a10007566ee17896bb51c2a3100f8da20f48d5 commit finished as ABORTED |
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
|
Jenkins build for 89a10007566ee17896bb51c2a3100f8da20f48d5 commit finished as ABORTED |
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
pruthvistony
left a comment
There was a problem hiding this comment.
Can you please change the ROCm 6.5 version check to ROCm 7.0, since 6.5 release is dropped.
jeffdaily
left a comment
There was a problem hiding this comment.
Need questions resolved before approving.
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Ported mx fp8 part from #2046 Current test stats (accounting only blockwise scale tests) PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v Ran 225 tests in 8.256s FAILED (failures=1, skipped=150) _74 test pass_ **fp8 mx data type sample test case.** test_blockwise_mxfp8_numerics_test_case_name_data_random_scales_one_fast_accum_True_512_128_256_cuda (__main__.TestFP8MatmulCudaCUDA) hipblaslt-bench --api_method c -m 256 -n 512 -k 128 --lda 128 --ldb 128 --ldc 256 --ldd 256 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1 --beta 0 --transA T --transB N --batch_count 1 --scaleA 3 --scaleB 3 --a_type f8_r --b_type f8_r --c_type bf16_r --d_type bf16_r --compute_type f32_r --algo_method index --solution_index -2146957310 --rotating 0 --cold_iters 0 --iters 0 --------- Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Ported mx fp8 part from #2046 Current test stats (accounting only blockwise scale tests) PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v Ran 225 tests in 8.256s FAILED (failures=1, skipped=150) _74 test pass_ **fp8 mx data type sample test case.** test_blockwise_mxfp8_numerics_test_case_name_data_random_scales_one_fast_accum_True_512_128_256_cuda (__main__.TestFP8MatmulCudaCUDA) hipblaslt-bench --api_method c -m 256 -n 512 -k 128 --lda 128 --ldb 128 --ldc 256 --ldd 256 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1 --beta 0 --transA T --transB N --batch_count 1 --scaleA 3 --scaleB 3 --a_type f8_r --b_type f8_r --c_type bf16_r --d_type bf16_r --compute_type f32_r --algo_method index --solution_index -2146957310 --rotating 0 --cold_iters 0 --iters 0 --------- Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com> (cherry picked from commit d17e222)
Ported mx fp8 part from ROCm#2046 Current test stats (accounting only blockwise scale tests) PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v Ran 225 tests in 8.256s FAILED (failures=1, skipped=150) _74 test pass_ **fp8 mx data type sample test case.** test_blockwise_mxfp8_numerics_test_case_name_data_random_scales_one_fast_accum_True_512_128_256_cuda (__main__.TestFP8MatmulCudaCUDA) hipblaslt-bench --api_method c -m 256 -n 512 -k 128 --lda 128 --ldb 128 --ldc 256 --ldd 256 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1 --beta 0 --transA T --transB N --batch_count 1 --scaleA 3 --scaleB 3 --a_type f8_r --b_type f8_r --c_type bf16_r --d_type bf16_r --compute_type f32_r --algo_method index --solution_index -2146957310 --rotating 0 --cold_iters 0 --iters 0 --------- Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com> (cherry picked from commit d17e222)
Ported mx fp8 part from #2046 Current test stats (accounting only blockwise scale tests) PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v Ran 225 tests in 8.256s FAILED (failures=1, skipped=150) _74 test pass_ **fp8 mx data type sample test case.** test_blockwise_mxfp8_numerics_test_case_name_data_random_scales_one_fast_accum_True_512_128_256_cuda (__main__.TestFP8MatmulCudaCUDA) hipblaslt-bench --api_method c -m 256 -n 512 -k 128 --lda 128 --ldb 128 --ldc 256 --ldd 256 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1 --beta 0 --transA T --transB N --batch_count 1 --scaleA 3 --scaleB 3 --a_type f8_r --b_type f8_r --c_type bf16_r --d_type bf16_r --compute_type f32_r --algo_method index --solution_index -2146957310 --rotating 0 --cold_iters 0 --iters 0 --------- Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com> (cherry picked from commit d17e222)
Ported mx fp8 part from #2046 Current test stats (accounting only blockwise scale tests) PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v Ran 225 tests in 8.256s FAILED (failures=1, skipped=150) _74 test pass_ **fp8 mx data type sample test case.** test_blockwise_mxfp8_numerics_test_case_name_data_random_scales_one_fast_accum_True_512_128_256_cuda (__main__.TestFP8MatmulCudaCUDA) hipblaslt-bench --api_method c -m 256 -n 512 -k 128 --lda 128 --ldb 128 --ldc 256 --ldd 256 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1 --beta 0 --transA T --transB N --batch_count 1 --scaleA 3 --scaleB 3 --a_type f8_r --b_type f8_r --c_type bf16_r --d_type bf16_r --compute_type f32_r --algo_method index --solution_index -2146957310 --rotating 0 --cold_iters 0 --iters 0 --------- Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com> (cherry picked from commit d17e222)
Ported mx fp8 part from #2046 Current test stats (accounting only blockwise scale tests) PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v Ran 225 tests in 8.256s FAILED (failures=1, skipped=150) _74 test pass_ **fp8 mx data type sample test case.** test_blockwise_mxfp8_numerics_test_case_name_data_random_scales_one_fast_accum_True_512_128_256_cuda (__main__.TestFP8MatmulCudaCUDA) hipblaslt-bench --api_method c -m 256 -n 512 -k 128 --lda 128 --ldb 128 --ldc 256 --ldd 256 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1 --beta 0 --transA T --transB N --batch_count 1 --scaleA 3 --scaleB 3 --a_type f8_r --b_type f8_r --c_type bf16_r --d_type bf16_r --compute_type f32_r --algo_method index --solution_index -2146957310 --rotating 0 --cold_iters 0 --iters 0 --------- Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com> (cherry picked from commit d17e222)
This PR enables mx data type support on ROCm.
Current test stats (accounting only blockwise scale tests)
PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v
Ran 452 tests in 17.470s
FAILED (failures=2, errors=2, skipped=337)
111 test pass
fp8 mx data type sample test case.
PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py TestFP8MatmulCudaCUDA.test_blockwise_mxfp8_nvfp4_numerics_test_case_name_a_eye_b_eye_fast_accum_False_128_128_128_recipe_mxfp8_cuda -v
HipblasLT log hipblaslt-bench --api_method c -m 128 -n 128 -k 128 --lda 128 --ldb 128 --ldc 128 --ldd 128 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1 --beta 0 --transA T --transB N --batch_count 1 --scaleA 3 --scaleB 3 --a_type f8_r --b_type f8_r --c_type bf16_r --d_type bf16_r --compute_type f32_r --algo_method index --solution_index -2147220478 --rotating 0 --cold_iters 0 --iters 0
fp4 mx data type sample test case.
PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py TestFP8MatmulCudaCUDA.test_blockwise_mxfp8_nvfp4_numerics_test_case_name_a_eye_b_eye_fast_accum_False_128_128_128_recipe_nvfp4_cuda -v
HipblasLT log hipblaslt-bench --api_method c -m 128 -n 128 -k 128 --lda 128 --ldb 128 --ldc 128 --ldd 128 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1 --beta 0 --transA T --transB N --batch_count 1 --scaleA 3 --scaleB 3 --a_type f4_r --b_type f4_r --c_type bf16_r --d_type bf16_r --compute_type f32_r --algo_method index --solution_index -2147220478 --rotating 0 --cold_iters 0 --iters 0
Commits:
ROCm MX-FP8 Gemm (PR from @petrex )
Ported the patch from ROCm MX-FP8 Gemm pytorch/pytorch#147553
Commented few lines to avoid compilation error. (check for todo comments)
Refine _platform_supports_mx_gemm check
For mx fp8, A and B need not be kFloat8_e8m0fnu type
Add fp4 support (PR from @petrex )
Ported the patch from AMD/ROCm OCP Micro-scaling Format (mx-fp8/mx-fp4) Support pytorch/pytorch#151360
Added fp4 type in aten/src/ATen/cuda/CUDADataType.h
Added more mappings in aten/src/ATen/cuda/CUDADataType.h
Use e8m0 scaling dtype for fp4 test case for ROCm in test/test_matmul_cuda.py
test_matmul: change code to correctly skip
test_matmul: skip if nv format
skip tests if Matrix dimensions must be multiples of 32.
skip convert to swizzled format
add fp4 support for data_to_mx_scale
test_matmul: Add mxfp4 test case