Conversation
There was a problem hiding this comment.
this file changed
There was a problem hiding this comment.
this file changed
There was a problem hiding this comment.
Pull Request Overview
This PR updates the UME code to support modality‐specific embeddings and renames the latent generator tokenizer to use 3D coordinates. Key changes include updating reserved token naming conventions in tests and tokenizers, modifying post‐processors to use modality-specific CLS tokens, and extending the UME model to optionally use modality-specific embedding layers.
Reviewed Changes
Copilot reviewed 21 out of 21 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/lobster/tokenization/test__ume_tokenizers.py | Adjusted expected vocab keys and outputs to reflect renamed tokens and reserved token naming changes. |
| tests/lobster/model/test__ume.py | Updated test cases and parameterization for training/validation steps with modality-specific changes. |
| src/lobster/tokenization/_ume_tokenizers.py | Renamed latent generator tokenizer to coordinates_3d_tokenizer and refactored reserved token handling and post‐processor function. |
| src/lobster/model/modern_bert/_modern_bert.py | Replaced tokens_to_latents with new helper methods; added an assert to validate token IDs; uses .cuda() when creating cu_seqlens. |
| src/lobster/model/_ume.py | Updated embedding logic to conditionally use modality-specific embedding layers and adjusted logits computation accordingly. |
| Asset/tokenizer JSON/config files | Updated special tokens maps to reflect the new modality-specific token names. |
Comments suppressed due to low confidence (1)
tests/lobster/tokenization/test__ume_tokenizers.py:85
- The function name 'test_ume_aminio_acid_tokenizer' contains a typo. It should be renamed to 'test_ume_amino_acid_tokenizer' for clarity and consistency.
def test_ume_aminio_acid_tokenizer():
| # Compute cumulative sequence lengths | ||
| # input_ids and attention_mask are expected to be of shape (batch_size, 1, length) | ||
| batch_size, length = input_ids.shape[0], input_ids.shape[2] | ||
| cu_seqlens = torch.tensor([0] + [(i + 1) * length for i in range(batch_size)], dtype=torch.int32).cuda() |
There was a problem hiding this comment.
Using .cuda() here may force the tensor onto GPU even when running on CPU. Consider using input_ids.device (e.g., device=input_ids.device) to ensure compatibility with both CPU and GPU.
| cu_seqlens = torch.tensor([0] + [(i + 1) * length for i in range(batch_size)], dtype=torch.int32).cuda() | |
| cu_seqlens = torch.tensor([0] + [(i + 1) * length for i in range(batch_size)], dtype=torch.int32).to(input_ids.device) |
| modalities = batch["metadata"]["modality"] if "metadata" in batch else batch["modality"] | ||
|
|
||
| for modality in set(modalities): | ||
| modality_mask = torch.tensor([m == modality for m in modalities], device=self.device, dtype=torch.bool) |
There was a problem hiding this comment.
[nitpick] Consider replacing the list comprehension with a vectorized operation if possible to create modality_mask, which may improve performance especially for larger batch sizes.
| modality_mask = torch.tensor([m == modality for m in modalities], device=self.device, dtype=torch.bool) | |
| modalities_tensor = torch.tensor(modalities, device=self.device) | |
| modality_mask = (modalities_tensor == modality) |
There was a problem hiding this comment.
this won't work because modalities is a list of strings and can't be a tensor
…to ume-modality-embeddings
taylormjs
left a comment
There was a problem hiding this comment.
lgtm! gave a cursory pass
* add <cls_modality> tokens * add <cls_modality> tokens * modality embeddings * module dict * embeddings * tests * modality and device * rank zero only * rank zero * fix back modality mask * sync dist
* peer fixes, add evaluate method * dataloader checkpoint callback (#60) * dataloader callback * utils * ume * gitignore dev * tests * update flash attention wheels (#61) * lock * torch 2.5 * torch 2.5 * part * .env * unpin flash attn (#62) * fix scheduler params (#64) * scheduler * fix scheduler * fix scheduler * Add AtomicaDataset (#63) Processed Atomica interactions dataset * Ume conversion/interaction tokenizer + fix SMILES and nucleotide tokenizers (#65) add two special tokens: <convert> and <interact> for later stages of Ume training: will be used as this: (or something like that) [CLS] PROT_SEQ [SEP] <convert> PROT_STRUCT(masked) [SEP] [CLS] PROT_SEQ [SEP] <interact> SMILES(masked) [SEP] extend functionality of UmeTokenizerTransform to handle dual modalities change the name of Ume embedding method and allow embedding from existing input_ids fix existing tokenizers: add lowercase normalized to nucleotide tokenizer (OG2 dataset contains a mix of upper and lowercase letters) BPE handled SMILES tokenization incorrectly, switch to WordLevel * Ume SMILES tokenizer fix (#66) * tokenizer * fix tests * lowercase normalizer for nt * tests * remove mod conv dataset * embed * Test * merge 2mod into UmeTokenizerTransform * fix tests * all * type hints * docstrings * tests * fix SMILES tokenizer * switch all tokenizer to BPE * Revert "switch all tokenizer to BPE" This reverts commit 367e77d. * tok * fix SMILES tokenizer * remove print statement * Ume perplexity logging (#67) * pplx * tests * src * ignore torchmetrics warnings * docstrings * docstrings * Update README.md (#69) * Ume fix perplexity device (#68) * pplx as attr * pplx as attr * pplx * comments * on step * comment * update tests, fix ruff * ruff * ruff ruff * Add <cls_modality> to Ume tokenizers (#71) * add <cls_modality> tokens * add <cls_modality> tokens * docstring * RNS metric implementation (#73) * add <cls_modality> tokens * add <cls_modality> tokens * modality embeddings * module dict * embeddings * tests * modality and device * rank zero only * rank zero * fix back modality mask * sync dist * RNS implementation * restore from main * restore * docstrings * docstrings * review * test * Ume modality-specific embeddings (#72) * add <cls_modality> tokens * add <cls_modality> tokens * modality embeddings * module dict * embeddings * tests * modality and device * rank zero only * rank zero * fix back modality mask * sync dist * add conversion transforms (#74) * add initial smiles to peptide and peptide to smiles transforms * remove smiles -> * transforms and touch up conversion functions * rename * add option to randomize smiles and caps --------- Co-authored-by: Colin Grambow <grambowc@gene.com> * fix def pad token, replace process_and_embed w/ ume.embed * update tests w -100 pad token --------- Co-authored-by: Taylor Joren <joren.taylor@gene.com> Co-authored-by: Karina Zadorozhny <karina.zadorozhny@gmail.com> Co-authored-by: Nathan Frey <ncfrey@users.noreply.github.com> Co-authored-by: Colin Grambow <17198155+cgrambow@users.noreply.github.com> Co-authored-by: Colin Grambow <grambowc@gene.com>
Add an option to pass inputs through modality-specific embeddings in Ume
Branches of !71 tokenization MR