Skip to content

UME-2 with auxiliary tasks#190

Merged
karinazad merged 15 commits intomainfrom
k/ume2-auxiliary-tasks
Sep 10, 2025
Merged

UME-2 with auxiliary tasks#190
karinazad merged 15 commits intomainfrom
k/ume2-auxiliary-tasks

Conversation

@karinazad
Copy link
Collaborator

@karinazad karinazad commented Sep 5, 2025

Description

  • Add UME-2 class for sequence-only encoders
  • Add support for auxiliary training tasks

Type of Change

  • Bug fix
  • New feature
  • Documentation update
  • Performance improvement
  • Code refactoring

import importlib


def ensure_package(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea!

from ._calm_tasks import CALM_TASK_SPECIES, CALM_TASKS, CALMSpecies, CALMTask, MAX_SEQUENCE_LENGTH
from ._codon_table import CODON_TABLE_PATH, CODON_TABLE_PATH_VENDOR
from ._descriptor_descs import RDKIT_DESCRIPTOR_DISTRIBUTIONS
from ._rdkit_descriptor_distributions import RDKIT_DESCRIPTOR_DISTRIBUTIONS
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these for normalization or are these directly predicted?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

David added these for normalization

@@ -61,6 +60,7 @@ def __init__(
seed: int = 0,
cache_dir: str | None = None,
transform_fn: Callable | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not have just one transform_fns instead of both transform_fn and extra_transform_fns?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

transform_fn would be applied to the sequence before it goes to tokenization (e.g replace | for . in protein complexes) and extra_transform_fns are applied alongside the tokenized result to give something like

{input_ids: ..., attention_mask: ..., extra_output_1: ...., extra_output2:...

maybe there is a better name for the parameter though

num_warmup_steps: 10_000
weight_decay: 0.01
mask_token_id: 8
pad_token_id: 6
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are the masks and pad ids 8 and 6, resp?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a pretty arbitrary order but that's what out tokenizers are using https://github.com/prescient-design/lobster/blob/main/src/lobster/tokenization/_ume_tokenizers.py#L105


masked_embeddings = output["last_hidden_state"] * mask

sum_embeddings = masked_embeddings.sum(dim=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a separate MR, we should probably update this with a pooling function that could be passed to the model. Something like making aggregate (bool) into aggregator (fn)

raise ValueError(f"Unsupported task type: {self.task_type}")


class AuxiliaryRegressionTaskHead(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this should live elsewhere, like lobster.finetune (lobster.post_train?). I know we're still pre-training here, but having a dedicated place for pooling, regression heads, etc. that could be used for auxiliary task and post-training might be better organizationally

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think that's a good idea, we can move it once finetune is ready

)
def test_smiles_to_rdkit_descs(mock_calc, smiles, expected):
mock_calc.return_value = {"desc1": 1.0, "desc2": 2.0}
def test_smiles_to_rdkit_descs(mock_smiles_to_desc, smiles, expected):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there tests for ume2 outputting both masked token & auxiliary task preds?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add!

@karinazad karinazad merged commit 40ed9dd into main Sep 10, 2025
2 of 4 checks passed
@karinazad karinazad deleted the k/ume2-auxiliary-tasks branch September 10, 2025 14:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants