Skip to content

Ume modality-specific embeddings#72

Merged
taylormjs merged 12 commits intomainfrom
ume-modality-embeddings
May 13, 2025
Merged

Ume modality-specific embeddings#72
taylormjs merged 12 commits intomainfrom
ume-modality-embeddings

Conversation

@karinazad
Copy link
Collaborator

@karinazad karinazad commented May 8, 2025

Add an option to pass inputs through modality-specific embeddings in Ume
Branches of !71 tokenization MR

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file changed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file changed

@ncfrey ncfrey requested a review from Copilot May 8, 2025 17:48
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR updates the UME code to support modality‐specific embeddings and renames the latent generator tokenizer to use 3D coordinates. Key changes include updating reserved token naming conventions in tests and tokenizers, modifying post‐processors to use modality-specific CLS tokens, and extending the UME model to optionally use modality-specific embedding layers.

Reviewed Changes

Copilot reviewed 21 out of 21 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
tests/lobster/tokenization/test__ume_tokenizers.py Adjusted expected vocab keys and outputs to reflect renamed tokens and reserved token naming changes.
tests/lobster/model/test__ume.py Updated test cases and parameterization for training/validation steps with modality-specific changes.
src/lobster/tokenization/_ume_tokenizers.py Renamed latent generator tokenizer to coordinates_3d_tokenizer and refactored reserved token handling and post‐processor function.
src/lobster/model/modern_bert/_modern_bert.py Replaced tokens_to_latents with new helper methods; added an assert to validate token IDs; uses .cuda() when creating cu_seqlens.
src/lobster/model/_ume.py Updated embedding logic to conditionally use modality-specific embedding layers and adjusted logits computation accordingly.
Asset/tokenizer JSON/config files Updated special tokens maps to reflect the new modality-specific token names.
Comments suppressed due to low confidence (1)

tests/lobster/tokenization/test__ume_tokenizers.py:85

  • The function name 'test_ume_aminio_acid_tokenizer' contains a typo. It should be renamed to 'test_ume_amino_acid_tokenizer' for clarity and consistency.
def test_ume_aminio_acid_tokenizer():

# Compute cumulative sequence lengths
# input_ids and attention_mask are expected to be of shape (batch_size, 1, length)
batch_size, length = input_ids.shape[0], input_ids.shape[2]
cu_seqlens = torch.tensor([0] + [(i + 1) * length for i in range(batch_size)], dtype=torch.int32).cuda()
Copy link

Copilot AI May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using .cuda() here may force the tensor onto GPU even when running on CPU. Consider using input_ids.device (e.g., device=input_ids.device) to ensure compatibility with both CPU and GPU.

Suggested change
cu_seqlens = torch.tensor([0] + [(i + 1) * length for i in range(batch_size)], dtype=torch.int32).cuda()
cu_seqlens = torch.tensor([0] + [(i + 1) * length for i in range(batch_size)], dtype=torch.int32).to(input_ids.device)

Copilot uses AI. Check for mistakes.
modalities = batch["metadata"]["modality"] if "metadata" in batch else batch["modality"]

for modality in set(modalities):
modality_mask = torch.tensor([m == modality for m in modalities], device=self.device, dtype=torch.bool)
Copy link

Copilot AI May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider replacing the list comprehension with a vectorized operation if possible to create modality_mask, which may improve performance especially for larger batch sizes.

Suggested change
modality_mask = torch.tensor([m == modality for m in modalities], device=self.device, dtype=torch.bool)
modalities_tensor = torch.tensor(modalities, device=self.device)
modality_mask = (modalities_tensor == modality)

Copilot uses AI. Check for mistakes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this won't work because modalities is a list of strings and can't be a tensor

@karinazad karinazad temporarily deployed to test.pypi.org May 8, 2025 18:48 — with GitHub Actions Inactive
@karinazad karinazad temporarily deployed to test.pypi.org May 9, 2025 00:01 — with GitHub Actions Inactive
@karinazad karinazad temporarily deployed to test.pypi.org May 9, 2025 00:06 — with GitHub Actions Inactive
@karinazad karinazad temporarily deployed to test.pypi.org May 9, 2025 00:25 — with GitHub Actions Inactive
@karinazad karinazad temporarily deployed to test.pypi.org May 9, 2025 00:30 — with GitHub Actions Inactive
@karinazad karinazad temporarily deployed to test.pypi.org May 9, 2025 13:53 — with GitHub Actions Inactive
@taylormjs taylormjs self-assigned this May 13, 2025
Copy link
Collaborator

@taylormjs taylormjs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! gave a cursory pass

@taylormjs taylormjs merged commit 47f1755 into main May 13, 2025
5 checks passed
@taylormjs taylormjs deleted the ume-modality-embeddings branch May 13, 2025 19:36
taylormjs pushed a commit that referenced this pull request May 14, 2025
* add <cls_modality> tokens

* add <cls_modality> tokens

* modality embeddings

* module dict

* embeddings

* tests

* modality and device

* rank zero only

* rank zero

* fix back modality mask

* sync dist
karinazad added a commit that referenced this pull request May 14, 2025
* peer fixes, add evaluate method

* dataloader checkpoint callback (#60)

* dataloader callback

* utils

* ume

* gitignore dev

* tests

* update flash attention wheels (#61)

* lock

* torch 2.5

* torch 2.5

* part

* .env

* unpin flash attn (#62)

* fix scheduler params (#64)

* scheduler

* fix scheduler

* fix scheduler

* Add AtomicaDataset (#63)

Processed Atomica interactions dataset

* Ume conversion/interaction tokenizer + fix SMILES and nucleotide tokenizers (#65)

add two special tokens: <convert> and <interact> for later stages of Ume training:
will be used as this: (or something like that)
[CLS]  PROT_SEQ  [SEP] <convert> PROT_STRUCT(masked)  [SEP]
[CLS]  PROT_SEQ  [SEP] <interact> SMILES(masked)  [SEP] 
extend functionality of UmeTokenizerTransform to handle dual modalities
change the name of Ume embedding method and allow embedding from existing input_ids
fix existing tokenizers:

add lowercase normalized to nucleotide tokenizer (OG2 dataset contains a mix of upper and lowercase letters)
BPE handled SMILES tokenization incorrectly, switch to WordLevel

* Ume SMILES tokenizer fix (#66)

* tokenizer

* fix tests

* lowercase normalizer for nt

* tests

* remove mod conv dataset

* embed

* Test

* merge 2mod into UmeTokenizerTransform

* fix tests

* all

* type hints

* docstrings

* tests

* fix SMILES tokenizer

* switch all tokenizer to BPE

* Revert "switch all tokenizer to BPE"

This reverts commit 367e77d.

* tok

* fix SMILES tokenizer

* remove print statement

* Ume perplexity logging (#67)

* pplx

* tests

* src

* ignore torchmetrics warnings

* docstrings

* docstrings

* Update README.md (#69)

* Ume fix perplexity device (#68)

* pplx as attr

* pplx as attr

* pplx

* comments

* on step

* comment

* update tests, fix ruff

* ruff

* ruff ruff

* Add <cls_modality> to Ume tokenizers (#71)

* add <cls_modality> tokens

* add <cls_modality> tokens

* docstring

* RNS metric implementation  (#73)

* add <cls_modality> tokens

* add <cls_modality> tokens

* modality embeddings

* module dict

* embeddings

* tests

* modality and device

* rank zero only

* rank zero

* fix back modality mask

* sync dist

* RNS implementation

* restore from main

* restore

* docstrings

* docstrings

* review

* test

* Ume modality-specific embeddings (#72)

* add <cls_modality> tokens

* add <cls_modality> tokens

* modality embeddings

* module dict

* embeddings

* tests

* modality and device

* rank zero only

* rank zero

* fix back modality mask

* sync dist

* add conversion transforms (#74)

* add initial smiles to peptide and peptide to smiles transforms

* remove smiles -> * transforms and touch up conversion functions

* rename

* add option to randomize smiles and caps

---------

Co-authored-by: Colin Grambow <grambowc@gene.com>

* fix def pad token, replace process_and_embed w/ ume.embed

* update tests w -100 pad token

---------

Co-authored-by: Taylor Joren <joren.taylor@gene.com>
Co-authored-by: Karina Zadorozhny <karina.zadorozhny@gmail.com>
Co-authored-by: Nathan Frey <ncfrey@users.noreply.github.com>
Co-authored-by: Colin Grambow <17198155+cgrambow@users.noreply.github.com>
Co-authored-by: Colin Grambow <grambowc@gene.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants