feat: add sageattention#2823
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the 📝 WalkthroughWalkthroughSupport for SageAttention, a new attention implementation, has been integrated. This includes configuration schema updates, a monkeypatch for Hugging Face transformers, conditional patch application logic, and internal model loader changes to select SageAttention. Validation prevents incompatible use with sample packing and enforces GPU compute capability requirements. No public APIs were changed; all modifications are internal or configuration-related. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Config
participant ModelLoader
participant PatchManager
participant Transformers
participant SageAttention
User->>Config: Set sage_attention=True
Config->>Config: Validate config (disallow sample_packing + sage_attention, check GPU capability)
ModelLoader->>Config: Read sage_attention flag
ModelLoader->>PatchManager: Apply SageAttention patch if enabled
PatchManager->>Transformers: Register sage_attention_forward
ModelLoader->>Transformers: Set attn_implementation="sage_attention"
Transformers->>SageAttention: Use SageAttention for attention calls
Poem
✨ Finishing touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
src/axolotl/monkeypatch/attention/sageattn.py (1)
41-112: Well-implemented attention forward function with minor style improvements.The function correctly handles SageAttention's limitations, GQA/MQA support, and tensor layout transformations. The extensive validation ensures clear error messages for unsupported features.
Apply these minor style improvements suggested by static analysis:
if ( kwargs.get("output_attentions", False) - or kwargs.get("head_mask", None) is not None + or kwargs.get("head_mask") is not None ): - if kwargs.get("position_ids", None) is not None: + if kwargs.get("position_ids") is not None:
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/axolotl/loaders/model.py(1 hunks)src/axolotl/loaders/patch_manager.py(1 hunks)src/axolotl/monkeypatch/attention/sageattn.py(1 hunks)src/axolotl/utils/schemas/config.py(2 hunks)
🧰 Additional context used
🪛 Ruff (0.11.9)
src/axolotl/monkeypatch/attention/sageattn.py
60-60: Use kwargs.get("head_mask") instead of kwargs.get("head_mask", None)
Replace kwargs.get("head_mask", None) with kwargs.get("head_mask")
(SIM910)
77-77: Use kwargs.get("position_ids") instead of kwargs.get("position_ids", None)
Replace kwargs.get("position_ids", None) with kwargs.get("position_ids")
(SIM910)
⏰ Context from checks skipped due to timeout of 90000ms (9)
- GitHub Check: preview
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: pre-commit
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: pre-commit
🔇 Additional comments (5)
src/axolotl/loaders/model.py (1)
550-554: LGTM! Consistent attention implementation pattern.The SageAttention integration follows the same pattern as other attention implementations and correctly sets both the model kwargs and config attributes.
src/axolotl/utils/schemas/config.py (2)
497-502: LGTM! Well-documented configuration field.The SageAttention configuration field follows the established pattern and includes a helpful description with a link to the source repository.
886-894: Good validation logic for incompatible features.The validator correctly prevents using SageAttention with sample packing, which aligns with the current limitations documented in the monkeypatch implementation.
src/axolotl/monkeypatch/attention/sageattn.py (2)
17-39: LGTM! Robust availability checking and import logic.The implementation properly handles the optional dependency with clear error messages and appropriate fallback behavior.
115-122: LGTM! Clean patch registration.The patch function properly checks for SageAttention availability and registers the forward function with transformers.
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (2)
src/axolotl/monkeypatch/attention/sageattn.py (2)
32-39: Consider making installation instructions more maintainable.The hardcoded commit hash in the installation command may become outdated. Consider either:
- Using a version tag instead of a commit hash
- Adding a comment to remind maintainers to keep this updated
- Referring to the official installation documentation
- "`pip install git+https://github.com/thu-ml/SageAttention.git@1718ddc06dbc694bcf3c6b49ac28c1921aa2d8bd`" + "`pip install git+https://github.com/thu-ml/SageAttention.git` or follow installation instructions at https://github.com/thu-ml/SageAttention/"
58-82: Address static analysis hints and document significant limitations.The error handling is thorough, but there are two style improvements suggested by static analysis tools, and the limitations should be prominently documented.
Apply these style improvements:
- or kwargs.get("head_mask", None) is not None + or kwargs.get("head_mask") is not None- if kwargs.get("position_ids", None) is not None: + if kwargs.get("position_ids") is not None:Important: The lack of support for
attention_maskandposition_idssignificantly limits this integration's applicability. Consider adding a prominent warning in the docstring about these constraints.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/axolotl/monkeypatch/attention/sageattn.py(1 hunks)
🧰 Additional context used
🪛 Ruff (0.11.9)
src/axolotl/monkeypatch/attention/sageattn.py
60-60: Use kwargs.get("head_mask") instead of kwargs.get("head_mask", None)
Replace kwargs.get("head_mask", None) with kwargs.get("head_mask")
(SIM910)
77-77: Use kwargs.get("position_ids") instead of kwargs.get("position_ids", None)
Replace kwargs.get("position_ids", None) with kwargs.get("position_ids")
(SIM910)
⏰ Context from checks skipped due to timeout of 90000ms (9)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: pre-commit
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: pre-commit
- GitHub Check: preview
🔇 Additional comments (3)
src/axolotl/monkeypatch/attention/sageattn.py (3)
1-30: LGTM! Well-structured conditional import pattern.The import section properly handles the optional SageAttention dependency with clear documentation and a standard conditional import pattern.
115-123: LGTM! Proper integration with transformers attention registry.The patch function correctly registers SageAttention with transformers' global attention function registry following the established pattern.
83-112: Verify causal mask inference logic and confirm tensor layout assumptions.The GQA/MQA handling and tensor operations look correct, but the causal mask inference should be verified.
Please verify that the causal mask inference logic matches transformers' behavior:
#!/bin/bash # Search for similar causal mask inference patterns in transformers codebase rg -A 5 -B 5 "is_causal.*query.*shape" --type py rg -A 5 -B 5 "getattr.*is_causal" --type pyThe tensor layout conversion from "HND" (batch, heads, seq_len, dim) to transformers format (batch, seq_len, heads, dim) using
transpose(1, 2)appears mathematically correct.
|
@NanoCode012 what's the sage vs flash attn VRAM usage? |
|
@winglian , weirdly not getting vram savings as in benchmarks. Current early wandb result show that: about 20% faster with same vram usage. However, kernel benchmarking showed it using less vram (when <32k context at least). More runs needs to be done still. |
|
Updated PR from main and added more validation/docs on attention. It is a bit faster than FA for adapter mode. I added warning that this is not recommended for FFT due to unstable loss. I did not add test as I didn't want to install another module by default. Edit Feb 2026: We will merge this PR as it is an external dependency to let users try it out. We understand that the metrics are not stable and are always open for help fixing it. |
|
📖 Documentation Preview: https://6912b00334daf4d5637097f2--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit 12a2f62 |
Description
Adds SageAttention https://github.com/thu-ml/SageAttention/
Since it has similar interface as sdpa_attention, I used that implementation and flash attention in transformers to cross check.
Motivation and Context
How has this been tested?
No test yet!
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit