Skip to content

dataloader checkpoint callback#60

Merged
karinazad merged 5 commits intomainfrom
dataloader-callback
Apr 11, 2025
Merged

dataloader checkpoint callback#60
karinazad merged 5 commits intomainfrom
dataloader-callback

Conversation

@karinazad
Copy link
Collaborator

No description provided.

if self._is_s3_uri:
with tempfile.NamedTemporaryFile() as tmp_file:
temp_path = tmp_file.name
torch.save(dataloader.state_dict(), temp_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, about how large are the dataloader state dicts?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have an example on hand but I think they are pretty small since it just stores the item indices and some metadata about input_dir etc

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# Instantiate the model
if ckpt_path is not None:
self.model = FlexBERT.load_from_checkpoint(ckpt_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. So by default, always use Ume.load_from_checkpoint rather than specify a ckpt_path in model instantiation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes exactly, ckpt_path in the model parameters is only needed because of this line: https://github.com/prescient-design/lobster/blob/main/src/lobster/cmdline/_train.py#L54

Copy link
Collaborator

@taylormjs taylormjs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

@karinazad karinazad merged commit de5caba into main Apr 11, 2025
5 checks passed
@karinazad karinazad deleted the dataloader-callback branch April 11, 2025 23:02
taylormjs pushed a commit that referenced this pull request Apr 29, 2025
* dataloader callback

* utils

* ume

* gitignore dev

* tests
karinazad added a commit that referenced this pull request May 14, 2025
* peer fixes, add evaluate method

* dataloader checkpoint callback (#60)

* dataloader callback

* utils

* ume

* gitignore dev

* tests

* update flash attention wheels (#61)

* lock

* torch 2.5

* torch 2.5

* part

* .env

* unpin flash attn (#62)

* fix scheduler params (#64)

* scheduler

* fix scheduler

* fix scheduler

* Add AtomicaDataset (#63)

Processed Atomica interactions dataset

* Ume conversion/interaction tokenizer + fix SMILES and nucleotide tokenizers (#65)

add two special tokens: <convert> and <interact> for later stages of Ume training:
will be used as this: (or something like that)
[CLS]  PROT_SEQ  [SEP] <convert> PROT_STRUCT(masked)  [SEP]
[CLS]  PROT_SEQ  [SEP] <interact> SMILES(masked)  [SEP] 
extend functionality of UmeTokenizerTransform to handle dual modalities
change the name of Ume embedding method and allow embedding from existing input_ids
fix existing tokenizers:

add lowercase normalized to nucleotide tokenizer (OG2 dataset contains a mix of upper and lowercase letters)
BPE handled SMILES tokenization incorrectly, switch to WordLevel

* Ume SMILES tokenizer fix (#66)

* tokenizer

* fix tests

* lowercase normalizer for nt

* tests

* remove mod conv dataset

* embed

* Test

* merge 2mod into UmeTokenizerTransform

* fix tests

* all

* type hints

* docstrings

* tests

* fix SMILES tokenizer

* switch all tokenizer to BPE

* Revert "switch all tokenizer to BPE"

This reverts commit 367e77d.

* tok

* fix SMILES tokenizer

* remove print statement

* Ume perplexity logging (#67)

* pplx

* tests

* src

* ignore torchmetrics warnings

* docstrings

* docstrings

* Update README.md (#69)

* Ume fix perplexity device (#68)

* pplx as attr

* pplx as attr

* pplx

* comments

* on step

* comment

* update tests, fix ruff

* ruff

* ruff ruff

* Add <cls_modality> to Ume tokenizers (#71)

* add <cls_modality> tokens

* add <cls_modality> tokens

* docstring

* RNS metric implementation  (#73)

* add <cls_modality> tokens

* add <cls_modality> tokens

* modality embeddings

* module dict

* embeddings

* tests

* modality and device

* rank zero only

* rank zero

* fix back modality mask

* sync dist

* RNS implementation

* restore from main

* restore

* docstrings

* docstrings

* review

* test

* Ume modality-specific embeddings (#72)

* add <cls_modality> tokens

* add <cls_modality> tokens

* modality embeddings

* module dict

* embeddings

* tests

* modality and device

* rank zero only

* rank zero

* fix back modality mask

* sync dist

* add conversion transforms (#74)

* add initial smiles to peptide and peptide to smiles transforms

* remove smiles -> * transforms and touch up conversion functions

* rename

* add option to randomize smiles and caps

---------

Co-authored-by: Colin Grambow <grambowc@gene.com>

* fix def pad token, replace process_and_embed w/ ume.embed

* update tests w -100 pad token

---------

Co-authored-by: Taylor Joren <joren.taylor@gene.com>
Co-authored-by: Karina Zadorozhny <karina.zadorozhny@gmail.com>
Co-authored-by: Nathan Frey <ncfrey@users.noreply.github.com>
Co-authored-by: Colin Grambow <17198155+cgrambow@users.noreply.github.com>
Co-authored-by: Colin Grambow <grambowc@gene.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants