#!/usr/bin/env python3
"""
Reproducible example demonstrating PostHog MongoDB connector performance issues.

This script replicates the exact logic from:
  posthog/temporal/data_imports/sources/mongodb/mongo.py

Usage:
    export MONGODB_CONNECTION_STRING="mongodb+srv://user:password@cluster.mongodb.net/database?authSource=admin"
    python test.py

Requirements:
    pip install pymongo certifi
"""

import contextlib
import os
import sys
import time
from collections import defaultdict
from collections.abc import Iterator
from typing import Any
from urllib.parse import parse_qs, urlparse

import certifi
from pymongo import MongoClient
from pymongo.collection import Collection


# =============================================================================
# EXACT COPY FROM: posthog/temporal/data_imports/sources/mongodb/mongo.py
# =============================================================================


def _parse_connection_string(connection_string: str) -> dict[str, Any]:
    """Parse MongoDB connection string and extract connection parameters."""
    parsed = urlparse(connection_string)

    if parsed.scheme not in ["mongodb", "mongodb+srv"]:
        raise ValueError(
            "Connection string must start with mongodb:// or mongodb+srv://"
        )

    host = parsed.hostname or "localhost"
    port = parsed.port or (27017 if parsed.scheme == "mongodb" else None)
    database = parsed.path.lstrip("/") if parsed.path else None
    user = parsed.username
    password = parsed.password

    query_params = parse_qs(parsed.query)
    auth_source = query_params.get("authSource", ["admin"])[0]
    tls = query_params.get("tls", ["false"])[0].lower() in ["true", "1"]
    ssl = query_params.get("ssl", ["false"])[0].lower() in ["true", "1"]
    use_tls = tls or ssl

    return {
        "host": host,
        "port": port,
        "database": database,
        "user": user,
        "password": "***REDACTED***",  # Don't log actual password
        "auth_source": auth_source,
        "tls": use_tls,
        "is_srv": parsed.scheme == "mongodb+srv",
    }


@contextlib.contextmanager
def mongo_client(connection_string: str) -> Iterator[MongoClient]:
    """PostHog's mongo_client context manager."""
    client: MongoClient = MongoClient(
        connection_string,
        serverSelectionTimeoutMS=10000,
        tls=True,
        tlsCAFile=certifi.where(),
    )
    try:
        yield client
    finally:
        client.close()


def _get_schema_from_query(collection: Collection) -> list[tuple[str, str]]:
    """
    PostHog's schema inference using aggregation.

    ISSUE: This scans ALL documents in the collection, which is O(n * fields).
    For large collections, this is extremely slow.
    """
    pipeline: list[dict[str, Any]] = [
        # NOTE: Missing {"$sample": {"size": 1000}} here would fix the performance issue
        {"$project": {"arrayofkeyvalue": {"$objectToArray": "$$ROOT"}}},
        {"$unwind": "$arrayofkeyvalue"},
        {
            "$group": {
                "_id": "$arrayofkeyvalue.k",
                "types": {"$addToSet": {"$type": "$arrayofkeyvalue.v"}},
            }
        },
    ]

    try:
        result = list(collection.aggregate(pipeline))
        if not result:
            return [("_id", "string")]
        return [(field["_id"], field["types"][0]) for field in result]
    except Exception:
        return [("_id", "string")]


def get_schemas_original(connection_string: str, database: str) -> dict[str, list[tuple[str, str]]]:
    """
    PostHog's get_schemas() - ORIGINAL VERSION (slow).

    ISSUE: Uses authorizedCollections=True but NOT nameOnly=True,
    which prevents MongoDB from filtering by collection-level permissions.
    """
    with mongo_client(connection_string) as client:
        db = client[database]
        schema_list = defaultdict(list)

        # ISSUE: Missing nameOnly=True prevents permission-based filtering
        collection_names = db.list_collection_names(authorizedCollections=True)

        for collection_name in collection_names:
            collection = db[collection_name]
            schema_info = _get_schema_from_query(collection)
            schema_list[collection_name].extend(schema_info)

    return schema_list


def get_schemas_fixed(connection_string: str, database: str) -> dict[str, list[tuple[str, str]]]:
    """
    FIXED VERSION with:
    1. nameOnly=True for proper permission filtering
    2. $sample in aggregation for performance
    """
    with mongo_client(connection_string) as client:
        db = client[database]
        schema_list = defaultdict(list)

        # FIX: Add nameOnly=True for proper permission filtering
        collection_names = db.list_collection_names(authorizedCollections=True, nameOnly=True)

        for collection_name in collection_names:
            collection = db[collection_name]

            # FIX: Add $sample to limit documents scanned
            pipeline = [
                {"$sample": {"size": 1000}},  # Only sample 1000 docs
                {"$project": {"arrayofkeyvalue": {"$objectToArray": "$$ROOT"}}},
                {"$unwind": "$arrayofkeyvalue"},
                {
                    "$group": {
                        "_id": "$arrayofkeyvalue.k",
                        "types": {"$addToSet": {"$type": "$arrayofkeyvalue.v"}},
                    }
                },
            ]

            try:
                result = list(collection.aggregate(pipeline))
                schema_info = [(field["_id"], field["types"][0]) for field in result] if result else [("_id", "string")]
            except Exception:
                schema_info = [("_id", "string")]

            schema_list[collection_name].extend(schema_info)

    return schema_list


# =============================================================================
# TEST SCRIPT
# =============================================================================

def main():
    # Get connection string from environment variable
    connection_string = os.environ.get("MONGODB_CONNECTION_STRING")

    if not connection_string:
        print("ERROR: Please set MONGODB_CONNECTION_STRING environment variable")
        print("")
        print("Usage:")
        print('  export MONGODB_CONNECTION_STRING="mongodb+srv://user:password@cluster.mongodb.net/database?authSource=admin"')
        print("  python test.py")
        sys.exit(1)

    print("=" * 70)
    print("PostHog MongoDB Connector - Performance Issue Reproduction")
    print("=" * 70)

    # Parse and display connection info (with redacted password)
    parsed = _parse_connection_string(connection_string)
    print(f"\nConnection Info:")
    print(f"  Host:        {parsed['host']}")
    print(f"  Database:    {parsed['database']}")
    print(f"  User:        {parsed['user']}")
    print(f"  Auth Source: {parsed['auth_source']}")

    if not parsed["database"]:
        print("\nERROR: Database name is required in connection string")
        sys.exit(1)

    # Test connection
    print("\n" + "-" * 70)
    print("Step 1: Testing connection...")
    print("-" * 70)

    try:
        with mongo_client(connection_string) as client:
            result = client.admin.command("ping")
            print(f"Connection successful: {result.get('ok') == 1.0}")
    except Exception as e:
        print(f"Connection failed: {e}")
        sys.exit(1)

    # Test original (slow) schema discovery
    print("\n" + "-" * 70)
    print("Step 2: Running ORIGINAL get_schemas() [PostHog's current implementation]")
    print("-" * 70)
    print("This may take a long time for databases with many collections...")

    start_time = time.time()
    try:
        schemas = get_schemas_original(connection_string, parsed["database"])
        elapsed = time.time() - start_time
        print(f"\nCompleted in {elapsed:.2f} seconds")
        print(f"Found {len(schemas)} collections")

        # Show collection sizes
        for name, columns in list(schemas.items())[:5]:
            print(f"  - {name}: {len(columns)} columns")
        if len(schemas) > 5:
            print(f"  ... and {len(schemas) - 5} more collections")

    except KeyboardInterrupt:
        elapsed = time.time() - start_time
        print(f"\nInterrupted after {elapsed:.2f} seconds (proving the timeout issue)")
    except Exception as e:
        elapsed = time.time() - start_time
        print(f"\nFailed after {elapsed:.2f} seconds: {e}")

    # Test fixed (fast) schema discovery
    print("\n" + "-" * 70)
    print("Step 3: Running FIXED get_schemas() [with $sample optimization]")
    print("-" * 70)

    start_time = time.time()
    try:
        schemas = get_schemas_fixed(connection_string, parsed["database"])
        elapsed = time.time() - start_time
        print(f"\nCompleted in {elapsed:.2f} seconds")
        print(f"Found {len(schemas)} collections")

        for name, columns in list(schemas.items())[:5]:
            print(f"  - {name}: {len(columns)} columns")
        if len(schemas) > 5:
            print(f"  ... and {len(schemas) - 5} more collections")

    except Exception as e:
        elapsed = time.time() - start_time
        print(f"\nFailed after {elapsed:.2f} seconds: {e}")

    print("\n" + "=" * 70)
    print("Summary")
    print("=" * 70)
    print("""
Issues identified in PostHog's MongoDB connector:

1. PERFORMANCE: Schema discovery scans ALL documents in every collection
   - Location: _get_schema_from_query() in mongo.py
   - Fix: Add {"$sample": {"size": 1000}} to the aggregation pipeline

2. PERMISSIONS: Cannot filter collections via MongoDB user permissions
   - Location: get_schemas() in mongo.py
   - Fix: Add nameOnly=True to list_collection_names() call
   - Current:  db.list_collection_names(authorizedCollections=True)
   - Fixed:    db.list_collection_names(authorizedCollections=True, nameOnly=True)
""")


if __name__ == "__main__":
    main()
