from flashinfer import comm
import hashlib
import multiprocessing as mp
import csv
import torch
import torch.distributed as dist
import os
import argparse
import sys
import time
import pathlib
import itertools


def process_fn(
    rank: int,
    world_size: int,
    token_num: int,
    hidden_dim: int,
    use_oneshot: bool,
    fp32_acc: bool,
    pattern_name: str,
    output_file: str,
    master_addr: str,
    master_port: str,
):
    try:
        os.environ["MASTER_ADDR"] = master_addr
        os.environ["MASTER_PORT"] = master_port

        dist.init_process_group(
            backend="nccl",
            rank=rank,
            world_size=world_size,
        )

        torch.cuda.set_device(rank)
        group = dist.group.WORLD

        pattern_map = {
            "kAllReduce": comm.AllReduceFusionPattern.kAllReduce,
            "kARResidualRMSNorm": comm.AllReduceFusionPattern.kARResidualRMSNorm,
            "kARResidualRMSNormFP8Quant": comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant,
            "kARResidualRMSNormFP4Quant": comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant,
            "kARResidualRMSNormOutFP8Quant": comm.AllReduceFusionPattern.kARResidualRMSNormOutFP8Quant,
            "kARResidualRMSNormOutFP4Quant": comm.AllReduceFusionPattern.kARResidualRMSNormOutFP4Quant,
        }
        pattern_code = pattern_map[pattern_name]

        print(f"rank {rank} initializing workspace")
        ipc_handles, workspace_tensor = (
            comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
                rank,
                world_size,
                token_num,
                hidden_dim,
                group,
            )
        )

        dtype = torch.bfloat16
        input_tensor = torch.randn(
            token_num, hidden_dim, device=torch.device("cuda"), dtype=dtype
        )
        residual = torch.randn_like(input_tensor)
        weight = torch.ones(hidden_dim, device=torch.device("cuda"), dtype=dtype)
        allreduce_out = torch.empty_like(input_tensor)
        residual_out = torch.empty_like(residual)
        norm_out = torch.empty_like(input_tensor)
        quant_out = torch.empty_like(input_tensor)
        scale_out = torch.empty_like(input_tensor)
        scale_factor = torch.tensor(
            1.0, device=torch.device("cuda"), dtype=torch.float32
        )

        def run():

            comm.trtllm_allreduce_fusion(
                allreduce_in=input_tensor,
                world_size=world_size,
                world_rank=rank,
                token_num=token_num,
                hidden_dim=hidden_dim,
                workspace_ptrs=workspace_tensor,
                launch_with_pdl=True,
                use_oneshot=use_oneshot,
                trigger_completion_at_end=False,
                fp32_acc=fp32_acc,
                pattern_code=pattern_code,
                allreduce_out=allreduce_out,
                residual_in=residual,
                residual_out=residual_out,
                norm_out=norm_out,
                quant_out=quant_out,
                scale_out=scale_out,
                rms_gamma=weight,
                rms_eps=1e-3,
                scale_factor=scale_factor,
                layout_code=None,
            )

            return residual_out, norm_out

        print(f"rank {rank} running")

        stream = torch.cuda.Stream()
        stream.wait_stream(torch.cuda.current_stream())
        with torch.cuda.stream(stream):
            for _ in range(10):
                run()  # warmup
        torch.cuda.synchronize()

        num_runs = 1000
        graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(graph):
            for _ in range(num_runs):
                run()

        start_time = time.time()

        graph.replay()
        torch.cuda.synchronize()

        end_time = time.time()
        time_in_us = (end_time - start_time) * 1000 * 1000 / num_runs

        print(f"rank {rank} time: {time_in_us:10.5f}us")
        comm_size_in_mb = (
            token_num * hidden_dim * world_size * 2 * dtype.itemsize / 1024 / 1024
        )
        if rank == 0:
            config_hash = hash_config(
                world_size, token_num, hidden_dim, use_oneshot, fp32_acc, pattern_name
            )
            with open(output_file, "a") as f:
                print(
                    f"{world_size},{token_num},{hidden_dim},{comm_size_in_mb:.4f},{use_oneshot},{fp32_acc},{pattern_name},{time_in_us:.2f},{config_hash}",
                    file=f,
                )

    finally:
        dist.barrier(group=group)
        comm.trtllm_destroy_ipc_workspace_for_all_reduce_fusion(
            ipc_handles, group=group
        )
        dist.destroy_process_group(group=group)


def mp_launch(
    world_size: int,
    token_num: int,
    hidden_dim: int,
    use_oneshot: bool,
    fp32_acc: bool,
    pattern_name: str,
    output_file: str,
):
    master_addr = "127.0.0.1"
    master_port = "29500"
    processes = []
    try:
        try:
            mp.set_start_method("spawn", force=True)
        except RuntimeError:
            pass  # Already set

        for rank in range(world_size):
            p = mp.Process(
                target=process_fn,
                args=(
                    rank,
                    world_size,
                    token_num,
                    hidden_dim,
                    use_oneshot,
                    fp32_acc,
                    pattern_name,
                    output_file,
                    master_addr,
                    master_port,
                ),
            )
            p.start()
            processes.append(p)

    finally:
        for i, p in enumerate(processes):
            p.join()
            if p.exitcode != 0:
                print(f"Process {i} failed with exit code {p.exitcode}")


def run_test(
    world_size: int,
    token_num: int,
    hidden_dim: int,
    use_oneshot: bool,
    fp32_acc: bool,
    pattern_name: str,
    output_file: str,
):
    if world_size > torch.cuda.device_count():
        print(
            f"Skipping test: world_size ({world_size}) > available GPUs ({torch.cuda.device_count()})"
        )
        return

    if token_num <= world_size:
        print(f"Skipping test: token_num ({token_num}) <= world_size ({world_size})")
        return

    if world_size * token_num * hidden_dim >= 2**31:
        print(f"Skipping test: {world_size=} * {token_num=} * {hidden_dim=} >= 2**31")
        return

    if (token_num == 8192 and hidden_dim == 8192 and world_size == 4) or (
        token_num == 4096 and hidden_dim == 8192 and world_size == 8
    ):
        print(
            f"Skipping test: possible bug with {world_size=} {token_num=} {hidden_dim=}"
        )
        return

    config_hash = hash_config(
        world_size, token_num, hidden_dim, use_oneshot, fp32_acc, pattern_name
    )

    with open(output_file, "r") as f:
        reader = csv.DictReader(f)
        for row in reader:
            if row["config_hash"] == config_hash:
                print(f"Skipping test: {config_hash=} already exists")
                return

    print(" " * 80)
    print(" " * 80)
    print("-" * 80)
    print(
        f"Running allreduce test with: {world_size=} {token_num=} {hidden_dim=} {use_oneshot=} {fp32_acc=} {pattern_name=}"
    )

    mp_launch(
        world_size,
        token_num,
        hidden_dim,
        use_oneshot,
        fp32_acc,
        pattern_name,
        output_file,
    )

    print(f"Test completed. Results saved to {output_file}")
    print("-" * 80)
    print(" " * 80)
    print(" " * 80)


def hash_config(*args) -> str:
    return hashlib.sha256(str(args).encode("utf-8")).hexdigest()[:16]


def main():
    parser = argparse.ArgumentParser(description="AllReduce performance test")
    parser.add_argument("--output-file", type=str, required=True, help="Output file")

    args = parser.parse_args()

    world_sizes = [2, 4, 8]
    token_nums = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
    hidden_dims = [2880, 5120, 8192]
    use_oneshots = [True, False]
    fp32_accs = [True, False]
    pattern_names = [
        "kAllReduce",
        "kARResidualRMSNorm",
        "kARResidualRMSNormFP8Quant",
        "kARResidualRMSNormFP4Quant",
        "kARResidualRMSNormOutFP8Quant",
        "kARResidualRMSNormOutFP4Quant",
    ]

    if not pathlib.Path(args.output_file).exists():
        with open(args.output_file, "w") as f:
            print(
                "world_size,token_num,hidden_dim,comm_size (MiB),use_oneshot,fp32_acc,pattern_name,time (us),config_hash",
                file=f,
            )

    for ar_args in itertools.product(
        world_sizes, token_nums, hidden_dims, use_oneshots, fp32_accs, pattern_names
    ):
        run_test(*ar_args, args.output_file)


if __name__ == "__main__":
    main()
