Skip to content

Fix Hiera global attention to use 4D tensors for efficient SDPA dispatch#2680

Merged
rwightman merged 1 commit intohuggingface:mainfrom
Raiden129:fix/hiera-flash-attention-global-4d
Mar 9, 2026
Merged

Fix Hiera global attention to use 4D tensors for efficient SDPA dispatch#2680
rwightman merged 1 commit intohuggingface:mainfrom
Raiden129:fix/hiera-flash-attention-global-4d

Conversation

@Raiden129
Copy link
Copy Markdown

Problem

Hiera's MaskUnitAttention.forward() takes a shortcut for global attention by setting num_windows=1 and funneling everything through the same 5D reshape as windowed attention. This produces Q/K/V tensors shaped [B, heads, 1, N, head_dim] (5D, non-contiguous after permute).

PyTorch's F.scaled_dot_product_attention silently rejects these for all efficient backends (FlashAttention, Memory-Efficient, CuDNN) and falls back to the math backend, which materializes the full N x N attention matrix. At high resolutions (e.g. 2048x2048, 16384 tokens), this allocates several GB of VRAM per layer and OOMs on consumer GPUs.

Fix

Branch the forward pass so that:

  • Windowed attention (use_mask_unit_attn=True): Unchanged 5D path.
  • Global attention (use_mask_unit_attn=False): Reshapes directly to 4D [B, heads, N, head_dim], calls .contiguous() on Q/K/V, and adjusts downstream indexing (amax dim, output transpose) accordingly.

This is a minimal, non-breaking change. The mathematical output is identical to the original implementation, but SDPA can now dispatch to efficient O(N) kernels instead of the O(N^2) fallback.

Changes

  • Branch forward() into windowed (5D) and global (4D) paths
  • Reshape global QKV directly to [B, N, 3, heads, head_dim] and permute to 4D
  • Adjust q_stride pooling from amax(dim=3) to amax(dim=2) on the global path
  • Enforce .contiguous() on Q, K, V for the global path
  • Adapt output transpose: transpose(1, 3) for windowed, transpose(1, 2) for global

…ntion dispatch

The global attention path in MaskUnitAttention.forward() used a 5D tensor
reshape with num_windows=1 as a shortcut. This caused PyTorch SDPA to
silently fall back from efficient backends (FlashAttention, Memory-Efficient,
CuDNN) to the O(N^2) math backend, as all efficient kernels require 4D
contiguous tensors.

At high resolutions (e.g. 2048x2048 -> 16384 tokens), the math backend
materializes the full N*N attention matrix, causing catastrophic VRAM usage
and OOM on consumer GPUs.

Changes:
- Branch forward() into windowed (5D, unchanged) and global (4D) paths
- Global path reshapes directly to [B, N, 3, heads, head_dim] -> 4D QKV
- Adjust q_stride pooling dim from amax(dim=3) to amax(dim=2) for global
- Add .contiguous() on q, k, v to guarantee FlashAttention compatibility
- Split output transpose: transpose(1,3) for windowed, transpose(1,2) for global
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@rwightman rwightman merged commit 0c90043 into huggingface:main Mar 9, 2026
22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants