import torch
torch._dynamo.disable()   # HARD disable Dynamo

import os
from torch.profiler import profile, ProfilerActivity
from torch.amp import autocast
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed

# ------------------------------------------------------------
# Configuration
# ------------------------------------------------------------
MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct"
DEVICE = "cpu"
DTYPE = torch.bfloat16

PROMPT = "I have tomatoes, basil and cheese at home. What can I cook for dinner?\n" *512
MAX_NEW_TOKENS = 128
N_WARMUP = 2
N_RUNS = 5

# ------------------------------------------------------------
# Load tokenizer + model
# ------------------------------------------------------------
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("Loading model in BF16...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=DTYPE,
).to(DEVICE)

model.eval()
print("Model parameter dtype:", next(model.parameters()).dtype)

# Silence generation warnings (optional)
model.generation_config.temperature = None
model.generation_config.top_p = None

# ------------------------------------------------------------
# Prepare input (BATCH SIZE = 1)
# ------------------------------------------------------------
inputs = tokenizer(PROMPT, return_tensors="pt").to(DEVICE)
assert inputs["input_ids"].shape[0] == 1, "Batch size is not 1!"

gen_kwargs = dict(
    do_sample=False,
    pad_token_id=tokenizer.pad_token_id,
)

# ------------------------------------------------------------
# Warm-up
# ------------------------------------------------------------
print("Running warm-up...")
set_seed(0)
with torch.no_grad(), autocast("cpu", dtype=DTYPE):
    for _ in range(N_WARMUP):
        _ = model.generate(**inputs, **gen_kwargs, max_new_tokens=16)

# ------------------------------------------------------------
# Profiling + Benchmark (NO SCHEDULE)
# ------------------------------------------------------------
print("Running benchmark + profiler...")

with profile(
    activities=[ProfilerActivity.CPU],
    record_shapes=True,
    with_stack=False,
) as prof:

    set_seed(1)
    with torch.no_grad(), autocast("cpu", dtype=DTYPE):
        for _ in range(N_RUNS):
            _ = model.generate(
                **inputs,
                **gen_kwargs,
                max_new_tokens=MAX_NEW_TOKENS,
                min_new_tokens=MAX_NEW_TOKENS,
            )

# ------------------------------------------------------------
# Results: TOP 20 OPERATORS
# ------------------------------------------------------------
print("\n===== TOP 20 OPERATORS (by self CPU time) =====\n")
print(
    prof.key_averages()
        .table(
            sort_by="self_cpu_time_total",
            row_limit=20,
        )
)

# Optional sanity check
print("\nNumber of profiler events:", len(prof.key_averages()))

# ------------------------------------------------------------
# Chrome trace
# ------------------------------------------------------------
trace_path = "llama31_8b_bf16_cpu_bs1_top6.json"
prof.export_chrome_trace(trace_path)
print(f"\nChrome trace written to: {trace_path}")
