Add Symile multi-modal contrastive loss to Ume#109
Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR adds multi‐modal contrastive learning functionality using Symile loss to Ume along with supporting transforms and a new streaming dataset. Key changes include:
- Integration of a new Symile loss function in the Ume model and updating contrastive loss scaling.
- Addition of UmeStreamingDataset leveraging litdata for multi-modal tokenization and data loading.
- Renaming and updating several transform classes (e.g. AminoAcidToNucleotidePairTransform, AminoAcidToSmilesPairTransform) to reflect modality‐specific processing.
Reviewed Changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| tests/lobster/transforms/test_modality_aware_transform.py | Added tests for modality-aware transforms. |
| tests/lobster/transforms/test__equivalence_transforms.py | Updated transform test cases and rename tests to reflect amino acid based transforms. |
| src/lobster/transforms/functional/_convert_seqs.py | Updated probabilistic conversion with an additional skip_unknown parameter. |
| src/lobster/transforms/_modality_aware_transform.py | Introduced modality-aware transform wrappers and composition. |
| src/lobster/transforms/_equivalence_transforms.py | Renamed and updated equivalence transforms to amino acid centric versions with new parameters. |
| src/lobster/model/_ume.py | Integrated Symile loss, updated contrastive loss scaling, and enhanced batch splitting for multi-view inputs. |
| src/lobster/model/_symile_loss.py | Added a new Symile loss implementation supporting two negative sampling strategies. |
| src/lobster/hydra_config/trainer.yaml | Updated trainer configuration with new dependencies and settings. |
| src/lobster/datasets/_ume_streaming_dataset.py | Added a new streaming dataset class supporting modality-specific tokenization via litdata. |
| pyproject.toml | Added required dependencies (litdata and polars) for the new functionality. |
Comments suppressed due to low confidence (2)
src/lobster/transforms/_equivalence_transforms.py:340
- Consider updating the init docstring for AminoAcidToNucleotidePairTransform to document the new 'skip_unknown' parameter.
skip_unknown: bool = False,
src/lobster/model/_ume.py:566
- Confirm that switching the scaling from division to multiplication with self.contrastive_temperature is intentional and consistent with the logit_scale initialization in SymileLoss.
similarities = embeddings_a @ embeddings_b.T * self.contrastive_temperature
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
| "B007", # unused-loop-control-variable | ||
| "E741", # ambiguous-variable-name | ||
| "E902", # file not found error | ||
| "UP038", # Use X | Y in isinstance call instead of (X, Y) |
There was a problem hiding this comment.
ruff complains about the use of isinstance(item, (str, float) which is the standard
| self.contrastive_temperature = contrastive_temperature | ||
|
|
||
| # Initialize SymileLoss with the correct logit scale | ||
| self.symile_loss_fn = SymileLoss(logit_scale=1.0 / contrastive_temperature) |
There was a problem hiding this comment.
thinking about if we should make the loss configurable. probably fine for now since we're pretty set on losses for this version
There was a problem hiding this comment.
I was planning on making a refactor of the Ume model class since it's a bit bloated right now but didn't want to include it in this MR since there is already a lot of code change
symilepackage for multi-modal contrastive learninglitdatawhich supports transforms that return multiple modality views/representationsNotes
Credits to Omar Mahmood for suggesting this loss for contrastive learning with multiple modalities
Reference:
https://github.com/rajesh-lab/symile
https://arxiv.org/pdf/2411.01053
Here's how InfoNCE and Symile loss compare on 2 inputs: