GNODEVAE is a computational framework that integrates Graph Attention Networks (GAT), Neural Ordinary Differential Equations (NODE), and Variational Autoencoders (VAE). It targets three common challenges in single-cell RNA sequencing data analysis:
- Capturing complex topological relationships between cells
- Modeling continuous dynamic processes of cell differentiation
- Handling high levels of technical noise and biological variation
The integration is intended to support identification of cell subpopulations, reconstruction of developmental trajectories, and analysis of cellular heterogeneity.
The GAT attention mechanism adaptively weights gene expression profiles, prioritizing biological relationships while reducing technical noise in heterogeneous cell populations.
Neural ordinary differential equations transform static representations into dynamic systems, with time variables providing a continuous parameterization of developmental processes.
The model's latent space is designed to capture variation in cell differentiation rates, while attention weights can be inspected alongside known developmental relationships between cell types.
The accompanying study compares GNODEVAE with several single-cell analysis methods, including scVI, DIP-VAE, TC-VAE, beta-VAE, Info-VAE, and scTour, across 13 datasets.
The accompanying study also reports stronger Calinski-Harabasz scores than selected baselines for gene trend analysis.
- Python 3.8 or higher
- PyTorch 1.12 or higher (with CUDA support recommended for GPU acceleration)
- PyTorch Geometric
# Clone the repository
git clone https://github.com/PeterPonyu/GNODEVAE.git
cd GNODEVAE
# Install dependencies
pip install torch torch-geometric scanpy anndata numpy pandas scikit-learn tqdm psutil torchdiffeqThe main dependencies include:
torch- PyTorch deep learning frameworktorch-geometric- Geometric deep learning extension for PyTorchscanpy- Single-cell analysis toolkitanndata- Annotated data structures for single-cell datatorchdiffeq- Differentiable ODE solvers for PyTorchnumpy,pandas- Data manipulationscikit-learn- Machine learning utilitiestqdm- Progress barspsutil- System resource monitoring
import scanpy as sc
from GNODEVAE import agent_r # GraphVAE with refined architecture
# OR
from GNODEVAE import agent # Standard GraphVAE
# For full GNODEVAE with ODE support, use:
# from GNODEVAE.GODEVAE_agent import GNODEVAE_agent_r
# Load your single-cell data
adata = sc.read_h5ad('your_data.h5ad')
# Initialize the GNODEVAE agent
model = agent_r(
adata=adata,
layer='counts', # Layer containing count data
n_var=2000, # Number of highly variable genes
tech='PCA', # Dimensionality reduction technique
n_neighbors=15, # Number of neighbors for graph construction
latent_dim=10, # Latent space dimension
hidden_dim=128, # Hidden layer dimension
encoder_type='graph', # Use graph encoder
graph_type='GAT', # Graph Attention Network
lr=1e-4, # Learning rate
device='cuda' # Use GPU if available
)
# Train the model
model.fit(epochs=300, update_steps=10, silent=False)
# Extract latent representations
latent = model.get_latent()
# Store latent representation in AnnData object
adata.obsm['X_gnodevae'] = latent
# Perform downstream analysis (e.g., clustering)
import scanpy as sc
sc.pp.neighbors(adata, use_rep='X_gnodevae')
sc.tl.leiden(adata)
sc.tl.umap(adata)from GNODEVAE import agent
# Initialize standard GraphVAE agent
model = agent(
adata=adata,
layer='counts',
n_var=2000,
tech='PCA',
n_neighbors=15,
latent_dim=10,
hidden_dim=128,
encoder_type='GAT',
lr=1e-4
)
# Train and extract embeddings
model.fit(epochs=300)
latent = model.get_latent()layer(str): Layer of AnnData to use (default: 'counts')n_var(int): Number of highly variable genes to select (default: None, uses all)tech(str): Dimensionality reduction method - 'PCA', 'NMF', 'FastICA', 'TruncatedSVD', 'FactorAnalysis', or 'LatentDirichletAllocation' (default: 'PCA')n_neighbors(int): Number of neighbors for graph construction (default: 15)batch_tech(str): Batch correction method - 'harmony' or 'scvi' (default: None)all_feat(bool): Whether to use all features or only highly variable genes (default: True)
hidden_dim(int): Hidden layer dimension (default: 128)latent_dim(int): Latent space dimension for embeddings (default: 10)encoder_type(str): Encoder type - 'graph' or 'linear' (default: 'graph')graph_type(str): Graph convolution type - 'GAT', 'GCN', 'SAGE', 'Transformer', etc. (default: 'GAT')structure_decoder_type(str): Structure decoder type - 'mlp', 'bilinear', or 'inner_product' (default: 'mlp')feature_decoder_type(str): Feature decoder type - 'linear' or 'graph' (default: 'linear')hidden_layers(int): Number of hidden layers (default: 2)dropout(float): Dropout rate (default: 0.05)use_residual(bool): Whether to use residual connections (default: True)
lr(float): Learning rate for optimizer (default: 1e-4)beta(float): Weight for KL divergence loss term (default: 1.0)graph(float): Weight for graph reconstruction loss (default: 1.0)epochs(int): Number of training epochs (default: 300)device(str or torch.device): Computing device - 'cuda' or 'cpu' (default: auto-detect)num_parts(int): Number of graph partitions for mini-batch training (default: 10)
n_ode_hidden(int): Number of hidden units in ODE function (default: varies)w_recon(float): Weight for reconstruction loss (default: 1.0)w_kl(float): Weight for KL divergence loss (default: 1.0)w_adj(float): Weight for adjacency matrix loss (default: 1.0)w_recon_ode(float): Weight for ODE reconstruction loss (default: 1.0)
GNODEVAE consists of three main components:
- Graph Encoder: Encodes cell-cell relationships and gene expression using Graph Attention Networks (GAT) or other graph convolution layers
- Neural ODE: Models continuous developmental trajectories in the latent space
- Decoder: Reconstructs both graph structure and gene expression from latent representations
The model learns a low-dimensional latent representation that captures:
- Cell type identity
- Developmental state
- Cell-cell relationships
- Temporal dynamics (with ODE component)
After training, GNODEVAE produces:
- Latent representations: Low-dimensional embeddings for each cell
- Clustering metrics: ARI, NMI, Silhouette score, Calinski-Harabasz index, Davies-Bouldin index
- Pseudo-time: Developmental trajectory information (for agent_r with ODE)
- Graph structure: Learned cell-cell similarity graph
# Use custom graph construction parameters
model = agent_r(
adata=adata,
n_neighbors=30, # Increase neighbors for denser graph
graph_type='Transformer', # Use Transformer convolution
alpha=0.5 # Set alpha for specific layers
)# Use interpretable GraphVAE
model = agent(
adata=adata,
interpretable=True, # Enable interpretable mode
idim=2 # Interpretable dimension
)# For GNODEVAE models with ODE component
# Note: Use GNODEVAE_agent_r from GODEVAE_agent module for pseudo-time functionality
from GNODEVAE.GODEVAE_agent import GNODEVAE_agent_r
model = GNODEVAE_agent_r(adata=adata, ...)
model.fit(epochs=300)
# Get pseudo-time for cells
pseudotime_df = model.partition_time()GNODEVAE automatically computes several clustering quality metrics during training:
- ARI (Adjusted Rand Index): Measures clustering agreement with ground truth
- NMI (Normalized Mutual Information): Information-theoretic clustering metric
- ASW (Average Silhouette Width): Measures cluster separation
- C_H (Calinski-Harabasz Index): Ratio of between-cluster to within-cluster variance
- D_B (Davies-Bouldin Index): Average similarity between clusters
- P_C (Pearson Correlation): Correlation between latent dimensions
If you use GNODEVAE in your research, please cite:
@article{fu2025gnodevae,
title={GNODEVAE: a graph-based ODE-VAE enhances clustering for single-cell data},
author={Fu, Z. and Chen, C. and Wang, S. and others},
journal={BMC Genomics},
volume={26},
pages={767},
year={2025},
doi={10.1186/s12864-025-11946-7}
}Full Citation: Fu, Z., Chen, C., Wang, S. et al. GNODEVAE: a graph-based ODE-VAE enhances clustering for single-cell data. BMC Genomics 26, 767 (2025). https://doi.org/10.1186/s12864-025-11946-7
See LICENSE file for details.
For questions and feedback, please open an issue on the GitHub repository.
