Skip to content

Vectorize masks_to_boxes for performance#9358

Merged
zy1git merged 5 commits intopytorch:mainfrom
raimbekovm:fix/masks-to-boxes-vectorize
Feb 13, 2026
Merged

Vectorize masks_to_boxes for performance#9358
zy1git merged 5 commits intopytorch:mainfrom
raimbekovm:fix/masks-to-boxes-vectorize

Conversation

@raimbekovm
Copy link
Copy Markdown
Contributor

Vectorizes the masks_to_boxes implementation by removing the Python loop.

Depends on #9357.

Split from #9347 as requested.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Jan 27, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/9358

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@zy1git
Copy link
Copy Markdown
Contributor

zy1git commented Feb 12, 2026

Benchmark Results

PERFORMANCE BENCHMARKS (CPU)

Batch Size Original (ms) Vectorized (ms) Speedup
1 0.039 ± 0.00ms 0.054 ± 0.00ms 0.71x
2 0.073 ± 0.00ms 0.067 ± 0.01ms 1.08x
4 0.143 ± 0.00ms 0.093 ± 0.00ms 1.54x
8 0.282 ± 0.01ms 0.121 ± 0.02ms 2.32x
16 0.824 ± 0.10ms 0.172 ± 0.02ms 4.78x
32 1.615 ± 0.05ms 0.215 ± 0.01ms 7.52x
64 3.224 ± 0.07ms 0.370 ± 0.01ms 8.71x
128 6.349 ± 0.09ms 0.791 ± 1.56ms 8.02x
256 13.225 ± 0.93ms 1.159 ± 0.13ms 11.41x
512 27.304 ± 2.69ms 1.963 ± 0.03ms 13.91x
1,024 51.184 ± 0.19ms 3.773 ± 0.10ms 13.57x
2,048 102.540 ± 0.42ms 7.540 ± 0.40ms 13.60x
4,096 207.354 ± 4.79ms 14.796 ± 0.60ms 14.01x

PERFORMANCE BENCHMARKS (CUDA)

Batch Size Original (ms) Vectorized (ms) Speedup
1 0.115 ± 0.01ms 0.133 ± 0.02ms 0.86x
2 0.210 ± 0.02ms 0.133 ± 0.01ms 1.58x
4 0.381 ± 0.02ms 0.129 ± 0.01ms 2.96x
8 0.732 ± 0.02ms 0.131 ± 0.01ms 5.60x
16 1.830 ± 0.49ms 0.134 ± 0.01ms 13.69x
32 2.771 ± 0.15ms 0.134 ± 0.02ms 20.75x
64 5.927 ± 1.06ms 0.194 ± 0.04ms 30.61x
128 10.747 ± 0.10ms 0.135 ± 0.01ms 79.73x
256 22.194 ± 1.72ms 0.148 ± 0.03ms 150.02x
512 43.501 ± 0.49ms 0.137 ± 0.02ms 317.71x
1,024 87.550 ± 3.36ms 0.132 ± 0.01ms 661.66x
2,048 174.580 ± 2.76ms 0.142 ± 0.01ms 1231.88x
4,096 351.819 ± 6.77ms 0.217 ± 0.00ms 1618.83x
8,192 702.779 ± 26.87ms 0.359 ± 0.00ms 1958.45x
16,384 1405.378 ± 45.75ms 0.656 ± 0.00ms 2143.69x

MEMORY BENCHMARKS (CUDA)

Batch Size Original (MB) Vectorized (MB) Difference
1 0.04MB 0.02MB -0.02MB
16 0.29MB 0.33MB +0.03MB
64 1.05MB 1.31MB +0.26MB
256 4.05MB 5.23MB +1.18MB
1,024 16.06MB 20.91MB +4.85MB
4,096 64.11MB 83.63MB +19.52MB
8,192 128.17MB 167.26MB +39.09MB

The vectorized implementation is faster, especially for large batch sizes. It is acceptable that the memory usage is slightly higher.

Note: To have a consistently reliable CPU performance, please use this command to run the benchmark script: taskset -c 0-3 python3 benchmark_script.py

Benchmark Script
"""
Benchmark script for masks_to_boxes optimization.

Compares the original sequential implementation vs the new vectorized implementation
from PR #9358.
"""

import gc
import time

import numpy as np
import torch


def masks_to_boxes_original(masks: torch.Tensor) -> torch.Tensor:
    """Original sequential implementation."""
    if masks.numel() == 0:
        return torch.zeros((0, 4), device=masks.device, dtype=torch.float)

    n = masks.shape[0]
    bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float)

    for index, mask in enumerate(masks):
        y, x = torch.where(mask != 0)
        if len(x) == 0:  # Handle empty masks
            continue
        bounding_boxes[index, 0] = torch.min(x)
        bounding_boxes[index, 1] = torch.min(y)
        bounding_boxes[index, 2] = torch.max(x)
        bounding_boxes[index, 3] = torch.max(y)

    return bounding_boxes


def masks_to_boxes_vectorized(masks: torch.Tensor) -> torch.Tensor:
    """New vectorized implementation from PR #9358."""
    if masks.numel() == 0:
        return torch.zeros((0, 4), device=masks.device, dtype=torch.float)

    n, h, w = masks.shape

    masks_bool = masks.bool()

    non_zero_rows = torch.any(masks_bool, dim=2)
    non_zero_cols = torch.any(masks_bool, dim=1)

    empty_masks = ~torch.any(non_zero_rows, dim=1)

    non_zero_rows_f = non_zero_rows.float()
    non_zero_cols_f = non_zero_cols.float()

    y1 = non_zero_rows_f.argmax(dim=1)
    x1 = non_zero_cols_f.argmax(dim=1)
    y2 = (h - 1) - non_zero_rows_f.flip(dims=[1]).argmax(dim=1)
    x2 = (w - 1) - non_zero_cols_f.flip(dims=[1]).argmax(dim=1)

    bounding_boxes = torch.stack([x1, y1, x2, y2], dim=1).float()

    bounding_boxes[empty_masks] = 0

    return bounding_boxes


def generate_random_masks(
    batch_size: int,
    height: int = 64,
    width: int = 64,
    density: float = 0.3,
    device: str = "cpu",
) -> torch.Tensor:
    """Generate random binary masks with approximately `density` fraction of pixels set."""
    masks = torch.rand(batch_size, height, width, device=device) < density
    return masks.float()


def benchmark_function(
    func, masks, num_warmup: int = 10, num_runs: int = 50, device: str = "cpu"
):
    """Benchmark a function with warmup runs."""
    # Force garbage collection ONCE before benchmarking (not during!)
    gc.collect()
    if device == "cuda":
        torch.cuda.empty_cache()

    # Warmup (increased to stabilize CPU frequency and caches)
    for _ in range(num_warmup):
        _ = func(masks)
        if device == "cuda":
            torch.cuda.synchronize()

    # Timed runs - NO gc.collect() here, it was causing the slowdown!
    times = []
    for _ in range(num_runs):
        if device == "cuda":
            torch.cuda.synchronize()

        start = time.perf_counter()
        _ = func(masks)

        if device == "cuda":
            torch.cuda.synchronize()

        end = time.perf_counter()
        times.append(end - start)

    return np.mean(times) * 1000, np.std(times) * 1000  # Convert to ms


def measure_peak_memory(func, masks, device: str = "cpu"):
    """Measure peak memory usage during function execution."""
    gc.collect()

    if device == "cuda":
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()
        _ = func(masks)
        torch.cuda.synchronize()
        peak_memory = torch.cuda.max_memory_allocated() / (1024 * 1024)  # MB
    else:
        # For CPU, we'll skip memory measurement as it's more complex
        _ = func(masks)
        peak_memory = None

    return peak_memory


def verify_correctness(batch_size: int = 100, device: str = "cpu"):
    """Verify that both implementations produce the same results."""
    print(f"\n{'='*60}")
    print("CORRECTNESS VERIFICATION")
    print(f"{'='*60}")

    # Test with random masks
    masks = generate_random_masks(batch_size, device=device)
    result_orig = masks_to_boxes_original(masks)
    result_vec = masks_to_boxes_vectorized(masks)

    if torch.allclose(result_orig, result_vec):
        print(f"✓ Random masks (batch={batch_size}): PASSED")
    else:
        print(f"✗ Random masks (batch={batch_size}): FAILED")
        print(f"  Max difference: {(result_orig - result_vec).abs().max().item()}")

    # Test with empty masks
    empty_masks = torch.zeros((5, 64, 64), device=device)
    result_orig_empty = masks_to_boxes_original(empty_masks)
    result_vec_empty = masks_to_boxes_vectorized(empty_masks)

    if torch.allclose(result_orig_empty, result_vec_empty):
        print(f"✓ Empty masks: PASSED")
    else:
        print(f"✗ Empty masks: FAILED")

    # Test with mixed empty and non-empty
    mixed_masks = torch.zeros((3, 10, 10), device=device)
    mixed_masks[1, 2:5, 3:7] = 1
    result_orig_mixed = masks_to_boxes_original(mixed_masks)
    result_vec_mixed = masks_to_boxes_vectorized(mixed_masks)

    if torch.allclose(result_orig_mixed, result_vec_mixed):
        print(f"✓ Mixed empty/non-empty masks: PASSED")
    else:
        print(f"✗ Mixed empty/non-empty masks: FAILED")


def run_benchmarks(device: str = "cpu"):
    """Run benchmarks for various batch sizes."""
    print(f"\n{'='*60}")
    print(f"PERFORMANCE BENCHMARKS ({device.upper()})")
    print(f"{'='*60}")

    # Batch sizes to test (powers of 2)
    batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]

    if device == "cuda":
        # Add larger batch sizes for GPU
        batch_sizes.extend([8192, 16384])

    print(
        f"\n{'Batch Size':>12} | {'Original (ms)':>15} | {'Vectorized (ms)':>15} | {'Speedup':>10}"
    )
    print("-" * 62)

    results = []

    for batch_size in batch_sizes:
        try:
            masks = generate_random_masks(batch_size, device=device)

            time_orig, std_orig = benchmark_function(
                masks_to_boxes_original, masks, device=device
            )
            time_vec, std_vec = benchmark_function(
                masks_to_boxes_vectorized, masks, device=device
            )

            speedup = time_orig / time_vec if time_vec > 0 else float("inf")

            print(
                f"{batch_size:>12} | {time_orig:>12.3f} ± {std_orig:.2f} | {time_vec:>12.3f} ± {std_vec:.2f} | {speedup:>9.2f}x"
            )

            results.append(
                {
                    "batch_size": batch_size,
                    "time_orig": time_orig,
                    "time_vec": time_vec,
                    "speedup": speedup,
                }
            )

            # Clean up
            del masks
            gc.collect()
            if device == "cuda":
                torch.cuda.empty_cache()

        except RuntimeError as e:
            print(f"{batch_size:>12} | {'OOM':>15} | {'OOM':>15} | {'N/A':>10}")
            break

    return results


def run_memory_benchmarks(device: str = "cuda"):
    """Run memory benchmarks (GPU only)."""
    if device != "cuda":
        print("\nMemory benchmarks are only available for CUDA devices.")
        return

    print(f"\n{'='*60}")
    print("MEMORY BENCHMARKS (CUDA)")
    print(f"{'='*60}")

    batch_sizes = [1, 16, 64, 256, 1024, 4096, 8192]

    print(
        f"\n{'Batch Size':>12} | {'Original (MB)':>15} | {'Vectorized (MB)':>15} | {'Difference':>12}"
    )
    print("-" * 62)

    for batch_size in batch_sizes:
        try:
            masks = generate_random_masks(batch_size, device=device)

            mem_orig = measure_peak_memory(
                masks_to_boxes_original, masks, device=device
            )

            gc.collect()
            torch.cuda.empty_cache()

            mem_vec = measure_peak_memory(
                masks_to_boxes_vectorized, masks, device=device
            )

            diff = mem_vec - mem_orig if mem_orig and mem_vec else 0

            print(
                f"{batch_size:>12} | {mem_orig:>15.2f} | {mem_vec:>15.2f} | {diff:>+11.2f}"
            )

            del masks
            gc.collect()
            torch.cuda.empty_cache()

        except RuntimeError as e:
            print(f"{batch_size:>12} | {'OOM':>15} | {'OOM':>15} | {'N/A':>12}")
            break


def main():
    print("=" * 60)
    print("masks_to_boxes Benchmark: Original vs Vectorized (PR #9358)")
    print("=" * 60)

    # Check CUDA availability
    cuda_available = torch.cuda.is_available()
    print(f"\nPyTorch version: {torch.__version__}")
    print(f"CUDA available: {cuda_available}")
    if cuda_available:
        print(f"CUDA device: {torch.cuda.get_device_name(0)}")

    # Verify correctness first
    verify_correctness(device="cpu")

    # Run CPU benchmarks
    run_benchmarks(device="cpu")

    # Run GPU benchmarks if available
    if cuda_available:
        verify_correctness(device="cuda")
        run_benchmarks(device="cuda")
        run_memory_benchmarks(device="cuda")



if __name__ == "__main__":
    main()

@zy1git zy1git merged commit 0f6d91d into pytorch:main Feb 13, 2026
8 of 22 checks passed
@github-actions
Copy link
Copy Markdown

Hey @zy1git!

You merged this PR, but no labels were added.
The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants