make flash-attn optional for Ume, enable CPU-only inference and testing#90
make flash-attn optional for Ume, enable CPU-only inference and testing#90
Conversation
ncfrey
commented
May 23, 2025
- Add use_flash_attn parameter to Ume and FlexBERT to control flash-attn usage.
- Fallback to padded attention if flash-attn is unavailable, enabling CPU-only operation.
- Fix input shape handling in Ume for both padded and unpadded attention modes.
- Add a test for Ume inference on CPU without flash-attn.
- All existing tests pass, ensuring backward compatibility.
… testing\n\n- Add use_flash_attn param to Ume and FlexBERT\n- Fallback to padded attention if flash-attn is unavailable\n- Fix input shape handling for padded/unpadded attention\n- Add test for CPU inference without flash-attn\n- All existing tests pass
There was a problem hiding this comment.
Pull Request Overview
This PR makes flash-attn optional in Ume and FlexBERT, adds a CPU-only fallback for attention, fixes input shape handling in Ume’s embed method, and adds a test for Ume inference on CPU.
- Introduce
use_flash_attnparameter to Ume and propagate to FlexBERT viause_fa2 - Fallback to padded attention when flash-attn is unavailable and adjust embed input shape handling
- Add
test_embed_sequences_cputo verify CPU-only embed_sequences functionality
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| tests/lobster/model/test__ume.py | Add test_embed_sequences_cpu to cover embed_sequences without flash-attn |
| src/lobster/model/modern_bert/_modern_bert.py | Warn and disable flash-attn fallback when unavailable; remove strict assert |
| src/lobster/model/_ume.py | Add use_flash_attn param, set padding mode, and update embed shape handling |
Comments suppressed due to low confidence (2)
src/lobster/model/_ume.py:357
- The unpadded attention code path in
embedis not covered by existing tests. Consider adding a test foruse_flash_attn=True(when flash-attn is available) to exercise this branch and verify output shapes.
input_ids, attention_mask, cu_seqlens = self.model._prepare_inputs(x["input_ids"], x["attention_mask"])
src/lobster/model/modern_bert/_modern_bert.py:121
- [nitpick] The warning refers to the internal
use_fa2flag; consider referencing the user-facinguse_flash_attnparameter in the message for clearer guidance.
warnings.warn("flash_attn not available but use_fa2=True. Setting use_fa2=False. "
|
changes in the _ume.py look good to me though I don't see any changed files for modern_bert? |
|
realized this when testing the code out: we might want to enable passing I think the most common scenario is that we train a model with FA and then run inference without it. For that, we might need to override lightning's load_from_checkpoint to pass it there? |
edit: you actually already handle this when cpu is not installed (though it still might be nice to optionally disable it when loading checkpoints even when it's installed in the environment, I opened a similar issue last week for ESM) But that being said, it looks like I still get the same error when trying to embed sequences without FA? |