Skip to content

Migrate masked_scatter_ CUDA to ATen#50039

Closed
kshitij12345 wants to merge 15 commits intopytorch:masterfrom
kshitij12345:migrate/masked_scatter_cuda
Closed

Migrate masked_scatter_ CUDA to ATen#50039
kshitij12345 wants to merge 15 commits intopytorch:masterfrom
kshitij12345:migrate/masked_scatter_cuda

Conversation

@kshitij12345
Copy link
Copy Markdown
Collaborator

Fixes #49542

@kshitij12345
Copy link
Copy Markdown
Collaborator Author

kshitij12345 commented Jan 4, 2021

Benchmark

[----------------------- masked_scatter_ mdtype torch.uint8 input dtypetorch.int32 -----------------------]
                            |  After PR torch.masked_scatter_ CUDA  |  Before PR torch.masked_scatter_ CUDA
1 threads: ------------------------------------------------------------------------------------------------
      (128,)                |                  80.2                 |                  90.6                
      (128, 128)            |                  84.2                 |                  94.4                
      (2, 512, 128)         |                  88.6                 |                  98.5                
      (2, 16, 256, 128)     |                 124.3                 |                 156.1                
      (4, 2, 16, 256, 128)  |                 387.6                 |                 516.8                

Times are in microseconds (us).

[---------------------- masked_scatter_ mdtype torch.uint8 input dtypetorch.float32 ----------------------]
                            |  After PR torch.masked_scatter_ CUDA  |  Before PR torch.masked_scatter_ CUDA
1 threads: ------------------------------------------------------------------------------------------------
      (128,)                |                  72.9                 |                  81.0                
      (128, 128)            |                  83.7                 |                  89.9                
      (2, 512, 128)         |                  87.0                 |                  92.5                
      (2, 16, 256, 128)     |                 124.6                 |                 155.4                
      (4, 2, 16, 256, 128)  |                 388.0                 |                 516.2                

Times are in microseconds (us).

[---------------------- masked_scatter_ mdtype torch.uint8 input dtypetorch.float64 ----------------------]
                            |  After PR torch.masked_scatter_ CUDA  |  Before PR torch.masked_scatter_ CUDA
1 threads: ------------------------------------------------------------------------------------------------
      (128,)                |                  73.9                 |                  86.7                
      (128, 128)            |                  83.8                 |                  89.7                
      (2, 512, 128)         |                  88.3                 |                  94.6                
      (2, 16, 256, 128)     |                 143.3                 |                 171.0                
      (4, 2, 16, 256, 128)  |                 461.1                 |                 581.5                

Times are in microseconds (us).

[------------------------ masked_scatter_ mdtype torch.bool input dtypetorch.int32 -----------------------]
                            |  After PR torch.masked_scatter_ CUDA  |  Before PR torch.masked_scatter_ CUDA
1 threads: ------------------------------------------------------------------------------------------------
      (128,)                |                  49.4                 |                  58.4                
      (128, 128)            |                  58.8                 |                  68.3                
      (2, 512, 128)         |                  64.6                 |                  66.2                
      (2, 16, 256, 128)     |                 128.7                 |                 158.5                
      (4, 2, 16, 256, 128)  |                 395.6                 |                 531.0                

Times are in microseconds (us).

[----------------------- masked_scatter_ mdtype torch.bool input dtypetorch.float32 ----------------------]
                            |  After PR torch.masked_scatter_ CUDA  |  Before PR torch.masked_scatter_ CUDA
1 threads: ------------------------------------------------------------------------------------------------
      (128,)                |                  48.6                 |                  58.9                
      (128, 128)            |                  58.3                 |                  69.0                
      (2, 512, 128)         |                  63.8                 |                  66.7                
      (2, 16, 256, 128)     |                 126.0                 |                 157.2                
      (4, 2, 16, 256, 128)  |                 395.2                 |                 531.5                

Times are in microseconds (us).

[----------------------- masked_scatter_ mdtype torch.bool input dtypetorch.float64 ----------------------]
                            |  After PR torch.masked_scatter_ CUDA  |  Before PR torch.masked_scatter_ CUDA
1 threads: ------------------------------------------------------------------------------------------------
      (128,)                |                  48.6                 |                  61.9                
      (128, 128)            |                  58.1                 |                  70.7                
      (2, 512, 128)         |                  54.5                 |                  70.9                
      (2, 16, 256, 128)     |                 143.6                 |                 176.3                
      (4, 2, 16, 256, 128)  |                 467.8                 |                 594.5                

Times are in microseconds (us).
Code
import torch
import itertools
import time
from torch.utils.benchmark import Timer
from torch.utils.benchmark import Compare
import sys
import pickle

print('Using pytorch %s' % (torch.__version__))

shapes = [(128,), (128, 128), (2, 512, 128), (2, 16, 256, 128), (4, 2, 16, 256, 128)]
results = []
repeats = 10

for m_dtype in [torch.uint8, torch.bool]:
    for dtype in [torch.int32, torch.float, torch.double]:
        for mat1_shape in shapes:
            mat1 = torch.zeros(*mat1_shape, dtype=dtype, device='cuda')
            mask = (torch.randn(mat1.shape, device='cuda') > 0).to(m_dtype)
            source = torch.ones(mask.sum(), dtype=dtype, device='cuda')

            tasks = [("mat1.masked_scatter_(mask, source)", "After PR torch.masked_scatter_ CUDA")]
            timers = [Timer(stmt=stmt, label=f"masked_scatter_ mdtype {m_dtype} input dtype{dtype}", sub_label=f"{(mat1_shape)}", description=label, globals=globals()) for stmt, label in tasks]

            for i, timer in enumerate(timers * repeats):
                results.append(
                    timer.blocked_autorange()
                )
                print(f"\r{i + 1} / {len(timers) * repeats}", end="")
                sys.stdout.flush()

with open('after_pr.pkl', 'wb') as f:
    pickle.dump(results, f)

comparison = Compare(results)
comparison.print()

Comparison Script

import torch
import itertools
import time
from torch.utils.benchmark import Timer
from torch.utils.benchmark import Compare
import sys
import pickle

with open('after_pr.pkl', 'rb') as f:
    after_results = pickle.load(f)

with open('before_pr.pkl', 'rb') as f:
    before_results = pickle.load(f)

comparison = Compare(after_results + before_results)
comparison.print()

@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Jan 4, 2021

💊 CI failures summary and remediations

As of commit 4852dcf (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

@codecov
Copy link
Copy Markdown

codecov Bot commented Jan 4, 2021

Codecov Report

Merging #50039 (4852dcf) into master (789f6f1) will decrease coverage by 0.00%.
The diff coverage is n/a.

@@            Coverage Diff             @@
##           master   #50039      +/-   ##
==========================================
- Coverage   80.97%   80.97%   -0.01%     
==========================================
  Files        1919     1919              
  Lines      209785   209785              
==========================================
- Hits       169875   169867       -8     
- Misses      39910    39918       +8     

@kshitij12345 kshitij12345 marked this pull request as ready for review January 14, 2021 04:39
@mrshenli mrshenli requested review from anjali411 and ngimel January 15, 2021 03:11
@mrshenli mrshenli added complex_autograd module: complex Related to complex number support in PyTorch module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 15, 2021
Copy link
Copy Markdown
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add benchmarks for bigger sizes, current benchmarking seems to be mostly independent of size, which means that it's mostly benchmarking overheads.


template <typename mask_t>
void masked_scatter_cuda_impl(Tensor& self, const Tensor& mask, const Tensor& source){
ptrdiff_t srcSize = source.numel();
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto srcSize, source.numel() is int64_t and casting it to ptrdiff_t is confusing

ptrdiff_t srcSize = source.numel();

// Determine our output size
ptrdiff_t totalElements = mask.sum().item<ptrdiff_t>();
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto, item<int64_t>

// FIXME: there appears to be a bug in Thrust (CUDA 7.0) for mixed
// iterator prefix sums? Convert `mask` to the same datatype as what
// we're accumulating the prefix sum in (int64_t) to get around it
auto maskLong = mask.to(at::kLong);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this bug must be fixed by now, can you try getting rid of this workaround?

maskPrefixSumData(maskPrefixSum.data_ptr<int64_t>());

thrust::exclusive_scan(
#if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't support CUDA_VERSION less than 9, this #if is not necessary


auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());

thrust::device_ptr<int64_t>
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this assumes maskLong is contiguous, but now with to preserving memory format it's not guaranteed to be true. If you get rid of maskLong, there are no guaranteed on mask at all. Please explicitly call contiguous on mask or maskLong.

}

// `mask` and `self` must have the same number of elements
TORCH_CHECK(self.numel() == b_mask.numel(), "Number of elements of self != Number of elements in mask");
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to cpu PR, this check does not seem to be necessary. Consider unifying common checks in a follow-up PR

@kshitij12345
Copy link
Copy Markdown
Collaborator Author

@ngimel Gentle Ping

maskPrefixSum.data_ptr<int64_t>());

thrust::exclusive_scan(
thrust::cuda::par(allocator).on(c10::cuda::getCurrentCUDAStream()),
Copy link
Copy Markdown
Collaborator

@ngimel ngimel Jan 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please check if it works for numel > 2**31? There used to be some bugs for large cumsums.
Edit: actually, nevermind, LongTensor with this many elements will be >16GB, so it's not easy to test.

Copy link
Copy Markdown
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments.

Comment thread test/test_torch.py Outdated
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
for maskType in [torch.uint8, torch.bool]:
for dt in torch.testing.get_all_dtypes():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use @dtypes decorator instead of a loop

Comment thread test/test_torch.py Outdated
self.assertEqual(dest, dest2, atol=0, rtol=0)

# make source bigger than number of 1s in mask
src = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=dt, device=device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

your src was bigger in the original test case, what's the point of this one?

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@ngimel merged this pull request in eaf5ca0.

laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
Fixes pytorch#49542

Pull Request resolved: pytorch#50039

Reviewed By: heitorschueroff

Differential Revision: D26096247

Pulled By: ngimel

fbshipit-source-id: ec1810d3412e0d7ab6b950265a3123519ad886c1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed complex_autograd Merged module: complex Related to complex number support in PyTorch module: cuda Related to torch.cuda, and CUDA support in general open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Migrate masked_scatter from TH to ATen (CUDA)

5 participants