python 3.8
torch >=1.9.0
wandb
numpy
easydict
When using the distributed training in parallel/trainer.py, the latest accelerate should be installed:
git clone https://github.com/huggingface/accelerate
cd accelerate
python setup.py install
python ./scripts/predict.py \
--model <model_ckpt_path> \
--wt_pdb <path-to-wild-type-pdb> \
--mut_pdb <path-to-mutant-pdb>python ./script/test.py \
--model <model_ckpt_path> \
--neu_path <neutralization_ground-truth_path> \
--mut_pdb_dir <mut_pdbs_directory> \
--wt_pdb <wt_pdb_path>python ./script/e2e.py \
--model <model_ckpt_path> \
--evoef2_path <evoef2_installation_path> \
--wt_pdb <wt_pdb_path> \
--mut_tags <muation string> \
--clean_work_path <empty_working_directory>Single GPU:
python ./script/train.py \
--model ./data/model.pt \
--save_ckpt_dir <model_ckpt_dir> \
--input_data <serialized_training_data> Distributed:
python -m torch.distributed.launch --nproc_per_node <num_gpu_to_use> --use_env --master_port 20654 ./script/train.py \
--model ./data/model.pt \
--save_ckpt_dir <model_ckpt_dir> \
--input_data <serialized_training_data> The input of the personalized '--input_data' is a binary file serialized by pickle, which is a python list of triples
(data_wt, data_mut, ddG)
where ddG is a float value, data_wt/data_mut is the return of utils.protein.parse_pdb, which is a dictionary of
{
'name': structure.get_id(),
# Chain info
'chain_id': ''.join(chain_id), # sequence format of the chain id
'chain_seq': torch.LongTensor(chain_seq), # chain_id of a residue in, (L, )
# Sequence
'aa': torch.LongTensor(aa), # residue type id (L,)
'aa_seq': ''.join(aa_seq), # sequence of residue type
'resseq': torch.LongTensor(resseq),
'icode': ''.join(icode),
'seq': torch.LongTensor(seq),
# Atom positions
'pos14': torch.stack(pos14), # all atom coordinates of a residue (L, 14, 3)
'pos14_mask': torch.stack(pos14_mask), # mask flag for empty atom, (L, 14)
# Physicochemical Property
'phys': torch.stack(phys), # numerical value property, (L, 2)
'crg': torch.LongTensor(crg), # residue sidechain charge, (L,)
### L is the total number of residues in the protein
}
| name | type | description |
|---|---|---|
| res_encoder | String | 'mlp' for an mlp atom encoder, 'egnn' for EGNN atom encoder |
| mode | String | 'reg' for MSE loss, 'cla' for Cross Entropy Loss, 'gau' for Gaussian loss |
| k | Int | number of neighbors nearby the mutation to be used |
| num_egnn_layers | Int | number of EGNN layers, when res_encoder is 'egnn' |
| ckpt_freq | Int | number of epochs for model checkpoint |
All the parameters with their default value are in script/train.py
@inproceedings{zhao2023geometric,
title={Geometric Graph Learning for Protein Mutation Effect Prediction},
author={Zhao, Kangfei and Rong, Yu and Jiang, Biaobin and Tang, Jianheng and Zhang, Hengtong and Yu, Jeffrey Xu and Zhao, Peilin},
booktitle={Proceedings of the 32nd ACM International Conference on Information and Knowledge Management},
pages={3412--3422},
year={2023}
}