Skip to content

[MPS] Improve runtime complexity of roi_align#9100

Merged
malfet merged 4 commits intopytorch:mainfrom
hvaara:fix-roi-align-mps
Feb 2, 2026
Merged

[MPS] Improve runtime complexity of roi_align#9100
malfet merged 4 commits intopytorch:mainfrom
hvaara:fix-roi-align-mps

Conversation

@hvaara
Copy link
Copy Markdown
Contributor

@hvaara hvaara commented Jun 8, 2025

roi_align on MPS has significantly inflated runtime complexity due to a bug in the looping behavior of the kernel. I've not found any other correctness issues with the current implementation, which closely follows the CUDA implementation. This PR fixes the runtime complexity, otherwise the kernel is semantically identical to before.

Note that this PR switches the dispatching to dispatchThreads, which has a tighter build target set than dispatchThreadgroups. Ref Nonuniform threadgroup size in Metal feature set tables.

Some other MPS kernels in vision is also likely affected.

Running the example code from pytorch/pytorch#124850 (comment) before:

--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  --------------------------------------------------------------------------------
                            Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls                                                                      Input Shapes
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  --------------------------------------------------------------------------------
                 model_inference         0.02%       6.412ms       100.00%       41.913s       41.913s             1                                                                                []
                     aten::where         0.00%       4.373us        80.19%       33.611s        8.403s             4                                                                          [[1000]]
             aten::nonzero_numpy         0.00%      15.335us        80.19%       33.611s        8.403s             4                                                                          [[1000]]
                   aten::nonzero        80.18%       33.605s        80.19%       33.611s        8.403s             4                                                                          [[1000]]
                     aten::where         0.00%       7.375us         2.55%        1.067s     533.698ms             2                                                                          [[4507]]
             aten::nonzero_numpy         0.00%      11.042us         2.55%        1.067s     533.695ms             2                                                                          [[4507]]
                   aten::nonzero         2.31%     969.133ms         2.55%        1.067s     533.679ms             2                                                                          [[4507]]
                      aten::topk         2.53%        1.062s         2.53%        1.062s        1.062s             1                                                     [[1, 120000], [], [], [], []]
                torchvision::nms         0.00%      52.208us         2.39%        1.004s        1.004s             1                                                               [[21, 4], [21], []]
                      aten::sort         2.39%     999.630ms         2.39%     999.635ms     999.635ms             1                                                                [[21], [], [], []]
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  --------------------------------------------------------------------------------
Self CPU time total: 41.913s

and after

--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  --------------------------------------------------------------------------------
                            Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls                                                                      Input Shapes
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  --------------------------------------------------------------------------------
                 model_inference         0.88%       4.364ms       100.00%     493.862ms     493.862ms             1                                                                                []
                torchvision::nms        15.95%      78.782ms        17.20%      84.925ms      84.925ms             1                                                           [[4507, 4], [4507], []]
                     aten::where         0.00%       2.957us        11.38%      56.185ms      14.046ms             4                                                                          [[1000]]
             aten::nonzero_numpy         0.00%       7.379us        11.38%      56.182ms      14.045ms             4                                                                          [[1000]]
                   aten::nonzero        10.26%      50.684ms        11.37%      56.146ms      14.036ms             4                                                                          [[1000]]
                    aten::conv2d         0.00%       5.417us         6.39%      31.548ms      31.548ms             1                             [[1, 3, 800, 800], [64, 3, 7, 7], [], [], [], [], []]
               aten::convolution         0.00%       9.041us         6.39%      31.543ms      31.543ms             1                     [[1, 3, 800, 800], [64, 3, 7, 7], [], [], [], [], [], [], []]
              aten::_convolution         0.00%      12.542us         6.39%      31.534ms      31.534ms             1     [[1, 3, 800, 800], [64, 3, 7, 7], [], [], [], [], [], [], [], [], [], [], []]
          aten::_mps_convolution         6.38%      31.520ms         6.38%      31.521ms      31.521ms             1                             [[1, 3, 800, 800], [64, 3, 7, 7], [], [], [], [], []]
          torchvision::roi_align         5.88%      29.036ms         5.88%      29.047ms      29.047ms             1                                [[1, 256, 200, 200], [960, 5], [], [], [], [], []]
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  --------------------------------------------------------------------------------
Self CPU time total: 493.862ms

One concern I have with the approach I'm proposing here is numeric overflow of the index with large input sizes.

Fixes pytorch/pytorch#124850

cc @malfet @kulinseth @qqaatw

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Jun 8, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/9100

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 2 Pending

As of commit f72587e with merge base b32ce3d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@hvaara
Copy link
Copy Markdown
Contributor Author

hvaara commented Jun 16, 2025

@Isalia20 @qqaatw Do you have time to review this?

Copy link
Copy Markdown
Contributor

@Isalia20 Isalia20 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a small script which will measure the time difference directly between the old roi pool and the new one? The one in the main thread is a bit confusing to me since the first section has no roi_pool and the 2nd one does.

Also about:
"One concern I have with the approach I'm proposing here is numeric overflow of the index with large input sizes."

Have you tested it out on larger input sizes and tested against CPU that this implementation produces equivalent results?

uint2 tid2 [[thread_position_in_threadgroup]]);

template<typename T, typename integer_t>
template <typename T, typename integer_t>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need templating here for integer_t? From what I see it just registers two of such op:
REGISTER_ROI_ALIGN_OP(float, int64_t);
REGISTER_ROI_ALIGN_OP(half, int64_t);

Both of which are int64_t so maybe we can remove it? I know it wasn't added in this PR but would be a nice thing to add to it

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM. Fixed.

@hvaara
Copy link
Copy Markdown
Contributor Author

hvaara commented Jun 16, 2025

Thanks a lot for the review!

Can you add a small script which will measure the time difference directly between the old roi pool and the new one? The one in the main thread is a bit confusing to me since the first section has no roi_pool and the 2nd one does.

I agree that the perf outputs from the first comment is a bit confusing. The culprit looks like it's nonzero, but that's a quirk from the profiler. The time is actually spent in roi_align, and the total execution time is 41.913s. In the second output you can see that the timings has improved significantly and the total execution time is 493.862ms.

I added a regression test test_performance_mps, that checks the execution time against a threshold of 1000 ms. You can run this unit test on this branch and main to see the difference in execution time, just set execution_time_ms_threshold = 0 and you'll get the timings on this branch too.

Have you tested it out on larger input sizes and tested against CPU that this implementation produces equivalent results?

output_size is defined as

int64_t output_size = num_rois * pooled_height * pooled_width * channels;

I've tested it with values generating output_size up to a size of 2^31 and it outputs the same results on CPU and MPS (tested with torch.testing.assert_close).

Above 2^31 I get a crash on CPU with the error Fatal Python error: Segmentation fault. If I try to print the whole tensor on MPS I get

/AppleInternal/Library/BuildRoots/1c8f7852-1ca9-11f0-b28b-226177e5bb69/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:829: failed assertion `[MPSNDArray initWithDevice:descriptor:isTextureBacked:] Error: NDArray dimension length > INT_MAX'
Fatal Python error: Aborted

Indexing into the tensor I get valid output eg. for out[0][0]:

tensor([[73.9383, 61.6012, 74.1146, 72.1870, 71.5774, 81.3736, 68.8621],
        [73.6598, 65.7005, 76.3044, 68.5069, 72.9770, 75.2113, 74.2729],
        [68.8734, 75.3870, 69.6267, 79.9169, 74.0059, 81.7421, 79.3910],
        [73.7394, 72.1691, 64.8541, 68.3909, 78.4569, 75.4807, 76.2083],
        [82.0290, 70.3133, 69.1630, 70.7505, 80.5654, 65.7685, 79.4339],
        [70.2205, 76.4919, 68.9302, 66.3778, 74.3694, 77.7530, 66.5249],
        [88.4454, 65.3945, 83.0347, 66.1287, 63.6279, 66.8136, 84.1742]],
       device='mps:0')

but I don't trust the results to be numerically correct - especially considering index likely overflows here. And indexing with out[-1][-1] will again yield

/AppleInternal/Library/BuildRoots/1c8f7852-1ca9-11f0-b28b-226177e5bb69/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:829: failed assertion `[MPSNDArray initWithDevice:descriptor:isTextureBacked:] Error: NDArray dimension length > INT_MAX'
Fatal Python error: Aborted

These errors can be triggered by setting

num_rois = 171196 # < 2**31 -> good
num_rois = 171206 # > 2**31 -> errors
rois = self._make_rois(img_size, num_imgs, dtype, num_rois=num_rois)

Should we add a check on output_size against INT_MAX for MPS? We should probably add a check in CPU as well to prevent a crash, but I consider it out of scope for this PR.

cc @Isalia20

@hvaara
Copy link
Copy Markdown
Contributor Author

hvaara commented Dec 10, 2025

Someone reached out to me on LinkedIn regarding this bugfix. I should fix the merge conflicts and try to get it merged again

@Zemke
Copy link
Copy Markdown

Zemke commented Jan 31, 2026

I noticed the merge conflicts were mainly due to whitespace changes in your commits. I have removed the whitespace changes and were then able to merge the main branch easily. I made a pull request to your branch and I hope the solution for this ticket can be continued from there. hvaara#1

@hvaara hvaara force-pushed the fix-roi-align-mps branch from e70ceda to 83361f4 Compare February 1, 2026 17:07
@hvaara
Copy link
Copy Markdown
Contributor Author

hvaara commented Feb 1, 2026

I rebased to main. Thanks for the nudge @Zemke.

@malfet @NicolasHug can we merge this? The issue with output_size raised in #9100 (comment) is present for CPU as well so I'm inclined to consider it out of scope for this PR, though happy to make changes in a follow-up if necessary.

@malfet
Copy link
Copy Markdown
Contributor

malfet commented Feb 1, 2026

@hvaara if CI is green, then LGTM

@hvaara
Copy link
Copy Markdown
Contributor Author

hvaara commented Feb 1, 2026

@malfet I fixed the lint errors, but that seems to have aborted the tests. Do you mind clicking the button again? Thanks!

@malfet malfet merged commit 75c6fbe into pytorch:main Feb 2, 2026
62 checks passed
@github-actions
Copy link
Copy Markdown

github-actions bot commented Feb 2, 2026

Hey @malfet!

You merged this PR, but no labels were added.
The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

@hvaara hvaara deleted the fix-roi-align-mps branch February 8, 2026 21:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

aten::nonzero calls taking a huge amount of time when using MPS backend vs CPU

5 participants