"""
Fix validation: Compare per-tick cost of 3 approaches at scale.

A) Baseline  — no snapshot (= Branch/revert)
B) Original  — Pydantic DeploymentSnapshot every tick (= Master/#56225)
C) Fix       — tuple cache + lazy Pydantic (= current fix)

Uses actual Ray Serve code. No cluster needed.
"""

import math
import statistics
import sys
import time
from typing import Dict, List, Optional

from ray.serve._private.common import (
    AutoscalingSnapshotError,
    AutoscalingStatus,
    DeploymentID,
    DeploymentSnapshot,
    ReplicaID,
)
from ray.serve._private.autoscaling_state import DeploymentAutoscalingState
from ray.serve._private.deployment_info import DeploymentInfo
from ray.serve._private.config import DeploymentConfig, ReplicaConfig
from ray.serve.config import AutoscalingConfig

WARMUP = 200
ITERATIONS = 2000


def _dummy():
    pass


def create_state(n_replicas):
    dep_id = DeploymentID(name="worker", app_name="default")
    state = DeploymentAutoscalingState(dep_id)
    ac = AutoscalingConfig(
        min_replicas=1,
        max_replicas=4096,
        target_ongoing_requests=5,
        metrics_interval_s=10.0,
        look_back_period_s=30.0,
    )
    dc = DeploymentConfig(autoscaling_config=ac)
    rc = ReplicaConfig.create(_dummy)
    info = DeploymentInfo(
        deployment_config=dc,
        replica_config=rc,
        start_time_ms=int(time.time() * 1000),
        deployer_job_id="bench",
    )
    state.register(info, curr_target_num_replicas=n_replicas)
    replicas = [
        ReplicaID(unique_id=f"r-{i}", deployment_id=dep_id)
        for i in range(n_replicas)
    ]
    state.update_running_replica_ids(replicas)
    state._latest_metrics_timestamp = time.time() - 2.0
    return state


def bench_baseline(state, n, iters):
    """A) No snapshot — equivalent to Branch (revert)."""
    times = []
    for _ in range(iters):
        t0 = time.perf_counter_ns()
        ctx = state.get_autoscaling_context(n)
        decision, _ = state._policy(ctx)
        if isinstance(decision, float):
            decision = math.ceil(decision)
        state.apply_bounds(decision)
        times.append(time.perf_counter_ns() - t0)
    return times


def bench_original_pr(state, n, iters):
    """B) Original PR #56225 — Pydantic snapshot every tick."""
    dep_id = state._deployment_id
    times = []
    for _ in range(iters):
        t0 = time.perf_counter_ns()
        ctx = state.get_autoscaling_context(n)
        decision, _ = state._policy(ctx)
        if isinstance(decision, float):
            decision = math.ceil(decision)
        decision = state.apply_bounds(decision)
        # Original PR: create Pydantic DeploymentSnapshot every tick
        current = ctx.current_num_replicas
        scaling_status = AutoscalingStatus.format_scaling_status(
            AutoscalingStatus.UPSCALE if decision > current
            else AutoscalingStatus.DOWNSCALE if decision < current
            else AutoscalingStatus.STABLE
        )
        elapsed = time.time() - state._latest_metrics_timestamp if state._latest_metrics_timestamp else None
        look_back = state._config.look_back_period_s
        errors = []
        if elapsed is None:
            errors.append(AutoscalingSnapshotError.METRICS_UNAVAILABLE)
        policy = ctx.config.policy.get_policy()
        DeploymentSnapshot(
            timestamp_str=time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
            app=dep_id.app_name,
            deployment=dep_id.name,
            current_replicas=current,
            target_replicas=decision,
            min_replicas=ctx.capacity_adjusted_min_replicas,
            max_replicas=ctx.capacity_adjusted_max_replicas,
            scaling_status=scaling_status,
            policy_name=f"{policy.__module__}.{policy.__name__}",
            look_back_period_s=look_back,
            queued_requests=float(ctx.total_queued_requests),
            ongoing_requests=float(ctx.total_num_requests),
            metrics_health=DeploymentSnapshot.format_metrics_health_text(
                time_since_last_collected_metrics_s=elapsed,
                look_back_period_s=look_back,
            ),
            errors=errors,
        )
        times.append(time.perf_counter_ns() - t0)
    return times


def bench_fix(state, n, iters):
    """C) Current fix — tuple cache, lazy Pydantic."""
    times = []
    for _ in range(iters):
        t0 = time.perf_counter_ns()
        state.get_decision_num_replicas(curr_target_num_replicas=n)
        times.append(time.perf_counter_ns() - t0)
    return times


def bench_emit_original(state, iters):
    """Emit cost: original PR (Pydantic comparison)."""
    dep_id = state._deployment_id
    last: Dict[DeploymentID, DeploymentSnapshot] = {}
    times = []
    for _ in range(iters):
        t0 = time.perf_counter_ns()
        snap = state.get_deployment_snapshot()
        if snap is not None:
            prev = last.get(dep_id)
            if prev is None or not (
                prev.target_replicas == snap.target_replicas
                and prev.min_replicas == snap.min_replicas
                and prev.max_replicas == snap.max_replicas
                and prev.scaling_status == snap.scaling_status
            ):
                snap.dict(exclude_none=True)
                last[dep_id] = snap
        times.append(time.perf_counter_ns() - t0)
    return times


def bench_emit_fix(state, iters):
    """Emit cost: fix (tuple key comparison)."""
    dep_id = state._deployment_id
    last_keys: Dict[DeploymentID, tuple] = {}
    times = []
    for _ in range(iters):
        t0 = time.perf_counter_ns()
        key = state.get_cached_snapshot_key()
        if key is not None:
            if key != last_keys.get(dep_id):
                snap = state.get_deployment_snapshot()
                if snap is not None:
                    snap.dict(exclude_none=True)
                    last_keys[dep_id] = key
        times.append(time.perf_counter_ns() - t0)
    return times


def us(ns_list):
    return [x / 1000.0 for x in ns_list]


def main():
    print("=" * 80)
    print("  Fix Validation: Per-tick cost comparison")
    print("  A) No snapshot (Branch/revert)")
    print("  B) Original PR #56225 (Pydantic every tick)")
    print("  C) Current fix (tuple cache + lazy Pydantic)")
    print("=" * 80)

    replica_counts = [1, 4, 16, 64, 256, 1024, 2048]

    print(f"\n  {'Replicas':>8s}  {'A) No snap':>12s}  {'B) Original':>12s}  "
          f"{'C) Fix':>12s}  {'B vs A':>10s}  {'C vs A':>10s}  {'B vs C':>8s}")
    print(f"  {'-'*8}  {'-'*12}  {'-'*12}  {'-'*12}  {'-'*10}  {'-'*10}  {'-'*8}")

    for n in replica_counts:
        state = create_state(n)
        # warmup
        for _ in range(WARMUP):
            state.get_decision_num_replicas(curr_target_num_replicas=n)

        t_a = us(bench_baseline(state, n, ITERATIONS))
        t_b = us(bench_original_pr(state, n, ITERATIONS))
        t_c = us(bench_fix(state, n, ITERATIONS))

        a = statistics.mean(t_a)
        b = statistics.mean(t_b)
        c = statistics.mean(t_c)

        ba_pct = (b - a) / a * 100 if a > 0 else 0
        ca_pct = (c - a) / a * 100 if a > 0 else 0
        bc_ratio = b / c if c > 0 else float("inf")

        print(f"  {n:>8d}  {a:>10.1f}us  {b:>10.1f}us  {c:>10.1f}us  "
              f"{ba_pct:>+8.0f}%  {ca_pct:>+8.0f}%  {bc_ratio:>7.2f}x")

    # Emit comparison (steady state — no changes, dedup active)
    print(f"\n  Emit cost (steady state, dedup active, 64 replicas):")
    state = create_state(64)
    for _ in range(WARMUP):
        state.get_decision_num_replicas(curr_target_num_replicas=64)

    t_emit_orig = us(bench_emit_original(state, ITERATIONS))
    t_emit_fix = us(bench_emit_fix(state, ITERATIONS))

    eo = statistics.mean(t_emit_orig)
    ef = statistics.mean(t_emit_fix)

    print(f"    B) Original (Pydantic build + compare): {eo:>8.1f}us")
    print(f"    C) Fix (tuple key compare):             {ef:>8.1f}us")
    print(f"    Speedup: {eo/ef:.1f}x")

    # Full loop: autoscale + emit (steady state)
    print(f"\n  Full loop: autoscale + emit (steady state)")
    print(f"  {'Replicas':>8s}  {'A) No snap':>12s}  {'B) Original':>12s}  "
          f"{'C) Fix':>12s}  {'C vs A':>10s}")
    print(f"  {'-'*8}  {'-'*12}  {'-'*12}  {'-'*12}  {'-'*10}")

    for n in replica_counts:
        state = create_state(n)
        dep_id = state._deployment_id
        for _ in range(WARMUP):
            state.get_decision_num_replicas(curr_target_num_replicas=n)

        # A) baseline
        a_mean = statistics.mean(us(bench_baseline(state, n, ITERATIONS)))

        # B) original full loop
        last_b: Dict[DeploymentID, DeploymentSnapshot] = {}
        times_b = []
        for _ in range(ITERATIONS):
            t0 = time.perf_counter_ns()
            # autoscale (original: Pydantic every tick)
            ctx = state.get_autoscaling_context(n)
            decision, _ = state._policy(ctx)
            if isinstance(decision, float):
                decision = math.ceil(decision)
            decision = state.apply_bounds(decision)
            current = ctx.current_num_replicas
            scaling_status = AutoscalingStatus.format_scaling_status(
            AutoscalingStatus.UPSCALE if decision > current
            else AutoscalingStatus.DOWNSCALE if decision < current
            else AutoscalingStatus.STABLE
        )
            elapsed = time.time() - state._latest_metrics_timestamp if state._latest_metrics_timestamp else None
            look_back = state._config.look_back_period_s
            errors = []
            if elapsed is None:
                errors.append(AutoscalingSnapshotError.METRICS_UNAVAILABLE)
            policy = ctx.config.policy.get_policy()
            snap = DeploymentSnapshot(
                timestamp_str=time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
                app=dep_id.app_name,
                deployment=dep_id.name,
                current_replicas=current,
                target_replicas=decision,
                min_replicas=ctx.capacity_adjusted_min_replicas,
                max_replicas=ctx.capacity_adjusted_max_replicas,
                scaling_status=scaling_status,
                policy_name=f"{policy.__module__}.{policy.__name__}",
                look_back_period_s=look_back,
                queued_requests=float(ctx.total_queued_requests),
                ongoing_requests=float(ctx.total_num_requests),
                metrics_health=DeploymentSnapshot.format_metrics_health_text(
                    time_since_last_collected_metrics_s=elapsed,
                    look_back_period_s=look_back,
                ),
                errors=errors,
            )
            # emit (original: Pydantic comparison)
            prev = last_b.get(dep_id)
            if prev is None or not (
                prev.target_replicas == snap.target_replicas
                and prev.min_replicas == snap.min_replicas
                and prev.max_replicas == snap.max_replicas
                and prev.scaling_status == snap.scaling_status
            ):
                snap.dict(exclude_none=True)
                last_b[dep_id] = snap
            times_b.append(time.perf_counter_ns() - t0)
        b_mean = statistics.mean(us(times_b))

        # C) fix full loop
        last_c: Dict[DeploymentID, tuple] = {}
        times_c = []
        for _ in range(ITERATIONS):
            t0 = time.perf_counter_ns()
            state.get_decision_num_replicas(curr_target_num_replicas=n)
            key = state.get_cached_snapshot_key()
            if key is not None and key != last_c.get(dep_id):
                snap = state.get_deployment_snapshot()
                if snap is not None:
                    snap.dict(exclude_none=True)
                    last_c[dep_id] = key
            times_c.append(time.perf_counter_ns() - t0)
        c_mean = statistics.mean(us(times_c))

        ca_pct = (c_mean - a_mean) / a_mean * 100 if a_mean > 0 else 0
        print(f"  {n:>8d}  {a_mean:>10.1f}us  {b_mean:>10.1f}us  "
              f"{c_mean:>10.1f}us  {ca_pct:>+8.0f}%")

    print()


if __name__ == "__main__":
    main()
