Migrate masked_scatter_ CUDA to ATen#50039
Migrate masked_scatter_ CUDA to ATen#50039kshitij12345 wants to merge 15 commits intopytorch:masterfrom
Conversation
|
Benchmark Codeimport torch
import itertools
import time
from torch.utils.benchmark import Timer
from torch.utils.benchmark import Compare
import sys
import pickle
print('Using pytorch %s' % (torch.__version__))
shapes = [(128,), (128, 128), (2, 512, 128), (2, 16, 256, 128), (4, 2, 16, 256, 128)]
results = []
repeats = 10
for m_dtype in [torch.uint8, torch.bool]:
for dtype in [torch.int32, torch.float, torch.double]:
for mat1_shape in shapes:
mat1 = torch.zeros(*mat1_shape, dtype=dtype, device='cuda')
mask = (torch.randn(mat1.shape, device='cuda') > 0).to(m_dtype)
source = torch.ones(mask.sum(), dtype=dtype, device='cuda')
tasks = [("mat1.masked_scatter_(mask, source)", "After PR torch.masked_scatter_ CUDA")]
timers = [Timer(stmt=stmt, label=f"masked_scatter_ mdtype {m_dtype} input dtype{dtype}", sub_label=f"{(mat1_shape)}", description=label, globals=globals()) for stmt, label in tasks]
for i, timer in enumerate(timers * repeats):
results.append(
timer.blocked_autorange()
)
print(f"\r{i + 1} / {len(timers) * repeats}", end="")
sys.stdout.flush()
with open('after_pr.pkl', 'wb') as f:
pickle.dump(results, f)
comparison = Compare(results)
comparison.print()Comparison Script import torch
import itertools
import time
from torch.utils.benchmark import Timer
from torch.utils.benchmark import Compare
import sys
import pickle
with open('after_pr.pkl', 'rb') as f:
after_results = pickle.load(f)
with open('before_pr.pkl', 'rb') as f:
before_results = pickle.load(f)
comparison = Compare(after_results + before_results)
comparison.print() |
💊 CI failures summary and remediationsAs of commit 4852dcf (more details on the Dr. CI page):
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
Codecov Report
@@ Coverage Diff @@
## master #50039 +/- ##
==========================================
- Coverage 80.97% 80.97% -0.01%
==========================================
Files 1919 1919
Lines 209785 209785
==========================================
- Hits 169875 169867 -8
- Misses 39910 39918 +8 |
ngimel
left a comment
There was a problem hiding this comment.
Please add benchmarks for bigger sizes, current benchmarking seems to be mostly independent of size, which means that it's mostly benchmarking overheads.
|
|
||
| template <typename mask_t> | ||
| void masked_scatter_cuda_impl(Tensor& self, const Tensor& mask, const Tensor& source){ | ||
| ptrdiff_t srcSize = source.numel(); |
There was a problem hiding this comment.
auto srcSize, source.numel() is int64_t and casting it to ptrdiff_t is confusing
| ptrdiff_t srcSize = source.numel(); | ||
|
|
||
| // Determine our output size | ||
| ptrdiff_t totalElements = mask.sum().item<ptrdiff_t>(); |
| // FIXME: there appears to be a bug in Thrust (CUDA 7.0) for mixed | ||
| // iterator prefix sums? Convert `mask` to the same datatype as what | ||
| // we're accumulating the prefix sum in (int64_t) to get around it | ||
| auto maskLong = mask.to(at::kLong); |
There was a problem hiding this comment.
this bug must be fixed by now, can you try getting rid of this workaround?
| maskPrefixSumData(maskPrefixSum.data_ptr<int64_t>()); | ||
|
|
||
| thrust::exclusive_scan( | ||
| #if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__ |
There was a problem hiding this comment.
we don't support CUDA_VERSION less than 9, this #if is not necessary
|
|
||
| auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); | ||
|
|
||
| thrust::device_ptr<int64_t> |
There was a problem hiding this comment.
this assumes maskLong is contiguous, but now with to preserving memory format it's not guaranteed to be true. If you get rid of maskLong, there are no guaranteed on mask at all. Please explicitly call contiguous on mask or maskLong.
| } | ||
|
|
||
| // `mask` and `self` must have the same number of elements | ||
| TORCH_CHECK(self.numel() == b_mask.numel(), "Number of elements of self != Number of elements in mask"); |
There was a problem hiding this comment.
similar to cpu PR, this check does not seem to be necessary. Consider unifying common checks in a follow-up PR
|
@ngimel Gentle Ping |
| maskPrefixSum.data_ptr<int64_t>()); | ||
|
|
||
| thrust::exclusive_scan( | ||
| thrust::cuda::par(allocator).on(c10::cuda::getCurrentCUDAStream()), |
There was a problem hiding this comment.
can you please check if it works for numel > 2**31? There used to be some bugs for large cumsums.
Edit: actually, nevermind, LongTensor with this many elements will be >16GB, so it's not easy to test.
| with warnings.catch_warnings(record=True) as w: | ||
| warnings.simplefilter("always") | ||
| for maskType in [torch.uint8, torch.bool]: | ||
| for dt in torch.testing.get_all_dtypes(): |
There was a problem hiding this comment.
please use @dtypes decorator instead of a loop
| self.assertEqual(dest, dest2, atol=0, rtol=0) | ||
|
|
||
| # make source bigger than number of 1s in mask | ||
| src = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=dt, device=device) |
There was a problem hiding this comment.
your src was bigger in the original test case, what's the point of this one?
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Fixes pytorch#49542 Pull Request resolved: pytorch#50039 Reviewed By: heitorschueroff Differential Revision: D26096247 Pulled By: ngimel fbshipit-source-id: ec1810d3412e0d7ab6b950265a3123519ad886c1
Fixes #49542