XAIR is a comprehensive system for generating, analyzing, and validating reasoning paths of large language models (LLMs). It provides powerful tools for understanding and evaluating how LLMs reason, with a focus on counterfactual analysis and knowledge graph validation.
XAIR consists of three main components:
- CGRT (Counterfactual Graph Reasoning Tree): Generates multiple reasoning paths and builds a tree structure to visualize the model's reasoning process.
- Counterfactual Analysis: Identifies critical tokens in the reasoning process and explores what happens when they're modified.
- Knowledge Graph Validation: Maps reasoning statements to external knowledge in Wikidata and validates factual accuracy.
- Generate multiple reasoning paths with different temperature settings
- Identify divergence points where reasoning paths differ
- Construct a graph representation of the reasoning process
- Analyze attention patterns to identify important tokens
- Generate counterfactual alternatives to explore causal relationships
- Calculate Counterfactual Flip Rate (CFR) to quantify decision sensitivity
- Validate reasoning against Wikidata knowledge graph
- Calculate trustworthiness scores for reasoning paths
- Export visualizations for understanding model reasoning
- Optimized for CPUs, GPUs, and Apple Silicon (MPS)
- Performance presets for balancing speed vs analysis depth
- Python 3.8+
- PyTorch 2.0+
- Hugging Face Transformers 4.35+
- NetworkX 3.1+
- Other dependencies listed in
requirements.txt
- Clone the repository:
git clone https://github.com/veerdosi/xai.git
cd xai- Install the required packages:
pip install -r requirements.txt- For knowledge graph validation (optional):
python -m spacy download en_core_web_smRun the main script with default settings:
python main.pyYou'll be prompted to enter a query. The system will generate multiple reasoning paths, analyze them, and show you the results.
You can save and load configurations to easily reuse settings:
- Create a configuration file:
python main.py --save-config my_config.json- Load a configuration file:
python main.py --config my_config.json--model: Model name or path (default: "meta-llama/Llama-3.2-1B")--device: Device to use - "cpu", "cuda", "mps", or "auto" (default: "auto")--max-tokens: Maximum tokens to generate (default: 256)--verbose: Enable verbose logging (flag)--output-dir: Output directory (default: "output")
--performance: Performance preset to use - "max_speed", "balanced", or "max_quality" (default: "balanced")--fast-mode: Skip hidden states and attention collection for faster generation (flag)--fast-init: Skip non-essential initialization steps for faster startup (flag)
--temperatures: Comma-separated temperatures for generation (default: "0.2,0.7,1.0")--paths-per-temp: Paths to generate per temperature (default: 1)
--counterfactual-tokens: Top-k tokens for counterfactual generation (default: 5)--attention-threshold: Minimum attention threshold for counterfactuals (default: 0.3)--max-counterfactuals: Maximum counterfactuals to generate (default: 20)
--kg-use-local-model: Use local sentence transformer model (flag)--kg-similarity-threshold: Minimum similarity threshold for KG entity mapping (default: 0.6)--kg-skip: Skip Knowledge Graph processing (useful for slower machines) (flag)
--generate-visualizations: Generate visualizations for the results (flag)
--config: Path to configuration file--save-config: Save configuration to the specified file path
XAIR automatically detects and optimizes for your available hardware:
python main.py --device cudapython main.py --device mps --max-tokens 256python main.py --device cpu --max-tokens 256The Counterfactual Graph Reasoning Tree component:
- Generates multiple reasoning paths with different temperature settings
- Identifies points where reasoning paths diverge
- Analyzes token-level probabilities and attention patterns
- Constructs a directed graph representing all reasoning paths
- Calculates importance scores for nodes based on multiple factors
The Counterfactual component:
- Identifies tokens with high importance/attention scores
- Generates alternative versions by substituting these tokens
- Evaluates the impact of substitutions on the output
- Identifies "flip points" where small changes cause different conclusions
- Calculates Counterfactual Flip Rate (CFR) to quantify reasoning stability
The Knowledge Graph component:
- Maps tokens and statements to Wikidata entities
- Validates factual statements against external knowledge
- Identifies supported statements, contradicted statements, and unverified claims
- Calculates trustworthiness scores for reasoning paths
- Provides detailed validation reports
When you run XAIR with the --generate-visualizations flag, it creates several visualizations:
- Reasoning Tree: Shows the structural relationships between reasoning steps
- Token Importance: Highlights tokens with high importance and attention scores
- Counterfactual Impact: Visualizes the impact of different token substitutions
- Knowledge Graph Validation: Shows trustworthiness scores across reasoning paths
- Divergence Points: Highlights where reasoning paths diverge
Visualizations are saved in the output/visualizations directory and can be viewed through the generated index.html file.
The system generates several outputs in the specified output directory:
generation_results.json: Raw generation results from the modeldivergence_points.json: Detected divergence points between pathsreasoning_tree.json: The constructed reasoning tree in JSON formatpath_comparison.txt: Detailed comparison of different reasoning paths
counterfactuals.json: Generated counterfactual candidatescounterfactual_evaluation.json: Evaluation metrics for counterfactualscounterfactual_comparison.txt: Detailed comparison of counterfactualscounterfactual_state.json: Complete state of counterfactual analysis
entity_mapping.json: Mapping of tokens to knowledge graph entitiesvalidation_results.json: Results of knowledge graph validationvalidation_report.txt: Detailed report of validation findingskg_cache/: Cache directory for knowledge graph requests
index.html: Entry point for viewing all visualizationsreasoning_tree.png: Visualization of the reasoning treetoken_importance.png: Chart of token importance scorescounterfactual_impact.png: Visualization of counterfactual impactkg_validation.png: Knowledge graph validation resultsdivergence_points.png: Visualization of divergence points
You can import XAIR components into your own Python scripts:
from backend.models.llm_interface import LlamaInterface
from backend.cgrt.cgrt_main import CGRT, get_performance_preset
from backend.counterfactual.counterfactual_main import Counterfactual
from backend.knowledge_graph.kg_main import KnowledgeGraph
from backend.utils.config import XAIRConfig
# Load configuration
config = XAIRConfig()
config.model_name_or_path = "meta-llama/Llama-3.2-1B"
# Apply performance optimizations (optional)
preset = get_performance_preset("max_speed")
config.fast_mode = preset["fast_mode"]
config.fast_init = preset["fast_init"]
config.cgrt.temperatures = preset["temperatures"]
# Initialize components
llm = LlamaInterface(model_name_or_path=config.model_name_or_path, fast_init=config.fast_init)
cgrt = CGRT(
model_name_or_path=config.model_name_or_path,
temperatures=config.cgrt.temperatures,
fast_mode=config.fast_mode,
fast_init=config.fast_init
)
counterfactual = Counterfactual()
# Process input
tree = cgrt.process_input("What is the capital of France?")
paths = cgrt.get_paths_text()
counterfactuals = counterfactual.generate_counterfactuals(cgrt.tree_builder, llm, "What is the capital of France?", cgrt.paths)
# Print results
print(f"Generated {len(paths)} reasoning paths")
for i, path_text in enumerate(paths):
print(f"Path {i+1}: {path_text[:100]}...")You can customize visualizations using the functions in backend/utils/viz_utils.py:
from backend.utils.viz_utils import plot_reasoning_tree, setup_visualization_style
# Set visualization style
setup_visualization_style(style="whitegrid", context="paper", font_scale=1.5, palette="viridis")
# Create custom tree visualization
plot_reasoning_tree(
cgrt.tree_builder.graph,
output_path="custom_tree.png",
title="My Custom Reasoning Tree",
highlight_nodes=["node_1", "node_5"],
show_edge_labels=True
)For systems with limited memory:
python main.py --model meta-llama/Llama-3.2-1B --max-tokens 128 --paths-per-temp 1 --kg-skipFor detailed analysis on powerful hardware:
python main.py --model meta-llama/Llama-3.2-70B-Instruct --max-tokens 512 --paths-per-temp 3 --temperatures 0.1,0.5,0.9,1.3 --generate-visualizationsXAIR includes several performance optimization features to make generation and analysis faster.
Use performance presets to easily balance speed vs analysis depth:
# Maximum speed (fastest, simplified analysis)
python main.py --performance max_speed
# Balanced performance (good balance of speed and analysis)
python main.py --performance balanced
# Maximum quality (most thorough analysis)
python main.py --performance max_qualityFor more fine-grained control, you can use individual fast mode options:
# Skip collecting hidden states and attention (faster generation)
python main.py --fast-mode
# Skip non-essential initialization steps (faster startup)
python main.py --fast-init
# Combine options for maximum speed
python main.py --fast-mode --fast-init --temperatures 0.7The performance optimizations can significantly improve response times:
| Configuration | Response Time | Startup Time | Analysis Depth |
|---|---|---|---|
| Default | Baseline | Baseline | Full |
| --fast-mode | 2-4x faster | Baseline | Basic |
| --fast-init | Baseline | 2x faster | Full |
| max_speed | 3-5x faster | 2x faster | Basic |
| max_quality | Baseline | Baseline | Enhanced |
- max_speed: Use when you need quick responses and don't need detailed analysis
- balanced: Good for most use cases with reasonable performance
- max_quality: Use when you need the most thorough analysis and have time
- --fast-mode: Good when you need to process many prompts quickly
- --fast-init: Useful when starting up the system frequently
To contribute to XAIR:
- Fork the repository
- Create a feature branch:
git checkout -b feature-name - Commit your changes:
git commit -am 'Add feature' - Push to the branch:
git push origin feature-name - Submit a pull request
We use pytest for testing:
pytest tests/For specific test files:
pytest tests/test_config.pyThis system builds on research in explainable AI, counterfactual analysis, and knowledge graph integration for language models. It incorporates techniques from:
- Counterfactual explanations
- Attention flow analysis
- Knowledge graph entity linking
- Semantic similarity measurement
- Token-level probability analysis
If you use XAIR in your research, please cite:
@software{xair2023,
author = {Veerdosi},
title = {XAIR: Explainable AI Reasoning System},
year = {2023},
url = {https://github.com/veerdosi/xair}
}