Skip to content

T/callback update#183

Merged
taylormjs merged 19 commits intomainfrom
t/callback-update
Aug 12, 2025
Merged

T/callback update#183
taylormjs merged 19 commits intomainfrom
t/callback-update

Conversation

@taylormjs
Copy link
Collaborator

@taylormjs taylormjs commented Aug 8, 2025

Updates to DGEB, CaLM and MoleculeACE (public) callbacks:

  • Add DGEB callback to be consistent with other linear probe callbacks and to work with hydra configs
  • LinearProbeCallback now properly handles mean pooling (before was including pad tokens)
  • MoleculeACE Callback now matches lobster_internal
  • CaLM callback now uses LinearProbeCallbacks embed method instead of its own
  • DGEB callback implemented for ESM

Major Features:

  • Add comprehensive DGEBEvaluationCallback for UME and ESM models
  • Implement ESMAdapterDGEB for direct ESM model evaluation without checkpoints
  • Add shared pooling utilities for consistent embedding aggregation
  • Enhance MoleculeACE linear probe with better model compatibility

Core Components:

  • DGEBEvaluationCallback: Unified callback supporting both UME (checkpoint-based) and ESM (direct) evaluation workflows
  • ESMAdapterDGEB: DGEB-compatible adapter for ESM models with proper masked pooling
  • Shared pooling utilities: mean/max/cls/last pooling with attention masking
  • Enhanced error handling and graceful task failure recovery

Improvements:

  • Better embedding extraction across different model types
  • Improved linear probe callbacks with enhanced input processing
  • Updated DGEB runners with better error handling and reporting
  • Comprehensive test coverage for new ESM adapter functionality

Type of Change

  • Bug fix
  • New feature
  • Documentation update
  • Performance improvement
  • Code refactoring

Taylor Joren added 12 commits August 4, 2025 20:38
Major Features:
- Add comprehensive DGEBEvaluationCallback for UME and ESM models
- Implement ESMAdapterDGEB for direct ESM model evaluation without checkpoints
- Add shared pooling utilities for consistent embedding aggregation
- Enhance MoleculeACE linear probe with better model compatibility

Core Components:
- DGEBEvaluationCallback: Unified callback supporting both UME (checkpoint-based) and ESM (direct) evaluation workflows
- ESMAdapterDGEB: DGEB-compatible adapter for ESM models with proper masked pooling
- Shared pooling utilities: mean/max/cls/last pooling with attention masking
- Enhanced error handling and graceful task failure recovery

Improvements:
- Better embedding extraction across different model types
- Improved linear probe callbacks with enhanced input processing
- Updated DGEB runners with better error handling and reporting
- Comprehensive test coverage for new ESM adapter functionality
@taylormjs taylormjs marked this pull request as ready for review August 11, 2025 21:46
@taylormjs taylormjs requested a review from ncfrey August 11, 2025 21:46
import lightning as L
import numpy as np
import torch
from lobster.transforms import Transform
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we made these changes in #166 for consistency

def apply_dgeb_pooling(
token_embeddings: torch.Tensor,
attention_mask: torch.Tensor,
pool_type: Literal["mean", "max", "cls", "last"] = "mean",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

possibly define an Enum for pooling types across the library?

pooled = torch.stack([layer_hidden[i, l, :] for i, l in enumerate(lengths)], dim=0)
else:
raise ValueError(f"Unsupported pool_type: {self.pool_type}")
pooled = apply_dgeb_pooling(layer_hidden, attention_mask, self.pool_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! i like having the pooling logic self-contained

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's nice for easy to transfer to esm_dgeb_adapter, for example

}

# Extract key metrics from results
# Extract key metrics from results with error handling for individual tasks
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good call - are you seeing failed tasks frequently?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not frequently, but enough to where I wanted a report of which tasks failed

def __init__(
self,
module: L.LightningModule,
modality: Literal["protein", "dna"] = "protein",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we harmonize this with UME's modality types?

"""

# Create a minimal tokenizer object with required attributes
class MinimalTokenizer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe DummyTokenizer then since this isn't used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good point. I had it as DummyTokenizer before but changed it because of Monday "vibes"

@taylormjs taylormjs merged commit 8dbe3f2 into main Aug 12, 2025
4 checks passed
@taylormjs taylormjs deleted the t/callback-update branch August 12, 2025 18:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants