Skip to content

Migrate masked_scatter_ CPU to ATen#49732

Closed
kshitij12345 wants to merge 8 commits intopytorch:masterfrom
kshitij12345:migrate/masked_scatter_cpu
Closed

Migrate masked_scatter_ CPU to ATen#49732
kshitij12345 wants to merge 8 commits intopytorch:masterfrom
kshitij12345:migrate/masked_scatter_cpu

Conversation

@kshitij12345
Copy link
Copy Markdown
Collaborator

Fixes #49541

Reference: #24507

@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Dec 22, 2020

💊 CI failures summary and remediations

As of commit cc003e6 (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_linux_bionic_py3_8_gcc9_coverage_test1 (1/1)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Dec 22 16:47:39 [E request_callback_no_python.cpp:636] Received error while processing request type 258: RuntimeError: Can not pickle torch.futures.Future
Dec 22 16:47:39 At:
Dec 22 16:47:39   /opt/conda/lib/python3.8/site-packages/torch/distributed/rpc/internal.py(120): serialize
Dec 22 16:47:39   /opt/conda/lib/python3.8/site-packages/torch/distributed/rpc/internal.py(172): serialize
Dec 22 16:47:39 
Dec 22 16:47:39 [E request_callback_no_python.cpp:636] Received error while processing request type 258: RuntimeError: Can not pickle torch.futures.Future
Dec 22 16:47:39 
Dec 22 16:47:39 At:
Dec 22 16:47:39   /opt/conda/lib/python3.8/site-packages/torch/distributed/rpc/internal.py(120): serialize
Dec 22 16:47:39   /opt/conda/lib/python3.8/site-packages/torch/distributed/rpc/internal.py(172): serialize
Dec 22 16:47:39 
Dec 22 16:47:39 [E request_callback_no_python.cpp:636] Received error while processing request type 258: RuntimeError: Can not pickle torch.futures.Future
Dec 22 16:47:39 
Dec 22 16:47:39 At:
Dec 22 16:47:39   /opt/conda/lib/python3.8/site-packages/torch/distributed/rpc/internal.py(120): serialize
Dec 22 16:47:39   /opt/conda/lib/python3.8/site-packages/torch/distributed/rpc/internal.py(172): serialize
Dec 22 16:47:39 
Dec 22 16:47:39 [W tensorpipe_agent.cpp:547] RPC agent for worker3 encountered error when reading incoming request from worker2: EOF: end of file (this is expected to happen during shutdown)
Dec 22 16:47:39 [W tensorpipe_agent.cpp:547] RPC agent for worker1 encountered error when reading incoming request from worker0: EOF: end of file (this is expected to happen during shutdown)
Dec 22 16:47:40 ok (2.445s)
Dec 22 16:47:42   test_return_future_remote (__main__.TensorPipeRpcTestWithSpawn) ... [W tensorpipe_agent.cpp:547] RPC agent for worker2 encountered error when reading incoming request from worker0: EOF: end of file (this is expected to happen during shutdown)
Dec 22 16:47:42 [W tensorpipe_agent.cpp:547] RPC agent for worker1 encountered error when reading incoming request from worker0: EOF: end of file (this is expected to happen during shutdown)

1 job timed out:

  • pytorch_linux_bionic_py3_8_gcc9_coverage_test1

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

This comment has been revised 9 times.

@kshitij12345
Copy link
Copy Markdown
Collaborator Author

Benchmark

Before PR

10 / 10[ masked_scatter_ mdtype torch.uint8 input dtypetorch.int32 ]
                          |  torch.masked_scatter_ CPU
1 threads: -------------------------------------------
      (32,)               |             14.5          
      (32, 32)            |             17.3          
      (2, 16, 32)         |             17.0          
      (2, 16, 32, 32)     |            119.5          
      (4, 2, 16, 32, 32)  |            521.5          

Times are in microseconds (us).

[ masked_scatter_ mdtype torch.uint8 input dtypetorch.float32 ]
                          |  torch.masked_scatter_ CPU
1 threads: -------------------------------------------
      (32,)               |             14.4          
      (32, 32)            |             16.9          
      (2, 16, 32)         |             16.9          
      (2, 16, 32, 32)     |            121.2          
      (4, 2, 16, 32, 32)  |            543.1          

Times are in microseconds (us).

[ masked_scatter_ mdtype torch.uint8 input dtypetorch.float64 ]
                          |  torch.masked_scatter_ CPU
1 threads: -------------------------------------------
      (32,)               |             15.9          
      (32, 32)            |             20.5          
      (2, 16, 32)         |             20.7          
      (2, 16, 32, 32)     |            125.9          
      (4, 2, 16, 32, 32)  |            498.8          

Times are in microseconds (us).

[ masked_scatter_ mdtype torch.bool input dtypetorch.int32 ]
                          |  torch.masked_scatter_ CPU
1 threads: -------------------------------------------
      (32,)               |              881.3        
      (32, 32)            |             2982.5        
      (2, 16, 32)         |             3074.9        
      (2, 16, 32, 32)     |            99410.4        
      (4, 2, 16, 32, 32)  |           461798.1        

Times are in nanoseconds (ns).

[ masked_scatter_ mdtype torch.bool input dtypetorch.float32 ]
                          |  torch.masked_scatter_ CPU
1 threads: -------------------------------------------
      (32,)               |              873.4        
      (32, 32)            |             2748.0        
      (2, 16, 32)         |             3133.0        
      (2, 16, 32, 32)     |           100659.8        
      (4, 2, 16, 32, 32)  |           471495.6        

Times are in nanoseconds (ns).

[ masked_scatter_ mdtype torch.bool input dtypetorch.float64 ]
                          |  torch.masked_scatter_ CPU
1 threads: -------------------------------------------
      (32,)               |              876.7        
      (32, 32)            |             3169.1        
      (2, 16, 32)         |             3235.5        
      (2, 16, 32, 32)     |            99530.6        
      (4, 2, 16, 32, 32)  |           461690.9        

Times are in nanoseconds (ns).

After PR

10 / 10[ masked_scatter_ mdtype torch.uint8 input dtypetorch.int32 ]
                          |  torch.masked_scatter_ CPU
1 threads: -------------------------------------------
      (32,)               |             14.6          
      (32, 32)            |             17.0          
      (2, 16, 32)         |             16.9          
      (2, 16, 32, 32)     |            135.8          
      (4, 2, 16, 32, 32)  |            523.4          

Times are in microseconds (us).

[ masked_scatter_ mdtype torch.uint8 input dtypetorch.float32 ]
                          |  torch.masked_scatter_ CPU
1 threads: -------------------------------------------
      (32,)               |             14.1          
      (32, 32)            |             16.2          
      (2, 16, 32)         |             16.4          
      (2, 16, 32, 32)     |            133.4          
      (4, 2, 16, 32, 32)  |            493.5          

Times are in microseconds (us).

[ masked_scatter_ mdtype torch.uint8 input dtypetorch.float64 ]
                          |  torch.masked_scatter_ CPU
1 threads: -------------------------------------------
      (32,)               |             12.8          
      (32, 32)            |             15.3          
      (2, 16, 32)         |             16.6          
      (2, 16, 32, 32)     |            134.2          
      (4, 2, 16, 32, 32)  |            515.6          

Times are in microseconds (us).

[ masked_scatter_ mdtype torch.bool input dtypetorch.int32 ]
                          |  torch.masked_scatter_ CPU
1 threads: -------------------------------------------
      (32,)               |              808.3        
      (32, 32)            |             2136.1        
      (2, 16, 32)         |             2017.6        
      (2, 16, 32, 32)     |            98196.0        
      (4, 2, 16, 32, 32)  |           445502.8        

Times are in nanoseconds (ns).

[ masked_scatter_ mdtype torch.bool input dtypetorch.float32 ]
                          |  torch.masked_scatter_ CPU
1 threads: -------------------------------------------
      (32,)               |              796.5        
      (32, 32)            |             3719.5        
      (2, 16, 32)         |             3760.2        
      (2, 16, 32, 32)     |           116965.8        
      (4, 2, 16, 32, 32)  |           461250.2        

Times are in nanoseconds (ns).

[ masked_scatter_ mdtype torch.bool input dtypetorch.float64 ]
                          |  torch.masked_scatter_ CPU
1 threads: -------------------------------------------
      (32,)               |              792.5        
      (32, 32)            |             2102.3        
      (2, 16, 32)         |             2096.0        
      (2, 16, 32, 32)     |            97564.0        
      (4, 2, 16, 32, 32)  |           435919.5        

Times are in nanoseconds (ns).

Code

Benchmarking Code
import torch
import itertools
import time
from torch.utils.benchmark import Timer
from torch.utils.benchmark import Compare
import sys

print('Using pytorch %s' % (torch.__version__))

shapes = [(32,), (32, 32), (2, 16, 32), (2, 16, 32, 32), (4, 2, 16, 32, 32)]
results = []
repeats = 10

for m_dtype in [torch.uint8, torch.bool]:
    for dtype in [torch.int32, torch.float, torch.double]:
        for mat1_shape in shapes:
            mat1 = torch.zeros(*mat1_shape, dtype=dtype, device='cpu')
            mask = (torch.randn(mat1.shape) > 0).to(m_dtype)
            source = torch.ones(mask.sum(), dtype=dtype, device='cpu')

            tasks = [("mat1.masked_scatter_(mask, source)", "torch.masked_scatter_ CPU")]
            timers = [Timer(stmt=stmt, label=f"masked_scatter_ mdtype {m_dtype} input dtype{dtype}", sub_label=f"{(mat1_shape)}", description=label, globals=globals()) for stmt, label in tasks]

            for i, timer in enumerate(timers * repeats):
                results.append(
                    timer.blocked_autorange()
                )
                print(f"\r{i + 1} / {len(timers) * repeats}", end="")
                sys.stdout.flush()
comparison = Compare(results)
comparison.print()

@kshitij12345 kshitij12345 marked this pull request as ready for review December 23, 2020 04:40
@kshitij12345
Copy link
Copy Markdown
Collaborator Author

@VitalyFedyunin Please review

@kshitij12345 kshitij12345 changed the title [WIP] Migrate masked_scatter_ CPU to ATen Migrate masked_scatter_ CPU to ATen Dec 23, 2020
Copy link
Copy Markdown
Collaborator Author

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the CUDA variant is tested.

pytorch/test/test_torch.py

Lines 1198 to 1239 in 963f762

def test_masked_scatter(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
for maskType in [torch.uint8, torch.bool]:
for dt in torch.testing.get_all_dtypes():
num_copy, num_dest = 3, 10
dest = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dt)
dest2 = dest.clone()
src = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=dt)
mask = torch.tensor((0, 0, 0, 0, 1, 0, 1, 0, 1, 0), dtype=maskType)
if dt == torch.bool:
# torch.bool is a special case and is being tested
# in a separate test
continue
# TODO: update test when masked scatter is supported for complex
if dt == torch.half or dt.is_complex:
self.assertRaises(RuntimeError, lambda: dest.masked_scatter_(mask, src))
continue
dest.masked_scatter_(mask, src)
j = 0
for i in range(num_dest):
if mask[i]:
dest2[i] = src[j]
j += 1
self.assertEqual(dest, dest2, atol=0, rtol=0)
# make source bigger than number of 1s in mask
src = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=dt)
dest.masked_scatter_(mask, src)
# make src smaller. this should fail
src = torch.randn(num_copy - 1)
with self.assertRaises(RuntimeError):
dest.masked_scatter_(mask, src)
self.assertEqual(len(w), 27)
warn = 'masked_scatter_ received a mask with dtype torch.uint8,'
for wi in w:
self.assertEqual(str(wi.message)[0:55], str(warn))

Comment thread test/test_torch.py Outdated
with self.assertRaises(RuntimeError):
dest.masked_scatter_(mask, src)
self.assertEqual(len(w), 27)
self.assertEqual(len(w), 20)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the CUDA variant is tested.

pytorch/test/test_torch.py

Lines 1198 to 1239 in 963f762

def test_masked_scatter(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
for maskType in [torch.uint8, torch.bool]:
for dt in torch.testing.get_all_dtypes():
num_copy, num_dest = 3, 10
dest = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dt)
dest2 = dest.clone()
src = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=dt)
mask = torch.tensor((0, 0, 0, 0, 1, 0, 1, 0, 1, 0), dtype=maskType)
if dt == torch.bool:
# torch.bool is a special case and is being tested
# in a separate test
continue
# TODO: update test when masked scatter is supported for complex
if dt == torch.half or dt.is_complex:
self.assertRaises(RuntimeError, lambda: dest.masked_scatter_(mask, src))
continue
dest.masked_scatter_(mask, src)
j = 0
for i in range(num_dest):
if mask[i]:
dest2[i] = src[j]
j += 1
self.assertEqual(dest, dest2, atol=0, rtol=0)
# make source bigger than number of 1s in mask
src = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=dt)
dest.masked_scatter_(mask, src)
# make src smaller. this should fail
src = torch.randn(num_copy - 1)
with self.assertRaises(RuntimeError):
dest.masked_scatter_(mask, src)
self.assertEqual(len(w), 27)
warn = 'masked_scatter_ received a mask with dtype torch.uint8,'
for wi in w:
self.assertEqual(str(wi.message)[0:55], str(warn))

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other tests for masked_scatter,

pytorch/test/test_torch.py

Lines 4590 to 4600 in 963f762

def test_masked_scatter_bool_tensor(self, device):
src = torch.tensor([True, True, True], device=device)
dst = torch.tensor([False, False, False], device=device)
mask = torch.tensor([False, True, False], device=device)
dst.masked_scatter_(mask, src)
self.assertEqual(dst, torch.tensor([False, True, False], device=device))
mask = torch.tensor([True, False, True], device=device)
dst = dst.masked_scatter(mask, src)
self.assertEqual(dst, torch.tensor([True, True, True], device=device))

pytorch/test/test_torch.py

Lines 5129 to 5144 in 963f762

@onlyOnCPUAndCUDA
def test_masked_scatter_mem_overlap(self, device):
x = torch.rand((1,), device=device).expand((6,))
src = torch.rand((3,), device=device)
mask = torch.tensor([True, False, True, True, False, False], device=device)
with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
x.masked_scatter_(mask, src)
@onlyOnCPUAndCUDA
def test_index_select_mem_overlap(self, device):
x = torch.rand((1, 6), device=device).expand((2, 6))
y = torch.rand((3, 6), device=device)
ind = torch.tensor([0, 1], dtype=torch.int64, device=device)
with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
torch.index_select(y, 1, ind, out=x)

@kshitij12345 kshitij12345 requested a review from ngimel December 28, 2020 10:04
@mruberry mruberry added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 28, 2020
@kshitij12345
Copy link
Copy Markdown
Collaborator Author

@ngimel Please review

Copy link
Copy Markdown
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good, I left minor comments. For the future, you can serialize benchmark results and then load and compare them, that will make looking at the results easier.

"please use a mask with dtype torch.bool instead.");
}

TORCH_CHECK(self.numel() == b_mask.numel(), "Number of elements of self != Number of elements in mask");
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this check needed? expand_inplace should have failed if it could not satisfy this invariant?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. This is redundant. Thanks!

for (int64_t i = 0; i < n; i++) {
mask_t mask_value = *(mask_t*)(mask + mask_stride * i);
if (!is_mask_bool) {
TORCH_CHECK(mask_value == 0 || mask_value == 1, "Mask tensor can take 0 and 1 values only");
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mask_value <= 1

Comment thread test/test_torch.py Outdated
with self.assertRaises(RuntimeError):
dest.masked_scatter_(mask, src)
self.assertEqual(len(w), 27)
self.assertEqual(len(w), 20)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, cuda variant is not tested here. Why did the number of warnings change?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This happens because I think the test has bug.

For the snippet below,
src is always is of dtype=torch.float, while dtype of dest varies as per the loop.

# make src smaller. this should fail
src = torch.randn(num_copy - 1)
with self.assertRaises(RuntimeError):
    dest.masked_scatter_(mask, src)

So except for dest and src having dtype torch.float32, all other cases fail with Error due to mismatch in dtypes of dest and src.

In the new implementation, we first check if the types are src and dest are same and proceed with the warning only then. However, in the old implementation, we check warning condition first and then check for types of src and dest.

That is why the discrepancy.

TORCH_CHECK(
self.scalar_type() == source.scalar_type(),
"masked_scatter: expected self and source to have same dtypes but got",
self.scalar_type(),
" and ",
source.scalar_type());
TORCH_CHECK(self.device().type() == at::kCPU, "device type of self (", self.device().type(), ") is not CPU");
TORCH_CHECK(mask.device().type() == at::kCPU, "device type of mask (", mask.device().type(), ") is not CPU");
TORCH_CHECK(source.device().type() == at::kCPU, "device type of source (", source.device().type(), ") is not CPU");
Tensor b_mask;
std::tie(b_mask) = expand_inplace(self, mask, "masked_scatter_");
if (b_mask.dtype() == ScalarType::Byte) {
TORCH_WARN("masked_scatter_ received a mask with dtype torch.uint8, this behavior is now deprecated," \
"please use a mask with dtype torch.bool instead.");
}

Actually will fix the bug.
Thanks!

@codecov
Copy link
Copy Markdown

codecov Bot commented Jan 19, 2021

Codecov Report

Merging #49732 (5653227) into master (ce30dba) will decrease coverage by 0.00%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master   #49732      +/-   ##
==========================================
- Coverage   80.65%   80.64%   -0.01%     
==========================================
  Files        1913     1913              
  Lines      208121   208017     -104     
==========================================
- Hits       167859   167764      -95     
+ Misses      40262    40253       -9     

@kshitij12345
Copy link
Copy Markdown
Collaborator Author

@ngimel PTAL :)

@kshitij12345
Copy link
Copy Markdown
Collaborator Author

@ngimel Gentle Ping

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@ngimel merged this pull request in a291b25.

laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
Fixes pytorch#49541

Reference: pytorch#24507

Pull Request resolved: pytorch#49732

Reviewed By: ejguan

Differential Revision: D25991438

Pulled By: ngimel

fbshipit-source-id: a43bd0bfe043d8e32a6cadbbf736a0eaa697e7ec
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Migrate masked_scatter from TH to ATen (CPU)

5 participants