Skip to content

make flash-attn optional for Ume, enable CPU-only inference and testing#90

Merged
ncfrey merged 6 commits intomainfrom
feat/optional-flash-attn-cpu-support
May 27, 2025
Merged

make flash-attn optional for Ume, enable CPU-only inference and testing#90
ncfrey merged 6 commits intomainfrom
feat/optional-flash-attn-cpu-support

Conversation

@ncfrey
Copy link
Contributor

@ncfrey ncfrey commented May 23, 2025

  • Add use_flash_attn parameter to Ume and FlexBERT to control flash-attn usage.
  • Fallback to padded attention if flash-attn is unavailable, enabling CPU-only operation.
  • Fix input shape handling in Ume for both padded and unpadded attention modes.
  • Add a test for Ume inference on CPU without flash-attn.
  • All existing tests pass, ensuring backward compatibility.

… testing\n\n- Add use_flash_attn param to Ume and FlexBERT\n- Fallback to padded attention if flash-attn is unavailable\n- Fix input shape handling for padded/unpadded attention\n- Add test for CPU inference without flash-attn\n- All existing tests pass
@ncfrey ncfrey requested a review from Copilot May 23, 2025 16:36
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR makes flash-attn optional in Ume and FlexBERT, adds a CPU-only fallback for attention, fixes input shape handling in Ume’s embed method, and adds a test for Ume inference on CPU.

  • Introduce use_flash_attn parameter to Ume and propagate to FlexBERT via use_fa2
  • Fallback to padded attention when flash-attn is unavailable and adjust embed input shape handling
  • Add test_embed_sequences_cpu to verify CPU-only embed_sequences functionality

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.

File Description
tests/lobster/model/test__ume.py Add test_embed_sequences_cpu to cover embed_sequences without flash-attn
src/lobster/model/modern_bert/_modern_bert.py Warn and disable flash-attn fallback when unavailable; remove strict assert
src/lobster/model/_ume.py Add use_flash_attn param, set padding mode, and update embed shape handling
Comments suppressed due to low confidence (2)

src/lobster/model/_ume.py:357

  • The unpadded attention code path in embed is not covered by existing tests. Consider adding a test for use_flash_attn=True (when flash-attn is available) to exercise this branch and verify output shapes.
input_ids, attention_mask, cu_seqlens = self.model._prepare_inputs(x["input_ids"], x["attention_mask"])

src/lobster/model/modern_bert/_modern_bert.py:121

  • [nitpick] The warning refers to the internal use_fa2 flag; consider referencing the user-facing use_flash_attn parameter in the message for clearer guidance.
warnings.warn("flash_attn not available but use_fa2=True. Setting use_fa2=False. "

@ncfrey ncfrey temporarily deployed to test.pypi.org May 23, 2025 16:48 — with GitHub Actions Inactive
@ncfrey ncfrey requested a review from karinazad May 23, 2025 16:53
@ncfrey ncfrey temporarily deployed to test.pypi.org May 23, 2025 16:53 — with GitHub Actions Inactive
@ncfrey ncfrey temporarily deployed to test.pypi.org May 23, 2025 17:01 — with GitHub Actions Inactive
@ncfrey ncfrey temporarily deployed to test.pypi.org May 23, 2025 17:12 — with GitHub Actions Inactive
@ncfrey ncfrey temporarily deployed to test.pypi.org May 23, 2025 18:04 — with GitHub Actions Inactive
@karinazad
Copy link
Collaborator

changes in the _ume.py look good to me though I don't see any changed files for modern_bert?

@ncfrey ncfrey marked this pull request as ready for review May 23, 2025 21:21
@karinazad
Copy link
Collaborator

realized this when testing the code out: we might want to enable passing use_flash_attn=False to load_from_checkpoint.

I think the most common scenario is that we train a model with FA and then run inference without it. For that, we might need to override lightning's load_from_checkpoint to pass it there?

@karinazad
Copy link
Collaborator

karinazad commented May 24, 2025

realized this when testing the code out: we might want to enable passing use_flash_attn=False to load_from_checkpoint.

I think the most common scenario is that we train a model with FA and then run inference without it. For that, we might need to override lightning's load_from_checkpoint to pass it there?

edit: you actually already handle this when cpu is not installed

/Users/zadorozk/Desktop/code/lobster/src/lobster/model/modern_bert/_modern_bert.py:122: UserWarning: flash_attn not available but use_fa2=True. Setting use_fa2=False. This will use standard attention instead of flash-attn.

(though it still might be nice to optionally disable it when loading checkpoints even when it's installed in the environment, I opened a similar issue last week for ESM)

But that being said, it looks like I still get the same error when trying to embed sequences without FA?

>>> model.embed_sequences(["AA"], "amino_acid")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/zadorozk/Desktop/code/lobster/src/lobster/model/_ume.py", line 428, in embed_sequences
    return self.embed(encoded, aggregate=aggregate)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/zadorozk/Desktop/code/lobster/src/lobster/model/_ume.py", line 364, in embed
    embeddings = self.model.model(
                 ^^^^^^^^^^^^^^^^^
...
  File "/Users/zadorozk/Desktop/code/lobster/src/lobster/model/modern_bert/_padding.py", line 140, in pad_input
    output = index_put_first_axis(hidden_states, indices, batch * seqlen)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/zadorozk/Desktop/code/lobster/.venv/lib/python3.11/site-packages/torch/autograd/function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/zadorozk/Desktop/code/lobster/src/lobster/model/modern_bert/_padding.py", line 63, in forward
    assert indices.ndim == 1
           ^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'ndim'

@ncfrey ncfrey merged commit 0a122ca into main May 27, 2025
5 checks passed
@ncfrey ncfrey deleted the feat/optional-flash-attn-cpu-support branch May 27, 2025 14:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants