Migrate masked_scatter_ CPU to ATen#49732
Migrate masked_scatter_ CPU to ATen#49732kshitij12345 wants to merge 8 commits intopytorch:masterfrom
Conversation
💊 CI failures summary and remediationsAs of commit cc003e6 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
|
Benchmark Before PR After PR Code Benchmarking Codeimport torch
import itertools
import time
from torch.utils.benchmark import Timer
from torch.utils.benchmark import Compare
import sys
print('Using pytorch %s' % (torch.__version__))
shapes = [(32,), (32, 32), (2, 16, 32), (2, 16, 32, 32), (4, 2, 16, 32, 32)]
results = []
repeats = 10
for m_dtype in [torch.uint8, torch.bool]:
for dtype in [torch.int32, torch.float, torch.double]:
for mat1_shape in shapes:
mat1 = torch.zeros(*mat1_shape, dtype=dtype, device='cpu')
mask = (torch.randn(mat1.shape) > 0).to(m_dtype)
source = torch.ones(mask.sum(), dtype=dtype, device='cpu')
tasks = [("mat1.masked_scatter_(mask, source)", "torch.masked_scatter_ CPU")]
timers = [Timer(stmt=stmt, label=f"masked_scatter_ mdtype {m_dtype} input dtype{dtype}", sub_label=f"{(mat1_shape)}", description=label, globals=globals()) for stmt, label in tasks]
for i, timer in enumerate(timers * repeats):
results.append(
timer.blocked_autorange()
)
print(f"\r{i + 1} / {len(timers) * repeats}", end="")
sys.stdout.flush()
comparison = Compare(results)
comparison.print() |
|
@VitalyFedyunin Please review |
kshitij12345
left a comment
There was a problem hiding this comment.
I don't think the CUDA variant is tested.
Lines 1198 to 1239 in 963f762
| with self.assertRaises(RuntimeError): | ||
| dest.masked_scatter_(mask, src) | ||
| self.assertEqual(len(w), 27) | ||
| self.assertEqual(len(w), 20) |
|
@ngimel Please review |
ngimel
left a comment
There was a problem hiding this comment.
This looks good, I left minor comments. For the future, you can serialize benchmark results and then load and compare them, that will make looking at the results easier.
| "please use a mask with dtype torch.bool instead."); | ||
| } | ||
|
|
||
| TORCH_CHECK(self.numel() == b_mask.numel(), "Number of elements of self != Number of elements in mask"); |
There was a problem hiding this comment.
why is this check needed? expand_inplace should have failed if it could not satisfy this invariant?
There was a problem hiding this comment.
Right. This is redundant. Thanks!
| for (int64_t i = 0; i < n; i++) { | ||
| mask_t mask_value = *(mask_t*)(mask + mask_stride * i); | ||
| if (!is_mask_bool) { | ||
| TORCH_CHECK(mask_value == 0 || mask_value == 1, "Mask tensor can take 0 and 1 values only"); |
| with self.assertRaises(RuntimeError): | ||
| dest.masked_scatter_(mask, src) | ||
| self.assertEqual(len(w), 27) | ||
| self.assertEqual(len(w), 20) |
There was a problem hiding this comment.
Yeah, cuda variant is not tested here. Why did the number of warnings change?
There was a problem hiding this comment.
This happens because I think the test has bug.
For the snippet below,
src is always is of dtype=torch.float, while dtype of dest varies as per the loop.
# make src smaller. this should fail
src = torch.randn(num_copy - 1)
with self.assertRaises(RuntimeError):
dest.masked_scatter_(mask, src)So except for dest and src having dtype torch.float32, all other cases fail with Error due to mismatch in dtypes of dest and src.
In the new implementation, we first check if the types are src and dest are same and proceed with the warning only then. However, in the old implementation, we check warning condition first and then check for types of src and dest.
That is why the discrepancy.
pytorch/aten/src/ATen/native/TensorAdvancedIndexing.cpp
Lines 1124 to 1141 in cc003e6
Actually will fix the bug.
Thanks!
Codecov Report
@@ Coverage Diff @@
## master #49732 +/- ##
==========================================
- Coverage 80.65% 80.64% -0.01%
==========================================
Files 1913 1913
Lines 208121 208017 -104
==========================================
- Hits 167859 167764 -95
+ Misses 40262 40253 -9 |
|
@ngimel PTAL :) |
|
@ngimel Gentle Ping |
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Fixes pytorch#49541 Reference: pytorch#24507 Pull Request resolved: pytorch#49732 Reviewed By: ejguan Differential Revision: D25991438 Pulled By: ngimel fbshipit-source-id: a43bd0bfe043d8e32a6cadbbf736a0eaa697e7ec
Fixes #49541
Reference: #24507