Advanced Marketing Mix Modeling with Causal Inference and Deep Learning
- Config-Driven: Every setting configurable via
config.py - GRU-Based Temporal Modeling: Captures complex time-varying effects
- DAG Learning: Discovers causal relationships between channels
- Learnable Coefficient Bounds: Channel-specific, data-driven constraints
- Data-Driven Seasonality: Automatic seasonal decomposition per region
- Huber Loss: Robust to outliers and extreme values
- Multiple Metrics: RMSE, R², MAE, Trimmed RMSE, Log-space metrics
- Advanced Regularization: L1/L2, sparsity, coefficient-specific penalties
- Gradient Clipping: Parameter-specific clipping for stability
- 14+ Interactive Visualizations: Complete dashboard with insights
- Response Curves: Non-linear saturation analysis with Hill equations
- Budget Optimization: Constrained optimization for optimal channel allocation
- DMA-Level Contributions: True economic impact calculation
- Channel Effectiveness: Detailed performance analysis
- DAG Visualization: Interactive causal network graphs
pip install deepcausalmmmpip install git+https://github.com/adityapt/deepcausalmmm.git# Clone repository
git clone https://github.com/adityapt/deepcausalmmm.git
cd deepcausalmmm
pip install -e .pip install torch pandas numpy plotly networkx statsmodels scikit-learn tqdmimport pandas as pd
from deepcausalmmm import DeepCausalMMM, get_device
from deepcausalmmm.core import get_default_config
from deepcausalmmm.core.trainer import ModelTrainer
from deepcausalmmm.core.data import UnifiedDataPipeline
# Load your data
data = pd.read_csv('your_mmm_data.csv')
# Get optimized configuration
config = get_default_config()
# Check device availability
device = get_device()
print(f"Using device: {device}")
# Process data with unified pipeline
pipeline = UnifiedDataPipeline(config)
processed_data = pipeline.fit_transform(data)
# Train with ModelTrainer (recommended approach)
trainer = ModelTrainer(config)
model, results = trainer.train(processed_data)
# Generate comprehensive dashboard
python dashboard_rmse_optimized.py # Run the main dashboard script# Run from the project root directory
python dashboard_rmse_optimized.py# Verify installation works
from deepcausalmmm import DeepCausalMMM, get_device
from deepcausalmmm.core import get_default_config
print("DeepCausalMMM package imported successfully!")
print(f"Device: {get_device()}")deepcausalmmm/ # Project root
├── pyproject.toml # Package configuration and dependencies
├── README.md # This documentation
├── LICENSE # MIT License
├── CHANGELOG.md # Version history and changes
├── CONTRIBUTING.md # Development guidelines
├── CODE_OF_CONDUCT.md # Code of conduct
├── CITATION.cff # Citation metadata for Zenodo/GitHub
├── Makefile # Build and development tasks
├── MANIFEST.in # Package manifest for distribution
│
├── deepcausalmmm/ # Main package directory
│ ├── __init__.py # Package initialization and exports
│ ├── cli.py # Command-line interface
│ ├── exceptions.py # Custom exception classes
│ │
│ ├── core/ # Core model components
│ │ ├── __init__.py # Core module initialization
│ │ ├── config.py # Optimized configuration parameters
│ │ ├── unified_model.py # Main DeepCausalMMM model architecture
│ │ ├── trainer.py # ModelTrainer class for training
│ │ ├── data.py # UnifiedDataPipeline for data processing
│ │ ├── scaling.py # SimpleGlobalScaler for data normalization
│ │ ├── seasonality.py # Seasonal decomposition utilities
│ │ ├── dag_model.py # DAG learning and causal inference
│ │ ├── inference.py # Model inference and prediction
│ │ ├── train_model.py # Training functions and utilities
│ │ └── visualization.py # Core visualization components
│ │
│ ├── postprocess/ # Analysis and post-processing
│ │ ├── __init__.py # Postprocess module initialization
│ │ ├── analysis.py # Statistical analysis utilities
│ │ ├── comprehensive_analysis.py # Comprehensive analyzer
│ │ ├── response_curves.py # Non-linear response curve fitting (Hill equations)
│ │ ├── optimization.py # Budget optimization with response curves
│ │ ├── optimization_utils.py # Optimization utility functions
│ │ └── dag_postprocess.py # DAG post-processing and analysis
│ │
│ └── utils/ # Utility functions
│ ├── __init__.py # Utils module initialization
│ ├── device.py # GPU/CPU device detection
│ └── data_generator.py # Synthetic data generation (ConfigurableDataGenerator)
│
├── examples/ # Example scripts and notebooks
│ ├── quickstart.ipynb # Interactive Jupyter notebook for Google Colab
│ ├── dashboard_rmse_optimized.py # Comprehensive dashboard with 14+ visualizations
│ ├── example_response_curves.py # Response curve fitting examples
│ └── example_budget_optimization.py # Budget optimization workflow
│
├── tests/ # Test suite
│ ├── __init__.py # Test package initialization
│ ├── unit/ # Unit tests
│ │ ├── __init__.py
│ │ ├── test_config.py # Configuration tests
│ │ ├── test_model.py # Model architecture tests
│ │ ├── test_scaling.py # Data scaling tests
│ │ └── test_response_curves.py # Response curve fitting tests
│ └── integration/ # Integration tests
│ ├── __init__.py
│ └── test_end_to_end.py # End-to-end integration tests
│
├── docs/ # Documentation
│ ├── Makefile # Documentation build tasks
│ ├── make.bat # Windows documentation build
│ ├── requirements.txt # Documentation dependencies
│ └── source/ # Sphinx documentation source
│ ├── conf.py # Sphinx configuration
│ ├── index.rst # Documentation index
│ ├── installation.rst # Installation guide
│ ├── quickstart.rst # Quick start guide
│ ├── contributing.rst # Contributing guide
│ ├── api/ # API documentation
│ │ ├── index.rst
│ │ ├── core.rst
│ │ ├── data.rst
│ │ ├── trainer.rst
│ │ ├── inference.rst
│ │ ├── analysis.rst
│ │ ├── response_curves.rst # Response curves API
│ │ ├── optimization.rst # Budget optimization API
│ │ ├── utils.rst
│ │ └── exceptions.rst
│ ├── examples/ # Example documentation
│ │ └── index.rst
│ └── tutorials/ # Tutorial documentation
│ └── index.rst
│
└── JOSS/ # Journal of Open Source Software submission
├── paper.md # JOSS paper manuscript
├── paper.bib # Bibliography
├── figure_dag_professional.png # DAG visualization figure
└── figure_response_curve_simple.png # Response curve figure
The comprehensive dashboard includes:
- Performance Metrics: Training vs Holdout comparison
- Actual vs Predicted: Time series visualization
- Holdout Scatter: Generalization assessment
- Economic Contributions: Total KPI per channel
- Contribution Breakdown: Donut chart with percentages
- Waterfall Analysis: Decomposed contribution flow
- Channel Effectiveness: Coefficient distributions
- DAG Network: Interactive causal relationships
- DAG Heatmap: Adjacency matrix visualization
- Stacked Contributions: Time-based channel impact
- Individual Channels: Detailed channel analysis
- Scaled Data: Normalized time series
- Control Variables: External factor analysis
- Response Curves: Non-linear response curves (diminishing returns analysis) with Hill equations
Key configuration parameters:
{
# Model Architecture
'hidden_dim': 320, # Optimal hidden dimension
'dropout': 0.08, # Proven stable dropout
'gru_layers': 1, # Single layer for stability
# Training Parameters
'n_epochs': 6500, # Optimal convergence epochs
'learning_rate': 0.009, # Fine-tuned learning rate
'temporal_regularization': 0.04, # Proven regularization
# Loss Function
'use_huber_loss': True, # Robust to outliers
'huber_delta': 0.3, # Optimal delta value
# Data Processing
'holdout_ratio': 0.08, # Optimal train/test split
'burn_in_weeks': 6, # Stabilization period
}- Media Coefficient Bounds:
F.softplus(coeff_max_raw) * torch.sigmoid(media_coeffs_raw) - Control Coefficients: Unbounded with gradient clipping
- Trend Damping:
torch.exp(trend_damping_raw) - Baseline Components: Non-negative via
F.softplus - Seasonal Coefficient: Learnable seasonal contribution
- SOV Scaling: Share-of-voice normalization for media channels
- Z-Score Normalization: For control variables (weather, events, etc.)
- Min-Max Seasonality: Regional seasonal scaling (0-1) using
seasonal_decompose - Consistent Transforms: Same scaling applied to train/holdout splits
- DMA-Level Processing: True economic contributions calculated per region
- Coefficient L2: Channel-specific regularization
- Sparsity Control: GRU parameter sparsity
- DAG Regularization: Acyclicity constraints
- Gradient Clipping: Parameter-specific clipping
- Hill Saturation Modeling: Non-linear response curves with Hill equations
- Automatic Curve Fitting: Fits S-shaped saturation curves to channel data
- National-Level Aggregation: Aggregates DMA-week data to national weekly level
- Proportional Allocation: Correctly scales log-space contributions to original scale
- Interactive Visualizations: Plotly-based interactive response curve plots
- Performance Metrics: R², slope, and saturation point for each channel
from deepcausalmmm.postprocess import ResponseCurveFit
# Fit response curves to channel data
fitter = ResponseCurveFit(
data=channel_data,
x_col='impressions',
y_col='contributions',
model_level='national',
date_col='week'
)
# Get fitted parameters
slope, saturation = fitter.fit_curve()
r2_score = fitter.calculate_r2_and_plot(save_path='response_curve.html')
print(f"Slope: {slope:.3f}, Saturation: {saturation:.3f}, R²: {r2_score:.3f}")- Constrained Optimization: Find optimal budget allocation across channels
- Multiple Methods: SLSQP (default), trust-constr, differential evolution, hybrid
- Hill Equation Integration: Uses fitted response curves for saturation modeling
- Channel Constraints: Set min/max spend limits based on business requirements
- Scenario Comparison: Compare current vs optimal allocations
- ROI Maximization: Maximize predicted response subject to budget and constraints
from deepcausalmmm import optimize_budget_from_curves
# After training your model and fitting response curves...
# Use optimize_budget_from_curves() with your fitted curve parameters
result = optimize_budget_from_curves(
budget=1_000_000,
curve_params=fitted_curves_df, # DataFrame with: channel, top, bottom, saturation, slope
num_weeks=52,
constraints={
'TV': {'lower': 100000, 'upper': 600000},
'Search': {'lower': 150000, 'upper': 500000},
'Social': {'lower': 50000, 'upper': 300000}
},
method='SLSQP'
)
# View results
if result.success:
print(f"Optimal Allocation: {result.allocation}")
print(f"Predicted Response: {result.predicted_response:,.0f}")
print(result.by_channel)Example Output:
Optimal Allocation: {'TV': 100000, 'Search': 420000, 'Social': 300000, ...}
Predicted Response: 627,788
Detailed Metrics:
channel total_spend weekly_spend roi spend_pct response_pct saturation_pct
Search 420,000 8,076.92 0.56 42.0% 37.8% 323%
Social 300,000 5,769.23 0.73 30.0% 34.8% 288%
TV 100,000 1,923.08 0.13 10.0% 2.1% 64%
See examples/example_budget_optimization.py for complete workflow and tips.
Performance benchmarks will be added with masked/anonymized data to demonstrate model capabilities while protecting proprietary information.
- Python 3.8+
- PyTorch 1.13+
- pandas 1.5+
- numpy 1.21+
- plotly 5.11+
- statsmodels 0.13+
- scikit-learn 1.1+
python -m pytest tests/See CONTRIBUTING.md for development guidelines.
MIT License - see LICENSE file.
"Achieved 93% holdout R² with only 3.6% performance gap - exceptional generalization!"
"Zero hardcoding approach makes it work perfectly on our different datasets without any modifications"
"The comprehensive dashboard with 14+ interactive visualizations including response curves provides insights we never had before"
"DMA-level contributions and DAG learning revealed true causal relationships between our marketing channels"
- Documentation: Comprehensive README with examples
- Issues: Use GitHub issues for bug reports and feature requests
- Performance: All configurations battle-tested and production-ready
- Zero Hardcoding: Fully generalizable across different datasets and industries
- Full Documentation: deepcausalmmm.readthedocs.io
- Quick Start Guide: Installation & Usage
- API Reference: Complete API Documentation
- Tutorials: Step-by-step Guides
- Examples: Practical Use Cases
If you use DeepCausalMMM in your research, please cite:
@article{tirumala2025deepcausalmmm,
title={DeepCausalMMM: A Deep Learning Framework for Marketing Mix Modeling with Causal Inference},
author={Puttaparthi Tirumala, Aditya},
journal={arXiv preprint arXiv:2510.13087},
year={2025}
}Or click the "Cite this repository" button on GitHub for other citation formats (APA, Chicago, MLA).
- Main Dashboard:
dashboard_rmse_optimized.py- Complete analysis pipeline - Budget Optimization:
examples/example_budget_optimization.py- End-to-end optimization workflow - Core Model:
deepcausalmmm/core/unified_model.py- DeepCausalMMM architecture - Configuration:
deepcausalmmm/core/config.py- All tunable parameters - Data Pipeline:
deepcausalmmm/core/data.py- Data processing and scaling
DeepCausalMMM - Where Deep Learning meets Causal Inference for Superior Marketing Mix Modeling
arXiv preprint - https://www.arxiv.org/abs/2510.13087