Skip to content

Add Symile multi-modal contrastive loss to Ume#109

Merged
karinazad merged 24 commits intomainfrom
k/symile-loss-v2
Jun 17, 2025
Merged

Add Symile multi-modal contrastive loss to Ume#109
karinazad merged 24 commits intomainfrom
k/symile-loss-v2

Conversation

@karinazad
Copy link
Collaborator

@karinazad karinazad commented Jun 16, 2025

  • Adds the Symile loss function from the symile package for multi-modal contrastive learning
  • Adds UmeStreamingDataset based on litdata which supports transforms that return multiple modality views/representations
  • Adds ModalityAwareTranform that works with UmeStreamingDataset

Notes
Credits to Omar Mahmood for suggesting this loss for contrastive learning with multiple modalities

Reference:
https://github.com/rajesh-lab/symile
https://arxiv.org/pdf/2411.01053


Here's how InfoNCE and Symile loss compare on 2 inputs:

>>> import torch
>>> import torch.nn.functional as F
>>> from symile import Symile
>>> 
>>> # Create random embeddings (batch_size=32, embedding_dim=256)
>>> embeddings_a = torch.rand(32, 256)
>>> embeddings_b = torch.rand(32, 256)
>>> 
>>> # Normalize embeddings
>>> embeddings_a = F.normalize(embeddings_a, p=2.0, dim=1)
>>> embeddings_b = F.normalize(embeddings_b, p=2.0, dim=1)
>>> 
>>> # Temperature parameter
>>> temperature = 0.07
>>> 
>>> # Compute InfoNCE loss
>>> logits = embeddings_a @ embeddings_b.T / temperature
>>> labels = torch.arange(embeddings_a.shape[0])
>>> infonce_loss = F.cross_entropy(logits, labels)
>>> 
>>> # Compute Symile loss
>>> symile_loss = Symile()([embeddings_a, embeddings_b], temperature)
>>> 
>>> print(f"InfoNCE Loss: {infonce_loss.item():.4f}")
InfoNCE Loss: 3.5327
>>> print(f"Symile Loss: {symile_loss.item():.4f}")
Symile Loss: 3.4659

@karinazad karinazad requested a review from ncfrey June 17, 2025 02:45
@ncfrey ncfrey requested a review from Copilot June 17, 2025 12:56
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds multi‐modal contrastive learning functionality using Symile loss to Ume along with supporting transforms and a new streaming dataset. Key changes include:

  • Integration of a new Symile loss function in the Ume model and updating contrastive loss scaling.
  • Addition of UmeStreamingDataset leveraging litdata for multi-modal tokenization and data loading.
  • Renaming and updating several transform classes (e.g. AminoAcidToNucleotidePairTransform, AminoAcidToSmilesPairTransform) to reflect modality‐specific processing.

Reviewed Changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
tests/lobster/transforms/test_modality_aware_transform.py Added tests for modality-aware transforms.
tests/lobster/transforms/test__equivalence_transforms.py Updated transform test cases and rename tests to reflect amino acid based transforms.
src/lobster/transforms/functional/_convert_seqs.py Updated probabilistic conversion with an additional skip_unknown parameter.
src/lobster/transforms/_modality_aware_transform.py Introduced modality-aware transform wrappers and composition.
src/lobster/transforms/_equivalence_transforms.py Renamed and updated equivalence transforms to amino acid centric versions with new parameters.
src/lobster/model/_ume.py Integrated Symile loss, updated contrastive loss scaling, and enhanced batch splitting for multi-view inputs.
src/lobster/model/_symile_loss.py Added a new Symile loss implementation supporting two negative sampling strategies.
src/lobster/hydra_config/trainer.yaml Updated trainer configuration with new dependencies and settings.
src/lobster/datasets/_ume_streaming_dataset.py Added a new streaming dataset class supporting modality-specific tokenization via litdata.
pyproject.toml Added required dependencies (litdata and polars) for the new functionality.
Comments suppressed due to low confidence (2)

src/lobster/transforms/_equivalence_transforms.py:340

  • Consider updating the init docstring for AminoAcidToNucleotidePairTransform to document the new 'skip_unknown' parameter.
        skip_unknown: bool = False,

src/lobster/model/_ume.py:566

  • Confirm that switching the scaling from division to multiplication with self.contrastive_temperature is intentional and consistent with the logit_scale initialization in SymileLoss.
        similarities = embeddings_a @ embeddings_b.T * self.contrastive_temperature

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
"B007", # unused-loop-control-variable
"E741", # ambiguous-variable-name
"E902", # file not found error
"UP038", # Use X | Y in isinstance call instead of (X, Y)
Copy link
Collaborator Author

@karinazad karinazad Jun 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ruff complains about the use of isinstance(item, (str, float) which is the standard

self.contrastive_temperature = contrastive_temperature

# Initialize SymileLoss with the correct logit scale
self.symile_loss_fn = SymileLoss(logit_scale=1.0 / contrastive_temperature)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thinking about if we should make the loss configurable. probably fine for now since we're pretty set on losses for this version

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was planning on making a refactor of the Ume model class since it's a bit bloated right now but didn't want to include it in this MR since there is already a lot of code change

@karinazad karinazad merged commit 7a844d7 into main Jun 17, 2025
5 checks passed
@karinazad karinazad deleted the k/symile-loss-v2 branch June 17, 2025 15:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants