Identifying Subgroup-Specific Genetic Modules in Gene-Gene Correlation Networks Using Graph Neural Networks
This project leverages Graph Neural Networks (GNNs) to identify subgroup-specific clusters within gene-gene correlation networks using a supervised learning approach. Each graph in the dataset contains exactly one subgroup-specific cluster, and our objective is to classify nodes as either belonging (1) or not belonging (0) to this cluster.
For more detailed theoretical background, motivation and results, please refer to the Project Report.
-
Graph Simulation & Preprocessing:
Generate synthetic correlation networks under multiple conditions, then initialize and store edge features for downstream processing. -
GNN Architecture:
A custom neural network built with GATv2Conv layers (or similar) that integrates edge attributes, employs attention mechanisms, and uses GraphNorm for stable training, followed by an MLP for per-node classification. -
Training & Hyperparameter Tuning:
Comprehensive training routines for the model, alongside automated hyperparameter search via Optuna. All experiments track metrics and logs in real-time using Weights & Biases (wandb). -
Inference Pipeline:
A convenient workflow that converts raw NetworkX graphs into annotated graphs with node-level predictions. This allows easy integration of the trained model into existing or newly generated datasets. -
Experiment Tracking:
Rapid iteration and visualization through Weights & Biases (wandb), enabling insights into model performance, architecture decisions, and hyperparameter choices.
Identifying-Subgroup-Specific-Genetic-Modules-Using-GNN
βββ data/
β βββ graphs/ # Raw generated graphs
β βββ modified_graphs/ # Graphs with initialized edge features
βββ data_pipeline/
β βββ __init__.py # Module initializer for the data pipeline
β βββ graph_generator.py # Graph simulation and generation
β βββ init_edge_features.py # Edge feature initialization
β βββ simulation.py # Graph simulation framework
βββ experiments/
β βββ optuna_studies/
β β βββ example_study.db # Example database for Optuna studies
β β βββ study.db # Optuna study results, will be created when hyperparameter tuning
β βββ trained_model.pt # Default trained model weights
β βββ current_trial_best.yaml # Best trial configuration from the currently running study
β βββ default_config.yaml # Default hyper parameter configuration
β βββ trial_best.yaml # Best trial configuration after hyperparameter tuning
βββ models/
β βββ __init__.py # Module initializer for models
β βββ architecture.py # GNN architecture and weight initialization
β βββ hyperparameter_tuning.py # Hyperparameter tuning routines (optuna objective function)
β βββ training.py # Training routines for the model
βββ notebooks/
β βββ Project.ipynb # Jupyter Notebook for project exploration (not pushed to main repository).
βββ scripts/
β βββ __init__.py # Module initializer for scripts
β βββ inference.py # Script for inference on graphs
β βββ run_pipeline.sh # Shell script to run the entire data pipeline with default values
β βββ train_model.py # Script to train the model using default or custom configuration
β βββ tune_hyperparams.py # Script to tune model hyperparameters
βββ tests/
β βββ edge_init_test.py # Test for edge initialization
β βββ edge_init_test_load_parse.py # Test for loading and parsing edge initialization
β βββ graph_generator_test.py # Tests for graph generator
β βββ graph_generator_test2.py # Additional tests for graph generator
β βββ test_graph_generator_cli.py # CLI tests for graph generator
βββ utils/
β βββ __init__.py # Module initializer for utilities
β βββ logging.py # Logging utilities for W&B integration
β βββ preprocessing.py # Preprocessing utilities
βββ .gitignore # Git ignore file
βββ LICENSE # License information
βββ Final Project Report.pdf # Final Project Report, detailing the project's theoretical background and results
βββ README.md # Project documentation
βββ requirements.txt # Python dependencies
-
Clone the Repository:
git clone https://github.com/John-Isr/Identifying-Subgroup-Specific-Genetic-Modules-Using-GNN.git cd Identifying-Subgroup-Specific-Genetic-Modules -
Create and Activate a Virtual Environment:
python -m venv venv source venv/bin/activate # On Windows: venv\Scripts\activate
-
Install Dependencies:
pip install -r requirements.txt
-
Set Up Weights & Biases (wandb):
wandb login
(Obtain your API key by creating a free account at wandb.ai.)
Generate raw graphs and initialize their edge features:
./scripts/run_pipeline.shNote: If you want to customize the graphs you're generating, you can use the following commands:
python -m data_pipeline.graph_generator --conditions Optimal Suboptimal Default --output_dir data/graphs --num_graphs 1000
python -m data_pipeline.init_edge_features --input_dir data/graphs --output_dir data/modified_graphsTrain your GNN model using the training script and configuration file:
python -m scripts.train_model --config experiments/default_config.yaml --data_dir data/modified_graphs --epochs 250Optimize the modelβs hyperparameters with Optuna:
python -m scripts.tune_hyperparams --n_trials 150 --study_name default --optuna_storage_path ./experiments/optuna_studies/study.db --min_resource 45 --max_resource 250 --reduction_factor 2Import the inference script and use it to process a NetworkX graph and obtain node predictions, for example:
import networkx as nx
from scripts import run_inference
# Create or load your NetworkX graph
nx_graph = nx.erdos_renyi_graph(100, 0.15)
# Ensure each node has the features expected by your model
for node in nx_graph.nodes():
nx_graph.nodes[node]['feature'] = [0.5] # Example feature
# Run inference to add the 'classification' attribute to nodes
result_graph = run_inference(
nx_graph=nx_graph,
model_path="experiments/trained_model.pt",
device='auto'
)
# Retrieve predictions
predictions = nx.get_node_attributes(result_graph, 'classification')
print(predictions)Real-time experiment tracking and visualization are handled via Weights & Biases (wandb). Logging in to wandb is required to the training and hyperparameter tuning scripts.
- Shiran Gerassy-Vainberg & Shai S. Shen-Orr (2024), A Personalized Network Framework Reveals Predictive Axis of Anti-TNF Response Across Diseases.
- Chung (1997), Spectral Graph Theory
- Gori et al. (2005), A new model for learning in graph domains.
- One of the first formulations of Graph Neural Networks (GNNs).
- Scarselli et al. (2008), The Graph Neural Network Model.
- Introduced a formal recursive framework for learning on graphs.
- Duvenaud et al. (2015), Convolutional Networks on Graphs for Learning Molecular Fingerprints.
- One of the first graph convolutional approaches, applied to molecular data.
- Vaswani et al. (2017), Attention Is All You Need.
- Introduced Transformers, which inspired attention-based mechanisms in GNNs like GAT and GATv2.
- VelickoviΔ et al. (2018), Graph Attention Networks.
- Proposed the original Graph Attention Network (GAT), using attention mechanisms for adaptive neighborhood aggregation.
- Brody, Alon & Yahav (2022), How Attentive are Graph Attention Networks?
- Introduced GATv2, an improved version of GAT with dynamic attention mechanisms.
- Wu et al. (2019), A Comprehensive Survey on Graph Neural Networks.
- Summarizes the evolution of GNN architectures, including GCN, GraphSAGE, and GAT.
- Yaniv Slor Futterman (yaniv.slor@campus.technion.ac.il)
- Jonathan Israel (jonathani@campus.technion.ac.il)
This project is licensed under the Apache 2.0 License β see the LICENSE file for details.