Skip to content

cpu attn mask#149

Merged
ncfrey merged 4 commits intomainfrom
n/cpu-attn-mask
Jul 15, 2025
Merged

cpu attn mask#149
ncfrey merged 4 commits intomainfrom
n/cpu-attn-mask

Conversation

@ncfrey
Copy link
Contributor

@ncfrey ncfrey commented Jul 15, 2025

Description

This pull request introduces significant improvements to the embedding computation in the UME model and adds comprehensive tests to ensure consistency across different configurations. The changes primarily address the handling of padding tokens during mean pooling and introduce new test cases to validate the consistency of embeddings generated by the model across devices and configurations.

Improvements to embedding computation:

  • src/lobster/model/_ume.py: Updated the embedding computation logic to use an attention mask for excluding padding tokens from mean pooling. This ensures accurate aggregation by masking out padding tokens, preventing them from affecting the computed embeddings. Additionally, safeguards were added to handle mismatched sequence lengths between the attention mask and embeddings.

Enhanced test coverage:

  • tests/lobster/model/test__ume.py: Added a new test method test_flash_attention_consistency_across_devices to validate the consistency of embeddings generated with flash attention (GPU) and non-flash attention (CPU). The test includes various modalities (e.g., amino acid, nucleotide, SMILES) and checks cosine similarity to ensure embeddings are nearly identical (>99.9% similarity). It also tests consistency between padded and unpadded architectures.

Type of Change

  • Bug fix
  • New feature
  • Documentation update
  • Performance improvement
  • Code refactoring

Testing

  • Tests pass locally
  • Added new tests for new functionality
  • Updated existing tests if needed

Checklist

  • Code follows style guidelines
  • Self-review completed
  • Documentation updated if needed
  • No breaking changes (or clearly documented)

@ncfrey ncfrey requested a review from karinazad July 15, 2025 18:02

assert embeddings_flash_agg.shape == embeddings_no_flash_agg.shape

def test_flash_attention_consistency_across_devices(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious: what specifically was missing in the previous tests that they didn't catch this bug?

# Check similarity
cosine_sim = torch.nn.functional.cosine_similarity(embeddings_gpu, embeddings_cpu, dim=1)

print(f" Cosine similarities: {cosine_sim.tolist()}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might want to remove print statements from tests before merging?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'll replace with logging statements

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here's what Claude thinks

For logging statements in Python tests, here's my guidance:
Avoid logging statements in test code itself. Tests should be clean and focused on verification, not producing logs.

attention_mask = attention_mask.squeeze(1) # Remove middle dimension: (batch, seq_len)

# Ensure attention mask matches embedding sequence length
if attention_mask.shape[1] != embeddings.shape[1]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how does it happen that the attention mask's shape is not the same as embeddings?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the attention mask is from the unpadded inputs for flash-attn, but the sdpa embeddings are padded so the attention mask has to be updated to match

@ncfrey ncfrey requested a review from taylormjs July 15, 2025 18:15
@ncfrey ncfrey temporarily deployed to test.pypi.org July 15, 2025 18:15 — with GitHub Actions Inactive
@ncfrey ncfrey temporarily deployed to test.pypi.org July 15, 2025 18:23 — with GitHub Actions Inactive
@ncfrey ncfrey mentioned this pull request Jul 15, 2025
12 tasks
token_counts = mask.sum(dim=1) # (batch, 1)

# Avoid division by zero for empty sequences
token_counts = torch.clamp(token_counts, min=1.0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you seen any empty seqs or is this just for safety?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just safety

logger.info(f" Token-level min cosine similarity: {min_token_sim:.6f}")

# Token-level embeddings should also be highly consistent
assert min_token_sim > 0.995, (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't they be exactly equal? Both for comparing all token embeddings and aggregated

embeddings = embeddings.mean(dim=1)
# Apply mask and compute mean only over actual tokens
masked_embeddings = embeddings * mask
sum_embeddings = masked_embeddings.sum(dim=1) # (batch, hidden_dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

@karinazad
Copy link
Collaborator

good fix! we can merge this for now but I'm a bit worried that we're just patching the issue in UME. maybe this should go to modern bert?

if attention_mask.dim() == 3:
attention_mask = attention_mask.squeeze(1) # Remove middle dimension: (batch, seq_len)

# Ensure attention mask matches embedding sequence length
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

claude suggested:

batch_size, mask_seq_len = attention_mask.shape
embed_seq_len = embeddings.shape[1]

if mask_seq_len != embed_seq_len:
    if mask_seq_len > embed_seq_len:
        attention_mask = attention_mask[:, :embed_seq_len]
    else:
        pad_length = embed_seq_len - mask_seq_len
        padding = torch.zeros(batch_size, pad_length, 
                            dtype=attention_mask.dtype, 
                            device=attention_mask.device)
        attention_mask = torch.cat([attention_mask, padding], dim=1)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could be moved to a helper func something like _align_attention_mask

@ncfrey ncfrey temporarily deployed to test.pypi.org July 15, 2025 19:52 — with GitHub Actions Inactive
@ncfrey ncfrey merged commit 152f709 into main Jul 15, 2025
5 checks passed
@ncfrey ncfrey deleted the n/cpu-attn-mask branch July 15, 2025 19:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants