CCVGAE is a deep learning framework for stable and interpretable analysis of single-cell multi-omics data. It combines centroid-based latent summaries, coupling losses, and graph neural network components in a variational autoencoder workflow.
- Centroid Inference: Uses deterministic posterior means as stable cell embeddings.
- Coupling Mechanism: Regularizes latent space through intermediate representations.
- Graph Neural Networks: Leverages cell-cell similarity relationships with attention mechanisms.
- Scalable Architecture: Supports subgraph sampling for large-scale datasets.
- Flexible Design: Compatible with multiple encoder/decoder architectures and graph convolution types.
- Python >= 3.7
- PyTorch >= 1.8.0
- PyTorch Geometric
- scanpy
- anndata
- scikit-learn
- numpy
- psutil
- tqdm
git clone https://github.com/PeterPonyu/CCVGAE
cd CCVGAE
pip install -r requirements.txtimport scanpy as sc
from ccvgae import CCVGAE_agent
adata = sc.read_h5ad("your_data.h5ad")
agent = CCVGAE_agent(
adata=adata,
layer='counts', # Use the 'counts' layer for training
n_var=2000, # Number of highly variable genes to use
n_neighbors=15, # Number of neighbors for graph construction
latent_dim=10, # Dimension of the latent space
subgraph_size=512, # Size of subgraphs for training
num_subgraphs_per_epoch=10, # Number of subgraphs per epoch
lr=1e-4,
w_recon=1.0, # Weight for reconstruction loss
w_irecon=1.0, # Weight for coupling reconstruction loss
w_kl=1.0, # Weight for KL divergence loss
)
# Train the model
agent.fit(epochs=300, update_steps=10)
latent_representation = agent.get_latent()
from sklearn.cluster import KMeans
clusters = KMeans(n_clusters=10).fit_predict(latent_representation)
adata.obs['ccvgae_clusters'] = clustersCCVGAE automatically handles preprocessing steps including:
- Normalization: Total count normalization and log1p transformation
- Feature Selection: Highly variable genes selection
- Dimensionality Reduction: PCA or other methods (NMF, FastICA, TruncatedSVD, FactorAnalysis, LatentDirichletAllocation)
- Graph Construction: k-nearest neighbor graph based on reduced dimensions
- Batch Correction (optional): Harmony or scVI integration
# Example with custom preprocessing options
agent = CCVGAE_agent(
adata=adata,
layer='counts', # Layer containing raw counts
n_var=3000, # Select top 3000 highly variable genes
tech='PCA', # Dimensionality reduction method
n_neighbors=20, # Number of neighbors for KNN graph
batch_tech='harmony', # Optional: use Harmony for batch correction
all_feat=False, # Use only highly variable genes
)CCVGAE supports flexible architecture configuration:
agent = CCVGAE_agent(
adata=adata,
# Architecture parameters
encoder_type='graph', # 'graph' or 'linear'
graph_type='GAT', # 'GAT', 'GCN', 'SAGE', 'Cheb', etc.
hidden_dim=128, # Hidden layer dimension
latent_dim=10, # Latent space dimension
i_dim=10, # Intermediate coupling dimension
hidden_layers=2, # Number of hidden layers
decoder_hidden_dim=128, # Decoder hidden dimension
dropout=0.05, # Dropout rate
use_residual=True, # Use residual connections
structure_decoder_type='mlp', # 'mlp', 'bilinear', 'inner_product'
feature_decoder_type='linear', # Feature decoder type
)Control training dynamics and loss weighting:
agent = CCVGAE_agent(
adata=adata,
# Training parameters
lr=1e-4, # Learning rate
beta=1.0, # KL divergence regularization strength (β-VAE)
graph=1.0, # Graph reconstruction loss weight
w_recon=1.0, # Feature reconstruction loss weight
w_irecon=1.0, # Coupling reconstruction loss weight
w_kl=1.0, # KL divergence loss weight
w_adj=1.0, # Adjacency reconstruction loss weight
# Subgraph sampling for large datasets
subgraph_size=512, # Maximum nodes per subgraph
num_subgraphs_per_epoch=10, # Number of subgraphs per epoch
sampling_method='random', # Sampling strategy
# Device configuration
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
latent_type='q_m', # 'q_m' (deterministic) or 'q_z' (stochastic)
)
# Train the model
agent.fit(epochs=300, update_steps=10, silent=False)| Parameter | Type | Default | Description |
|---|---|---|---|
adata |
AnnData | Required | Annotated data object containing single-cell data |
layer |
str | 'counts' | Data layer to use for training |
n_var |
int | None | Number of highly variable genes (None for automatic) |
latent_dim |
int | 10 | Dimension of the main latent space |
hidden_dim |
int | 128 | Hidden layer dimension for networks |
| Parameter | Type | Default | Description |
|---|---|---|---|
encoder_type |
str | 'graph' | Encoder type: 'graph' or 'linear' |
graph_type |
str | 'GAT' | Graph convolution type: 'GAT', 'GCN', 'SAGE', 'Cheb', 'TAG', 'ARMA', 'Transformer', 'SG', 'SSG' |
structure_decoder_type |
str | 'mlp' | Structure decoder: 'mlp', 'bilinear', 'inner_product' |
feature_decoder_type |
str | 'linear' | Feature decoder type: 'linear' |
hidden_layers |
int | 2 | Number of hidden layers |
dropout |
float | 0.05 | Dropout probability for regularization |
use_residual |
bool | True | Whether to use residual connections |
i_dim |
int | 10 | Intermediate coupling space dimension |
| Parameter | Type | Default | Description |
|---|---|---|---|
lr |
float | 1e-4 | Learning rate for Adam optimizer |
w_recon |
float | 1.0 | Weight for feature reconstruction loss |
w_irecon |
float | 1.0 | Weight for coupling reconstruction loss |
w_kl |
float | 1.0 | Weight for KL divergence loss |
w_adj |
float | 1.0 | Weight for adjacency reconstruction loss |
beta |
float | 1.0 | KL divergence regularization strength (β-VAE) |
graph |
float | 1.0 | Graph reconstruction loss weight |
| Parameter | Type | Default | Description |
|---|---|---|---|
tech |
str | 'PCA' | Dimensionality reduction: 'PCA', 'NMF', 'FastICA', 'TruncatedSVD', 'FactorAnalysis', 'LatentDirichletAllocation' |
n_neighbors |
int | 15 | Number of neighbors for graph construction |
batch_tech |
str | None | Batch correction method: None, 'harmony', or 'scvi' |
all_feat |
bool | False | Use all features (True) or just highly variable genes (False) |
| Parameter | Type | Default | Description |
|---|---|---|---|
subgraph_size |
int | 512 | Maximum number of nodes per subgraph |
num_subgraphs_per_epoch |
int | 10 | Number of subgraphs sampled per training epoch |
sampling_method |
str | 'random' | Subgraph sampling strategy |
import scanpy as sc
from ccvgae import CCVGAE_agent
# Load data with multiple batches
adata = sc.read_h5ad("multi_batch_data.h5ad")
# Ensure batch information is in adata.obs
# adata.obs['batch'] should contain batch labels
# Initialize with batch correction
agent = CCVGAE_agent(
adata=adata,
layer='counts',
n_var=2000,
batch_tech='harmony', # Use Harmony for batch correction
n_neighbors=15,
latent_dim=20,
)
agent.fit(epochs=500, update_steps=20)
latent = agent.get_latent()
# Add latent representation to adata for visualization
adata.obsm['X_ccvgae'] = latent
sc.pp.neighbors(adata, use_rep='X_ccvgae')
sc.tl.umap(adata)
sc.pl.umap(adata, color=['batch', 'cell_type'])# For datasets with > 10,000 cells, use subgraph sampling
agent = CCVGAE_agent(
adata=large_adata,
layer='counts',
n_var=2000,
latent_dim=15,
subgraph_size=1024, # Larger subgraphs
num_subgraphs_per_epoch=20, # More subgraphs per epoch
)
agent.fit(epochs=300, update_steps=10)# Fine-tune loss weights for specific analysis needs
agent = CCVGAE_agent(
adata=adata,
layer='counts',
n_var=2000,
latent_dim=10,
w_recon=1.0, # Standard reconstruction
w_irecon=0.5, # Reduce coupling weight
w_kl=0.8, # Reduce KL weight for more flexibility
w_adj=1.2, # Emphasize graph structure learning
beta=0.5, # Lighter regularization
)
agent.fit(epochs=300, update_steps=10)# Use linear encoder when graph structure is less important
agent = CCVGAE_agent(
adata=adata,
layer='counts',
n_var=2000,
encoder_type='linear', # Faster than graph encoder
latent_dim=10,
hidden_dim=256,
hidden_layers=3,
)
agent.fit(epochs=200, update_steps=10)# Get latent representations (deterministic posterior means)
latent = agent.get_latent() # Shape: (n_cells, latent_dim)
# Add to AnnData for downstream analysis
adata.obsm['X_ccvgae'] = latentDuring training, CCVGAE tracks several metrics:
- Loss: Total loss value
- ARI: Adjusted Rand Index (clustering quality)
- NMI: Normalized Mutual Information
- ASW: Average Silhouette Width
- C_H: Calinski-Harabasz Index
- D_B: Davies-Bouldin Index
- P_C: Average pairwise correlation
# Access final scores after training
agent.score_final()
print(f"Final scores: {agent.final_score}")# Clustering
from sklearn.cluster import KMeans
clusters = KMeans(n_clusters=8).fit_predict(latent)
adata.obs['ccvgae_clusters'] = clusters
# Visualization with UMAP
adata.obsm['X_ccvgae'] = latent
sc.pp.neighbors(adata, use_rep='X_ccvgae')
sc.tl.umap(adata)
sc.pl.umap(adata, color=['ccvgae_clusters', 'cell_type'])
# Trajectory inference
import scanpy as sc
sc.tl.diffmap(adata, use_rep='X_ccvgae')
sc.tl.dpt(adata)If you use CCVGAE in your research, please cite:
@software{ccvgae,
title = {CCVGAE: Centroid-based Coupled Variational Graph Autoencoder for Single-Cell Multi-omics Data Integration},
author = {Fu, Zeyu},
year = {2024},
url = {https://github.com/PeterPonyu/CCVGAE}
}This project is licensed under the terms specified in the LICENSE file.
For questions, issues, or feature requests, please open an issue on GitHub.