Skip to content

[2/2] support dp mla#5001

Closed
xu-yfei wants to merge 1 commit intosgl-project:mainfrom
xu-yfei:mla_dp
Closed

[2/2] support dp mla#5001
xu-yfei wants to merge 1 commit intosgl-project:mainfrom
xu-yfei:mla_dp

Conversation

@xu-yfei
Copy link
Copy Markdown
Contributor

@xu-yfei xu-yfei commented Apr 2, 2025

Motivation

Base on dp_mla_kernel PR #5000

Description:
On an 8*H20(96GB), weight mem usage=87.19 GB when --dp-size 4 --enable-dp-attention, not enough memory left.

This optimization is similar to data parallelism attention, but it applies to MLA core instead of the entire attention. Compared with data parallelism attention, it does not additionally increase the memory occupied by weights. It allows for a significant reduction in the KV cache size and enables larger batch sizes with only 1-3ms additional decode latency.

On an 8×H20 (96GB) node, with data parallelism MLA enabled, we have achieved up to 1.85x(dp=4) ~ 2.34x(dp=8) decoding throughput improvement compared to the previous version. And, the number of kvcaches has been increased by 3.3x(dp=4) and 6.6x(dp=8).

For 8*H20 (96GB), when the input is 4000, the output is 1500, and there are 128 concurrent requests, the output throughput per card for decoding can reach 266 tokens/s.

gsm8k:

#  dp=4
Accuracy: 0.958
Invalid: 0.000
Latency: 166.961 s
Output throughput: 765.964 token/s

# dp=4 fa3
Accuracy: 0.958
Invalid: 0.000
Latency: 147.353 s
Output throughput: 863.553 token/s
#  dp=8
Accuracy: 0.955
Invalid: 0.000
Latency: 165.833 s
Output throughput: 781.248 token/s
#  tp=8
Accuracy: 0.954
Invalid: 0.000
Latency: 177.669 s
Output throughput: 719.521 token/s

Usage:

Restrictions:

  • The currently implemented custom all-to-all operation is only restricted within a single node. Cross-node is not supported for the time being.

On an 8×H20 (96GB) node, on top of --attention-backend flashinfer or --attention-backend fa3, remove --enable-dp-lm-head, additional parameters:

# dp=4:
--max-running-requests 128 --mem-fraction-static 0.94 --chunked-prefill-size 8192 --enable-dp-mla --dp-size 4
# dp=8:
--max-running-requests 128 --mem-fraction-static 0.94 --chunked-prefill-size 8192 --enable-dp-mla --dp-size 8

We set the switch flashinfer_mla_disable_ragged to True when --enable-dp-mla to avoid the following issues:

  • There is a problem with the accuracy of requests after running sglang.bench_one_batch_server. Through testing the model deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct, we found that after the warmup process, there are NaN values in the kvcache, which leads to NaN values in the results of subsequent requests. After troubleshooting, it was determined that the result of the flashinfer-mla ragged process is NaN. The reproduction parameters are:
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct/  \
--host 0.0.0.0 --port 8188 --trust-remote-code --tp-size 8 --dp-size 8 --log-level info \
--enable-flashinfer-mla --enable-dp-mla --mem-fraction-static 0.9 --disable-cuda-graph 
  • The aot version of the flashinfer package reports an error.

Background:
For DeepSeek R1 on a single machine with 8*H20 (96GB), we aimed to deploy enable_dp_attention. We found:

  1. enable_dp_attention duplicates the weights of the attention, which consumes both memory and time.
  • A single machine with eight cards can't accommodate it due to more memory usage (10+GB).
  • Weight calculations for kv_u/q_u/o_proj are very time-consuming, slowing down Decode by 8ms in a dp=4, tp=2 scenario compared to tp=8 in the case of a short sequence of requests. With dp=8, the delay is expected to be worse.
  • The latest dp attention update doesn't split M (seq dimension).
  • Keeping the linear split strategy in attention may be a better approach.
  1. The row number M in wgmma calculation is fixed at 64. For Flash MLA and Flash Infer MLA, when TP = 8 and the number of heads is 16, only 1/4 of their computing power is utilized. When MLA DP = 4, the TP is 2 and the number of heads is 64 (calculated as 128/2), which can maintain good latency. When MLA DP = 8, the TP is 1 and the number of heads is 128. In this case, the time consumption of the MLA core may increase by 50% to 100%, but more kvcache can be obtained.

TODO:

  • Accuracy and performance testing and optimization for various scenarios.
  • Performance optimization of custom all to all.
  • Fix the above issue with ragged support in FlashInfer MLA.

Memory:

default mem=0.95 dp mla, dp-size=4 dp mla, dp-size=8
mem-fraction-static 0.95 0.94 0.94
distributed ends. mem usage 1.81 1.95 1.95
Load weight end. avail mem/mem usage 11.23/81.64 11.08/81.65 11.08/81.65
Capture cuda graph end. avail mem/mem usage 3.65/0.53 3.89/1.20 3.89/1.19
max_total_num_tokens 98696 82258*4 82258*8

Performance Comparison:

python -m sglang.bench_one_batch_server --model None --base-url http://localhost:8188 --batch-size {bs} --input-len {input_seq} --output-len 1024
input len batch size tp8 latency output throughput dp4 latency output throughput dp8 latency output throughput
256 1 14.16 72.31 16.63 61.59 16.84 60.80
256 4 19.23 212.98 20.53 199.55 20.82 196.75
256 8 22.81 359.21 24.72 331.33 23.74 345.02
256 16 27.13 603.92 29.47 555.90 28.76 569.68
256 32 34.04 962.55 35.94 911.62 35.19 931.11
256 48 40.10 1225.70 40.21 1222.45 39.65 1239.79
256 64 67.45 971.64 43.91 1492.39 44.49 1473.02
256 96 79.39 1238.24 52.35 1877.66 52.69 1865.58
256 128 113.39 1155.93 57.36 2285.00 57.60 2275.67
1024 1 14.48 70.73 16.84 60.82 17.17 59.64
1024 4 18.85 217.25 20.74 197.46 21.89 187.12
1024 8 23.45 349.36 25.67 319.18 25.52 320.97
1024 16 28.34 578.11 29.66 552.47 30.88 530.57
1024 32 37.57 872.30 38.25 856.76 38.45 852.20
1024 48 46.68 1053.04 43.97 1117.75 44.24 1110.92
1024 64 76.28 859.21 49.02 1336.82 49.90 1313.24
1024 96 93.56 1050.74 59.77 1644.84 60.78 1617.42
1024 128 131.59 996.06 67.76 1934.37 67.93 1929.56
4096 1 15.29 66.97 18.09 56.61 19.16 53.45
4096 4 21.03 194.81 22.96 178.41 24.27 168.76
4096 8 28.45 287.95 29.70 275.87 28.72 285.19
4096 16 39.68 412.88 36.95 443.44 37.39 438.23
4096 32 82.49 397.26 52.17 628.10 53.22 615.65
4096 48 124.72 394.11 66.07 743.90 67.61 727.04
4096 64 167.71 390.77 79.40 825.36 79.90 820.26
4096 96 233.19 421.55 144.59 679.88 106.63 921.92
4096 128 317.57 412.73 170.99 766.54 132.86 986.54

Modifications

Use dp for mla core calculations and tp for others, keeping memory and calculation time of linear in attn.

  1. Convert model input from dp to tp using dp_gather to obtain all input_ids.
  2. Use all_to_all operators before and after each mla layer for tp and dp conversion.
  3. Implement a custom all_to_all for dynamic input/output splitting (split_sizes) based on sglang allreduce IPC, currently limited to single-machine use.
  4. Process model output using dp_scatter to obtain results for the specified dp.

Checklist

@xu-yfei xu-yfei changed the title [DRFT] support dp mla [DRFT] support dp mla[2/2] Apr 2, 2025
@xu-yfei xu-yfei changed the title [DRFT] support dp mla[2/2] [DRAFT] support dp mla[2/2] Apr 2, 2025
@xu-yfei xu-yfei changed the title [DRAFT] support dp mla[2/2] [WIP] support dp mla[2/2] Apr 2, 2025
@xu-yfei xu-yfei changed the title [WIP] support dp mla[2/2] [WIP] [2/2] support dp mla Apr 2, 2025
Comment thread python/sglang/srt/server_args.py Outdated
@xu-yfei
Copy link
Copy Markdown
Contributor Author

xu-yfei commented Apr 2, 2025

The performance results of MLA in early March. The latest performance may have changed, but the pattern should be similar.

                flash infer mla                 deepseek flash mla
kv_len  batch   h 16    h 32    h 64    h 128   h 16    h 32    h 64    h 128  
128     1       15.7    16.8    20.6    20.8    19.1    20.0    21.1    21.3
128     2       16.2    17.0    20.6    21.1    19.1    20.1    21.2    21.2
128     4       16.4    17.2    21.2    24.3    19.6    20.1    21.2    21.8
128     8       18.7    20.0    24.3    22.4    19.7    20.4    21.9    24.3
128     16      21.3    21.5    22.8    23.7    23.1    23.5    24.6    26.0
128     32      22.6    22.6    24.4    25.0    23.9    24.8    26.4    28.8
128     64      23.2    24.0    26.2    39.0    25.5    27.0    29.7    48.4
128     128     36.4    37.1    40.2    65.6    41.2    44.1    50.2    87.1
128     256     61.0    62.8    67.8    104.8   73.2    78.9    90.3    146.9
128     512     98.0    99.4    107.1   194.9   120.7   130.5   150.8   266.7
1024    1       20.3    20.5    24.7    28.1    20.1    20.7    22.0    22.3
1024    2       23.5    25.4    28.1    34.8    20.9    21.7    22.9    27.7
1024    4       29.8    31.5    35.1    39.4    26.1    26.7    28.1    33.9
1024    8       34.9    36.1    39.8    40.6    32.1    32.9    34.7    46.1
1024    16      36.0    36.8    41.1    60.7    44.1    44.7    47.3    69.7
1024    32      55.1    57.3    61.8    93.9    64.7    67.1    70.7    104.0
1024    64      92.3    93.2    94.7    176.1   101.7   102.9   105.6   195.6
1024    128     172.9   174.1   177.4   339.5   185.6   189.2   197.2   340.1
1024    256     334.9   336.0   340.5   584.7   324.5   330.4   342.8   649.7
1024    512     577.9   579.4   587.2   1155.2  626.8   636.4   656.2   1252.5
4096    1       31.4    31.6    37.7    40.0    25.3    26.0    27.5    30.1
4096    2       36.3    36.3    41.0    41.4    29.0    29.6    30.9    39.9
4096    4       37.5    37.5    42.1    59.2    38.3    39.1    40.8    61.0
4096    8       55.1    57.4    59.8    99.1    59.2    60.4    61.9    100.5
4096    16      94.3    95.7    100.0   205.8   98.1    98.9    102.2   179.2
4096    32      196.5   199.3   205.9   391.3   173.6   176.3   180.5   334.0
4096    64      375.7   381.2   393.4   646.1   325.6   328.7   334.9   620.2
4096    128     643.7   646.2   646.8   1278.1  609.6   612.1   620.3   1188.1
4096    256     1272.5  1274.0  1282.3  2227.1  1169.4  1176.2  1193.1  2338.6
4096    512     2219.5  2221.5  2230.0  4446.7  2311.5  2321.5  2342.1  4628.5
8192    1       39.1    39.4    40.5    41.0    31.1    31.7    33.1    40.7
8192    2       40.7    40.4    41.6    60.7    39.4    40.2    41.5    57.0
8192    4       56.8    57.0    61.7    98.6    55.5    56.1    58.0    94.1
8192    8       94.2    96.3    99.4    180.6   93.2    93.9    95.6    170.7
8192    16      233.6   235.4   241.8   381.4   167.7   168.5   172.0   318.7
8192    32      372.1   375.4   384.1   765.5   313.0   315.7   319.3   617.8
8192    64      748.5   754.3   768.4   1271.6  610.3   613.3   620.7   1185.3
8192    128     1268.8  1271.4  1275.6  2529.5  1174.7  1179.6  1194.6  2313.5
8192    256     2528.4  2530.0  2535.4  4420.8  2299.4  2306.4  2312.7  4595.1
8192    512     4417.4  4422.1  4421.8  8828.5  4576.8  4586.7  4597.3  9146.2
16384   1       47.0    46.8    47.9    60.1    41.5    42.3    43.7    57.6
16384   2       60.2    60.5    61.4    99.6    56.4    57.1    58.7    94.5
16384   4       95.9    95.9    100.5   159.8   93.8    94.4    96.2    164.9
16384   8       176.2   177.8   182.2   338.2   162.9   163.8   165.5   311.3
16384   16      448.5   451.5   457.8   754.1   308.0   308.6   312.6   603.7
16384   32      744.0   748.5   758.2   1512.7  597.6   600.2   604.7   1182.5
16384   64      1490.4  1496.4  1513.0  2530.6  1171.0  1174.1  1185.2  2308.3
16384   128     2519.7  2521.0  2528.0  5039.7  2295.7  2301.5  2319.0  4564.7
16384   256     5027.5  5039.4  5047.6  8806.0  4543.7  4560.4  4561.0  9098.3
16384   512     8791.1  8793.9  8806.0  17598.9 9063.5  9074.8  9100.5  18139.4

@ispobock
Copy link
Copy Markdown
Collaborator

ispobock commented Apr 3, 2025

Could you add DP attention in the benchmarks?

@davidsyoung
Copy link
Copy Markdown

Attempted to load this with AWQ, got this error:

[2025-04-05 10:10:33 DP13 TP13] Scheduler hit an exception: Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1999, in run_scheduler_process
    scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 249, in __init__
    self.tp_worker = TpWorkerClass(
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 63, in __init__
    self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 74, in __init__
    self.model_runner = ModelRunner(
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 171, in __init__
    self.initialize(min_per_gpu_memory)
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 181, in initialize
    self.load_model()
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 385, in load_model
    self.model = get_model(
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/__init__.py", line 22, in get_model
    return loader.load_model(
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 370, in load_model
    model.load_weights(self._get_all_weights(model_config, model))
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 1632, in load_weights
    weigth_data = self_attn.kv_b_proj.weight.data
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1931, in __getattr__
    raise AttributeError(
AttributeError: 'ColumnParallelLinear' object has no attribute 'weight'. Did you mean: 'qweight'?

command:

python3 -m sglang.launch_server --model /models/wanzhenchn_DeepSeek-R1-AWQ/ --tp 16 --context-length 16384 --trust-remote-code  --mem-fraction-static 0.95 --chunked-prefill-size 2048 --max-running-requests 1  --max-total-tokens 2048 --enable-flashinfer-mla --enable-dp-mla --dp-size=16 --cuda-graph-max-bs 1 --flashinfer-mla-disable-ragged --disable-cuda-graph

16x3090. I'm possibly doing something wrong with commands etc. Or maybe it's just not available for awq yet. Thanks for any help.

@xu-yfei
Copy link
Copy Markdown
Contributor Author

xu-yfei commented Apr 7, 2025

Attempted to load this with AWQ, got this error:

[2025-04-05 10:10:33 DP13 TP13] Scheduler hit an exception: Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1999, in run_scheduler_process
    scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 249, in __init__
    self.tp_worker = TpWorkerClass(
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 63, in __init__
    self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 74, in __init__
    self.model_runner = ModelRunner(
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 171, in __init__
    self.initialize(min_per_gpu_memory)
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 181, in initialize
    self.load_model()
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 385, in load_model
    self.model = get_model(
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/__init__.py", line 22, in get_model
    return loader.load_model(
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 370, in load_model
    model.load_weights(self._get_all_weights(model_config, model))
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 1632, in load_weights
    weigth_data = self_attn.kv_b_proj.weight.data
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1931, in __getattr__
    raise AttributeError(
AttributeError: 'ColumnParallelLinear' object has no attribute 'weight'. Did you mean: 'qweight'?

command:

python3 -m sglang.launch_server --model /models/wanzhenchn_DeepSeek-R1-AWQ/ --tp 16 --context-length 16384 --trust-remote-code  --mem-fraction-static 0.95 --chunked-prefill-size 2048 --max-running-requests 1  --max-total-tokens 2048 --enable-flashinfer-mla --enable-dp-mla --dp-size=16 --cuda-graph-max-bs 1 --flashinfer-mla-disable-ragged --disable-cuda-graph

16x3090. I'm possibly doing something wrong with commands etc. Or maybe it's just not available for awq yet. Thanks for any help.

you can remove this code in deepseek_v2.py. However, the custom alltoall we currently implement is for a maximum of 8 devices, and the relevant kernel implementation needs to be changed to 16. Other codes related to the number of device may need to be adapted.

                if (
                    self_attn.enable_dp_mla
                ):  # Use kv_b_proj's space to avoid double memory
                    # 512, 128/tp*(128+128)
                    weigth_data = self_attn.kv_b_proj.weight.data
                    kv_stack = [
                        w_kc.to(torch.bfloat16),
                        w_vc.view(*w_kc.shape).to(torch.bfloat16),
                    ]
                    new_weight = torch.cat(kv_stack).to(w_kc.dtype).contiguous()
                    weigth_data.copy_(new_weight.view(weigth_data.shape))
                    weight_data = weigth_data.view(2, *w_kc.shape)
                    w_kc = weight_data[0].view(*w_kc.shape)
                    w_vc = weight_data[1].view(*w_vc.shape)

@xu-yfei xu-yfei changed the title [WIP] [2/2] support dp mla [2/2] support dp mla Apr 29, 2025
@xu-yfei
Copy link
Copy Markdown
Contributor Author

xu-yfei commented Apr 29, 2025

Could you add DP attention in the benchmarks?

in 8*H20(96GB), weight mem usage=87.19 GB when --dp-size 4 --enable-dp-attention, not enough memory left

@xu-yfei xu-yfei force-pushed the mla_dp branch 2 times, most recently from dfa4a39 to d6544f5 Compare April 30, 2025 06:50
@xu-yfei xu-yfei requested a review from ch-wan as a code owner April 30, 2025 06:50
@xu-yfei xu-yfei force-pushed the mla_dp branch 4 times, most recently from d0d455b to 1d569b5 Compare May 12, 2025 06:21
@jokerwyt
Copy link
Copy Markdown
Contributor

jokerwyt commented Jun 6, 2025

@xu-yfei
Hi, I cannot compile your sgl-kernel even I merged #5000 into the branch mla_dp. Could you please help me fix that?

I use the image lmsysorg/sglang:dev. I suspect the problem is related to recent changes in CUTLASS version. The sgl-kernel in main branch compiles smoothly.

@xu-yfei
Copy link
Copy Markdown
Contributor Author

xu-yfei commented Jun 9, 2025

Could you provide the specific error in detail?

@xu-yfei Hi, I cannot compile your sgl-kernel even I merged #5000 into the branch mla_dp. Could you please help me fix that?

I use the image lmsysorg/sglang:dev. I suspect the problem is related to recent changes in CUTLASS version. The sgl-kernel in main branch compiles smoothly.

@jokerwyt
Copy link
Copy Markdown
Contributor

jokerwyt commented Jun 9, 2025

@xu-yfei

For example:

-- Generating done (0.0s)
-- Build files have been written to: /tmp/sglang/sgl-kernel/build
*** Building project with Ninja...
[1/210] Building CUDA object CMakeFiles/flash_ops.dir/_deps/repo-flash-attention-src/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu.o
FAILED: CMakeFiles/flash_ops.dir/_deps/repo-flash-attention-src/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu.o 
/usr/local/cuda/bin/nvcc -forward-unknown-to-host-compiler -DFLASHATTENTION_DISABLE_BACKWARD -DFLASHATTENTION_DISABLE_DROPOUT -DFLASHATTENTION_DISABLE_UNEVEN_K -DFLASHATTENTION_VARLEN_ONLY -DPy_LIMITED_API=0x03090000 -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_DISTRIBUTED -DUSE_RPC -DUSE_TENSORPIPE -Dflash_ops_EXPORTS -I/tmp/sglang/sgl-kernel/include -I/tmp/sglang/sgl-kernel/csrc -I/tmp/sglang/sgl-kernel/build/_deps/repo-cutlass-src/include -I/tmp/sglang/sgl-kernel/build/_deps/repo-cutlass-src/tools/util/include -I/tmp/sglang/sgl-kernel/build/_deps/repo-flashinfer-src/include -I/tmp/sglang/sgl-kernel/build/_deps/repo-flashinfer-src/csrc -I/tmp/sglang/sgl-kernel/build/_deps/repo-flash-attention-src/hopper -isystem /root/miniconda3/include/python3.12 -isystem /usr/local/lib/python3.10/dist-packages/torch/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda/targets/x86_64-linux/include -DONNX_NAMESPACE=onnx_c2 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda -O3 -DNDEBUG -std=c++17 -Xcompiler=-fPIC -DNDEBUG -DOPERATOR_NAMESPACE=sgl-kernel -O3 -Xcompiler -fPIC -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_90a,code=sm_90a -std=c++17 -DCUTE_USE_PACKED_TUPLE=1 -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1 -DCUTLASS_VERSIONS_GENERATED -DCUTLASS_TEST_LEVEL=0 -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1 -DCUTLASS_DEBUG_TRACE_LEVEL=0 --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math -Xcompiler=-Wconversion -Xcompiler=-fno-strict-aliasing -D_GLIBCXX_USE_CXX11_ABI=0 -MD -MT CMakeFiles/flash_ops.dir/_deps/repo-flash-attention-src/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu.o -MF CMakeFiles/flash_ops.dir/_deps/repo-flash-attention-src/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu.o.d -x cu -c /tmp/sglang/sgl-kernel/build/_deps/repo-flash-attention-src/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu -o CMakeFiles/flash_ops.dir/_deps/repo-flash-attention-src/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu.o
/tmp/sglang/sgl-kernel/build/_deps/repo-flash-attention-src/hopper/epilogue_fwd.hpp(95): error: no instance of function template "cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator" matches the argument list
          decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator<StrideO, Element, TileShape_MNK_PV>()),
                   ^
/tmp/sglang/sgl-kernel/build/_deps/repo-cutlass-src/include/cutlass/epilogue/collective/builders/sm90_common.inl(42): note #3323-D: substituting explicit template arguments "<cute::tuple<int64_t, cute::_1, int64_t, int64_t, int64_t>, cutlass::half_t, cute::tuple<cute::C<128>, cute::C<256>, cute::C<112>>>" for function template "cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator" failed
  sm90_get_smem_store_op_for_accumulator() {

@xu-yfei
Copy link
Copy Markdown
Contributor Author

xu-yfei commented Jun 9, 2025

@xu-yfei

For example:

-- Generating done (0.0s)
-- Build files have been written to: /tmp/sglang/sgl-kernel/build
*** Building project with Ninja...
[1/210] Building CUDA object CMakeFiles/flash_ops.dir/_deps/repo-flash-attention-src/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu.o
FAILED: CMakeFiles/flash_ops.dir/_deps/repo-flash-attention-src/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu.o 
/usr/local/cuda/bin/nvcc -forward-unknown-to-host-compiler -DFLASHATTENTION_DISABLE_BACKWARD -DFLASHATTENTION_DISABLE_DROPOUT -DFLASHATTENTION_DISABLE_UNEVEN_K -DFLASHATTENTION_VARLEN_ONLY -DPy_LIMITED_API=0x03090000 -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_DISTRIBUTED -DUSE_RPC -DUSE_TENSORPIPE -Dflash_ops_EXPORTS -I/tmp/sglang/sgl-kernel/include -I/tmp/sglang/sgl-kernel/csrc -I/tmp/sglang/sgl-kernel/build/_deps/repo-cutlass-src/include -I/tmp/sglang/sgl-kernel/build/_deps/repo-cutlass-src/tools/util/include -I/tmp/sglang/sgl-kernel/build/_deps/repo-flashinfer-src/include -I/tmp/sglang/sgl-kernel/build/_deps/repo-flashinfer-src/csrc -I/tmp/sglang/sgl-kernel/build/_deps/repo-flash-attention-src/hopper -isystem /root/miniconda3/include/python3.12 -isystem /usr/local/lib/python3.10/dist-packages/torch/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda/targets/x86_64-linux/include -DONNX_NAMESPACE=onnx_c2 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda -O3 -DNDEBUG -std=c++17 -Xcompiler=-fPIC -DNDEBUG -DOPERATOR_NAMESPACE=sgl-kernel -O3 -Xcompiler -fPIC -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_90a,code=sm_90a -std=c++17 -DCUTE_USE_PACKED_TUPLE=1 -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1 -DCUTLASS_VERSIONS_GENERATED -DCUTLASS_TEST_LEVEL=0 -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1 -DCUTLASS_DEBUG_TRACE_LEVEL=0 --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math -Xcompiler=-Wconversion -Xcompiler=-fno-strict-aliasing -D_GLIBCXX_USE_CXX11_ABI=0 -MD -MT CMakeFiles/flash_ops.dir/_deps/repo-flash-attention-src/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu.o -MF CMakeFiles/flash_ops.dir/_deps/repo-flash-attention-src/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu.o.d -x cu -c /tmp/sglang/sgl-kernel/build/_deps/repo-flash-attention-src/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu -o CMakeFiles/flash_ops.dir/_deps/repo-flash-attention-src/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu.o
/tmp/sglang/sgl-kernel/build/_deps/repo-flash-attention-src/hopper/epilogue_fwd.hpp(95): error: no instance of function template "cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator" matches the argument list
          decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator<StrideO, Element, TileShape_MNK_PV>()),
                   ^
/tmp/sglang/sgl-kernel/build/_deps/repo-cutlass-src/include/cutlass/epilogue/collective/builders/sm90_common.inl(42): note #3323-D: substituting explicit template arguments "<cute::tuple<int64_t, cute::_1, int64_t, int64_t, int64_t>, cutlass::half_t, cute::tuple<cute::C<128>, cute::C<256>, cute::C<112>>>" for function template "cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator" failed
  sm90_get_smem_store_op_for_accumulator() {

May depend on an older version of repo-flash-attention, not the latest one. such as GIT_TAG bf1e4ce51fc85370b9489976e8ccbe72eb71fefe

@xu-yfei
Copy link
Copy Markdown
Contributor Author

xu-yfei commented Jun 11, 2025

@xu-yfei

For example:

-- Generating done (0.0s)
-- Build files have been written to: /tmp/sglang/sgl-kernel/build
*** Building project with Ninja...
[1/210] Building CUDA object CMakeFiles/flash_ops.dir/_deps/repo-flash-attention-src/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu.o
FAILED: CMakeFiles/flash_ops.dir/_deps/repo-flash-attention-src/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu.o 
/usr/local/cuda/bin/nvcc -forward-unknown-to-host-compiler -DFLASHATTENTION_DISABLE_BACKWARD -DFLASHATTENTION_DISABLE_DROPOUT -DFLASHATTENTION_DISABLE_UNEVEN_K -DFLASHATTENTION_VARLEN_ONLY -DPy_LIMITED_API=0x03090000 -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_DISTRIBUTED -DUSE_RPC -DUSE_TENSORPIPE -Dflash_ops_EXPORTS -I/tmp/sglang/sgl-kernel/include -I/tmp/sglang/sgl-kernel/csrc -I/tmp/sglang/sgl-kernel/build/_deps/repo-cutlass-src/include -I/tmp/sglang/sgl-kernel/build/_deps/repo-cutlass-src/tools/util/include -I/tmp/sglang/sgl-kernel/build/_deps/repo-flashinfer-src/include -I/tmp/sglang/sgl-kernel/build/_deps/repo-flashinfer-src/csrc -I/tmp/sglang/sgl-kernel/build/_deps/repo-flash-attention-src/hopper -isystem /root/miniconda3/include/python3.12 -isystem /usr/local/lib/python3.10/dist-packages/torch/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda/targets/x86_64-linux/include -DONNX_NAMESPACE=onnx_c2 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda -O3 -DNDEBUG -std=c++17 -Xcompiler=-fPIC -DNDEBUG -DOPERATOR_NAMESPACE=sgl-kernel -O3 -Xcompiler -fPIC -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_90a,code=sm_90a -std=c++17 -DCUTE_USE_PACKED_TUPLE=1 -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1 -DCUTLASS_VERSIONS_GENERATED -DCUTLASS_TEST_LEVEL=0 -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1 -DCUTLASS_DEBUG_TRACE_LEVEL=0 --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math -Xcompiler=-Wconversion -Xcompiler=-fno-strict-aliasing -D_GLIBCXX_USE_CXX11_ABI=0 -MD -MT CMakeFiles/flash_ops.dir/_deps/repo-flash-attention-src/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu.o -MF CMakeFiles/flash_ops.dir/_deps/repo-flash-attention-src/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu.o.d -x cu -c /tmp/sglang/sgl-kernel/build/_deps/repo-flash-attention-src/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu -o CMakeFiles/flash_ops.dir/_deps/repo-flash-attention-src/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu.o
/tmp/sglang/sgl-kernel/build/_deps/repo-flash-attention-src/hopper/epilogue_fwd.hpp(95): error: no instance of function template "cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator" matches the argument list
          decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator<StrideO, Element, TileShape_MNK_PV>()),
                   ^
/tmp/sglang/sgl-kernel/build/_deps/repo-cutlass-src/include/cutlass/epilogue/collective/builders/sm90_common.inl(42): note #3323-D: substituting explicit template arguments "<cute::tuple<int64_t, cute::_1, int64_t, int64_t, int64_t>, cutlass::half_t, cute::tuple<cute::C<128>, cute::C<256>, cute::C<112>>>" for function template "cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator" failed
  sm90_get_smem_store_op_for_accumulator() {

This pr has rebased with newest main

@xu-yfei xu-yfei closed this Jun 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants