Skip to content

Add a way to disable compile for debugging flex-attention#158534

Closed
drisspg wants to merge 3 commits intogh/drisspg/171/basefrom
gh/drisspg/171/head
Closed

Add a way to disable compile for debugging flex-attention#158534
drisspg wants to merge 3 commits intogh/drisspg/171/basefrom
gh/drisspg/171/head

Conversation

@drisspg
Copy link
Contributor

@drisspg drisspg commented Jul 17, 2025

Stack from ghstack (oldest at bottom):

Finally got around to doing this, this flag lets us do:

#!/usr/bin/env python3
"""
FlexAttention Debug: Using breakpoints and unwrap
"""

import torch
import torch.nn.attention.flex_attention as fa

unwrap = torch._C._functorch.get_unwrapped

def score_mod(score, batch, head, q_idx, kv_idx):
    # Set breakpoint here to debug
    breakpoint()
    
    # In debugger, unwrap to see actual tensor values:
    # >>> actual_score = unwrap(unwrap(unwrap(unwrap(score))))
    # >>> actual_batch = unwrap(batch)
    # >>> actual_head = unwrap(head)
    # >>> actual_q_idx = unwrap(q_idx)
    # >>> actual_kv_idx = unwrap(kv_idx)
    # >>> print(actual_score)
    # >>> print(f"q_idx: {actual_q_idx}, kv_idx: {actual_kv_idx}")
    
    return torch.where(q_idx >= kv_idx, score, torch.tensor(float('-inf')))

def main():
    # Enable debug mode
    fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = True
    
    # Small example
    B, H, S, D = 1, 2, 4, 8
    q = torch.randn(B, H, S, D)
    k = torch.randn(B, H, S, D)
    v = torch.randn(B, H, S, D)
    
    # Run - will hit breakpoint
    output = fa.flex_attention(q, k, v, score_mod=score_mod)
    
    # Disable debug mode
    fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = False

if __name__ == "__main__":
    main()

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @Chillee @yanboliang @BoyuanFeng

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158534

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 6b42a9e with merge base 1839e8d (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

drisspg added a commit that referenced this pull request Jul 17, 2025
@drisspg drisspg added the topic: not user facing topic category label Jul 17, 2025
@drisspg drisspg requested review from BoyuanFeng, Chillee and zou3519 and removed request for albanD, jbschlosser and mikaylagawarecki July 17, 2025 04:08
Copy link
Collaborator

@Chillee Chillee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice if we could have an error message that suggests that people try this.

[ghstack-poisoned]
drisspg added a commit that referenced this pull request Jul 18, 2025
[ghstack-poisoned]
@drisspg
Copy link
Contributor Author

drisspg commented Jul 18, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 18, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/drisspg/171/head branch August 18, 2025 02:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants