@@ -404,14 +404,24 @@ def _gather_hidden_states_and_residual(
404404 if context .attn_dp_size != 1 :
405405 if context .attn_tp_rank == 0 :
406406 hidden_states += residual
407+
408+ # Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size)
409+ use_layer_norm_before_gather = context .attn_tp_size == 1
410+ if use_layer_norm_before_gather :
411+ residual .copy_ (hidden_states )
412+ if hidden_states .shape [0 ] != 0 :
413+ hidden_states = layernorm (hidden_states )
414+
407415 hidden_states , local_hidden_states = (
408416 forward_batch .gathered_buffer ,
409417 hidden_states ,
410418 )
411419 dp_gather_partial (hidden_states , local_hidden_states , forward_batch )
412- dp_scatter (residual , hidden_states , forward_batch )
413- if hidden_states .shape [0 ] != 0 :
414- hidden_states = layernorm (hidden_states )
420+
421+ if not use_layer_norm_before_gather :
422+ dp_scatter (residual , hidden_states , forward_batch )
423+ if hidden_states .shape [0 ] != 0 :
424+ hidden_states = layernorm (hidden_states )
415425 else :
416426 # According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
417427 # We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
0 commit comments