Conversation
|
|
||
| assert embeddings_flash_agg.shape == embeddings_no_flash_agg.shape | ||
|
|
||
| def test_flash_attention_consistency_across_devices(self): |
There was a problem hiding this comment.
just curious: what specifically was missing in the previous tests that they didn't catch this bug?
tests/lobster/model/test__ume.py
Outdated
| # Check similarity | ||
| cosine_sim = torch.nn.functional.cosine_similarity(embeddings_gpu, embeddings_cpu, dim=1) | ||
|
|
||
| print(f" Cosine similarities: {cosine_sim.tolist()}") |
There was a problem hiding this comment.
might want to remove print statements from tests before merging?
There was a problem hiding this comment.
i'll replace with logging statements
There was a problem hiding this comment.
here's what Claude thinks
For logging statements in Python tests, here's my guidance:
Avoid logging statements in test code itself. Tests should be clean and focused on verification, not producing logs.
src/lobster/model/_ume.py
Outdated
| attention_mask = attention_mask.squeeze(1) # Remove middle dimension: (batch, seq_len) | ||
|
|
||
| # Ensure attention mask matches embedding sequence length | ||
| if attention_mask.shape[1] != embeddings.shape[1]: |
There was a problem hiding this comment.
how does it happen that the attention mask's shape is not the same as embeddings?
There was a problem hiding this comment.
the attention mask is from the unpadded inputs for flash-attn, but the sdpa embeddings are padded so the attention mask has to be updated to match
| token_counts = mask.sum(dim=1) # (batch, 1) | ||
|
|
||
| # Avoid division by zero for empty sequences | ||
| token_counts = torch.clamp(token_counts, min=1.0) |
There was a problem hiding this comment.
Have you seen any empty seqs or is this just for safety?
| logger.info(f" Token-level min cosine similarity: {min_token_sim:.6f}") | ||
|
|
||
| # Token-level embeddings should also be highly consistent | ||
| assert min_token_sim > 0.995, ( |
There was a problem hiding this comment.
Shouldn't they be exactly equal? Both for comparing all token embeddings and aggregated
| embeddings = embeddings.mean(dim=1) | ||
| # Apply mask and compute mean only over actual tokens | ||
| masked_embeddings = embeddings * mask | ||
| sum_embeddings = masked_embeddings.sum(dim=1) # (batch, hidden_dim) |
There was a problem hiding this comment.
Thanks for adding this!
|
good fix! we can merge this for now but I'm a bit worried that we're just patching the issue in UME. maybe this should go to modern bert? |
| if attention_mask.dim() == 3: | ||
| attention_mask = attention_mask.squeeze(1) # Remove middle dimension: (batch, seq_len) | ||
|
|
||
| # Ensure attention mask matches embedding sequence length |
There was a problem hiding this comment.
claude suggested:
batch_size, mask_seq_len = attention_mask.shape
embed_seq_len = embeddings.shape[1]
if mask_seq_len != embed_seq_len:
if mask_seq_len > embed_seq_len:
attention_mask = attention_mask[:, :embed_seq_len]
else:
pad_length = embed_seq_len - mask_seq_len
padding = torch.zeros(batch_size, pad_length,
dtype=attention_mask.dtype,
device=attention_mask.device)
attention_mask = torch.cat([attention_mask, padding], dim=1)
There was a problem hiding this comment.
nit: could be moved to a helper func something like _align_attention_mask
Description
This pull request introduces significant improvements to the embedding computation in the
UMEmodel and adds comprehensive tests to ensure consistency across different configurations. The changes primarily address the handling of padding tokens during mean pooling and introduce new test cases to validate the consistency of embeddings generated by the model across devices and configurations.Improvements to embedding computation:
src/lobster/model/_ume.py: Updated the embedding computation logic to use an attention mask for excluding padding tokens from mean pooling. This ensures accurate aggregation by masking out padding tokens, preventing them from affecting the computed embeddings. Additionally, safeguards were added to handle mismatched sequence lengths between the attention mask and embeddings.Enhanced test coverage:
tests/lobster/model/test__ume.py: Added a new test methodtest_flash_attention_consistency_across_devicesto validate the consistency of embeddings generated with flash attention (GPU) and non-flash attention (CPU). The test includes various modalities (e.g., amino acid, nucleotide, SMILES) and checks cosine similarity to ensure embeddings are nearly identical (>99.9% similarity). It also tests consistency between padded and unpadded architectures.Type of Change
Testing
Checklist