Vectorize masks_to_boxes for performance#9358
Merged
zy1git merged 5 commits intopytorch:mainfrom Feb 13, 2026
Merged
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/9358
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
NicolasHug
approved these changes
Feb 11, 2026
Contributor
Benchmark ResultsPERFORMANCE BENCHMARKS (CPU)
PERFORMANCE BENCHMARKS (CUDA)
MEMORY BENCHMARKS (CUDA)
The vectorized implementation is faster, especially for large batch sizes. It is acceptable that the memory usage is slightly higher. Note: To have a consistently reliable CPU performance, please use this command to run the benchmark script: Benchmark Script"""
Benchmark script for masks_to_boxes optimization.
Compares the original sequential implementation vs the new vectorized implementation
from PR #9358.
"""
import gc
import time
import numpy as np
import torch
def masks_to_boxes_original(masks: torch.Tensor) -> torch.Tensor:
"""Original sequential implementation."""
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
n = masks.shape[0]
bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float)
for index, mask in enumerate(masks):
y, x = torch.where(mask != 0)
if len(x) == 0: # Handle empty masks
continue
bounding_boxes[index, 0] = torch.min(x)
bounding_boxes[index, 1] = torch.min(y)
bounding_boxes[index, 2] = torch.max(x)
bounding_boxes[index, 3] = torch.max(y)
return bounding_boxes
def masks_to_boxes_vectorized(masks: torch.Tensor) -> torch.Tensor:
"""New vectorized implementation from PR #9358."""
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
n, h, w = masks.shape
masks_bool = masks.bool()
non_zero_rows = torch.any(masks_bool, dim=2)
non_zero_cols = torch.any(masks_bool, dim=1)
empty_masks = ~torch.any(non_zero_rows, dim=1)
non_zero_rows_f = non_zero_rows.float()
non_zero_cols_f = non_zero_cols.float()
y1 = non_zero_rows_f.argmax(dim=1)
x1 = non_zero_cols_f.argmax(dim=1)
y2 = (h - 1) - non_zero_rows_f.flip(dims=[1]).argmax(dim=1)
x2 = (w - 1) - non_zero_cols_f.flip(dims=[1]).argmax(dim=1)
bounding_boxes = torch.stack([x1, y1, x2, y2], dim=1).float()
bounding_boxes[empty_masks] = 0
return bounding_boxes
def generate_random_masks(
batch_size: int,
height: int = 64,
width: int = 64,
density: float = 0.3,
device: str = "cpu",
) -> torch.Tensor:
"""Generate random binary masks with approximately `density` fraction of pixels set."""
masks = torch.rand(batch_size, height, width, device=device) < density
return masks.float()
def benchmark_function(
func, masks, num_warmup: int = 10, num_runs: int = 50, device: str = "cpu"
):
"""Benchmark a function with warmup runs."""
# Force garbage collection ONCE before benchmarking (not during!)
gc.collect()
if device == "cuda":
torch.cuda.empty_cache()
# Warmup (increased to stabilize CPU frequency and caches)
for _ in range(num_warmup):
_ = func(masks)
if device == "cuda":
torch.cuda.synchronize()
# Timed runs - NO gc.collect() here, it was causing the slowdown!
times = []
for _ in range(num_runs):
if device == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
_ = func(masks)
if device == "cuda":
torch.cuda.synchronize()
end = time.perf_counter()
times.append(end - start)
return np.mean(times) * 1000, np.std(times) * 1000 # Convert to ms
def measure_peak_memory(func, masks, device: str = "cpu"):
"""Measure peak memory usage during function execution."""
gc.collect()
if device == "cuda":
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
_ = func(masks)
torch.cuda.synchronize()
peak_memory = torch.cuda.max_memory_allocated() / (1024 * 1024) # MB
else:
# For CPU, we'll skip memory measurement as it's more complex
_ = func(masks)
peak_memory = None
return peak_memory
def verify_correctness(batch_size: int = 100, device: str = "cpu"):
"""Verify that both implementations produce the same results."""
print(f"\n{'='*60}")
print("CORRECTNESS VERIFICATION")
print(f"{'='*60}")
# Test with random masks
masks = generate_random_masks(batch_size, device=device)
result_orig = masks_to_boxes_original(masks)
result_vec = masks_to_boxes_vectorized(masks)
if torch.allclose(result_orig, result_vec):
print(f"✓ Random masks (batch={batch_size}): PASSED")
else:
print(f"✗ Random masks (batch={batch_size}): FAILED")
print(f" Max difference: {(result_orig - result_vec).abs().max().item()}")
# Test with empty masks
empty_masks = torch.zeros((5, 64, 64), device=device)
result_orig_empty = masks_to_boxes_original(empty_masks)
result_vec_empty = masks_to_boxes_vectorized(empty_masks)
if torch.allclose(result_orig_empty, result_vec_empty):
print(f"✓ Empty masks: PASSED")
else:
print(f"✗ Empty masks: FAILED")
# Test with mixed empty and non-empty
mixed_masks = torch.zeros((3, 10, 10), device=device)
mixed_masks[1, 2:5, 3:7] = 1
result_orig_mixed = masks_to_boxes_original(mixed_masks)
result_vec_mixed = masks_to_boxes_vectorized(mixed_masks)
if torch.allclose(result_orig_mixed, result_vec_mixed):
print(f"✓ Mixed empty/non-empty masks: PASSED")
else:
print(f"✗ Mixed empty/non-empty masks: FAILED")
def run_benchmarks(device: str = "cpu"):
"""Run benchmarks for various batch sizes."""
print(f"\n{'='*60}")
print(f"PERFORMANCE BENCHMARKS ({device.upper()})")
print(f"{'='*60}")
# Batch sizes to test (powers of 2)
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
if device == "cuda":
# Add larger batch sizes for GPU
batch_sizes.extend([8192, 16384])
print(
f"\n{'Batch Size':>12} | {'Original (ms)':>15} | {'Vectorized (ms)':>15} | {'Speedup':>10}"
)
print("-" * 62)
results = []
for batch_size in batch_sizes:
try:
masks = generate_random_masks(batch_size, device=device)
time_orig, std_orig = benchmark_function(
masks_to_boxes_original, masks, device=device
)
time_vec, std_vec = benchmark_function(
masks_to_boxes_vectorized, masks, device=device
)
speedup = time_orig / time_vec if time_vec > 0 else float("inf")
print(
f"{batch_size:>12} | {time_orig:>12.3f} ± {std_orig:.2f} | {time_vec:>12.3f} ± {std_vec:.2f} | {speedup:>9.2f}x"
)
results.append(
{
"batch_size": batch_size,
"time_orig": time_orig,
"time_vec": time_vec,
"speedup": speedup,
}
)
# Clean up
del masks
gc.collect()
if device == "cuda":
torch.cuda.empty_cache()
except RuntimeError as e:
print(f"{batch_size:>12} | {'OOM':>15} | {'OOM':>15} | {'N/A':>10}")
break
return results
def run_memory_benchmarks(device: str = "cuda"):
"""Run memory benchmarks (GPU only)."""
if device != "cuda":
print("\nMemory benchmarks are only available for CUDA devices.")
return
print(f"\n{'='*60}")
print("MEMORY BENCHMARKS (CUDA)")
print(f"{'='*60}")
batch_sizes = [1, 16, 64, 256, 1024, 4096, 8192]
print(
f"\n{'Batch Size':>12} | {'Original (MB)':>15} | {'Vectorized (MB)':>15} | {'Difference':>12}"
)
print("-" * 62)
for batch_size in batch_sizes:
try:
masks = generate_random_masks(batch_size, device=device)
mem_orig = measure_peak_memory(
masks_to_boxes_original, masks, device=device
)
gc.collect()
torch.cuda.empty_cache()
mem_vec = measure_peak_memory(
masks_to_boxes_vectorized, masks, device=device
)
diff = mem_vec - mem_orig if mem_orig and mem_vec else 0
print(
f"{batch_size:>12} | {mem_orig:>15.2f} | {mem_vec:>15.2f} | {diff:>+11.2f}"
)
del masks
gc.collect()
torch.cuda.empty_cache()
except RuntimeError as e:
print(f"{batch_size:>12} | {'OOM':>15} | {'OOM':>15} | {'N/A':>12}")
break
def main():
print("=" * 60)
print("masks_to_boxes Benchmark: Original vs Vectorized (PR #9358)")
print("=" * 60)
# Check CUDA availability
cuda_available = torch.cuda.is_available()
print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {cuda_available}")
if cuda_available:
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
# Verify correctness first
verify_correctness(device="cpu")
# Run CPU benchmarks
run_benchmarks(device="cpu")
# Run GPU benchmarks if available
if cuda_available:
verify_correctness(device="cuda")
run_benchmarks(device="cuda")
run_memory_benchmarks(device="cuda")
if __name__ == "__main__":
main() |
|
Hey @zy1git! You merged this PR, but no labels were added. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Vectorizes the
masks_to_boxesimplementation by removing the Python loop.Depends on #9357.
Split from #9347 as requested.