Skip to content

Commit eaf5ca0

Browse files
kshitij12345facebook-github-bot
authored andcommitted
Migrate masked_scatter_ CUDA to ATen (#50039)
Summary: Fixes #49542 Pull Request resolved: #50039 Reviewed By: heitorschueroff Differential Revision: D26096247 Pulled By: ngimel fbshipit-source-id: ec1810d3412e0d7ab6b950265a3123519ad886c1
1 parent 1c8d11c commit eaf5ca0

8 files changed

Lines changed: 153 additions & 396 deletions

File tree

aten/src/ATen/LegacyTHFunctionsCUDA.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ namespace cuda {
2020

2121
Tensor & _th_masked_fill_(Tensor & self, const Tensor & mask, Scalar value);
2222
Tensor & _th_masked_fill_bool_(Tensor & self, const Tensor & mask, Scalar value);
23-
Tensor & _th_masked_scatter_(Tensor & self, const Tensor & mask, const Tensor & source);
24-
Tensor & _th_masked_scatter_bool_(Tensor & self, const Tensor & mask, const Tensor & source);
2523
Tensor & _th_index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source);
2624
Tensor & _th_take_out(Tensor & result, const Tensor & self, const Tensor & index);
2725
Tensor _th_take(const Tensor & self, const Tensor & index);

aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp

Lines changed: 0 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -200,166 +200,6 @@ Tensor & _th_masked_fill_bool_(Tensor & self, const Tensor & mask, Scalar value)
200200
}
201201
return self;
202202
}
203-
Tensor & _th_masked_scatter_(Tensor & self, const Tensor & mask, const Tensor & source) {
204-
// DeviceGuard omitted
205-
auto dispatch_scalar_type = infer_scalar_type(self);
206-
207-
switch (dispatch_scalar_type) {
208-
case ScalarType::Bool: {
209-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type);
210-
auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_", false, DeviceType::CUDA, ScalarType::Byte);
211-
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type);
212-
THCudaBoolTensor_maskedCopy(globalContext().getTHCState(), self_, mask_, source_);
213-
break;
214-
}
215-
case ScalarType::Byte: {
216-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type);
217-
auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_", false, DeviceType::CUDA, ScalarType::Byte);
218-
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type);
219-
THCudaByteTensor_maskedCopy(globalContext().getTHCState(), self_, mask_, source_);
220-
break;
221-
}
222-
case ScalarType::Char: {
223-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type);
224-
auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_", false, DeviceType::CUDA, ScalarType::Byte);
225-
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type);
226-
THCudaCharTensor_maskedCopy(globalContext().getTHCState(), self_, mask_, source_);
227-
break;
228-
}
229-
case ScalarType::Double: {
230-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type);
231-
auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_", false, DeviceType::CUDA, ScalarType::Byte);
232-
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type);
233-
THCudaDoubleTensor_maskedCopy(globalContext().getTHCState(), self_, mask_, source_);
234-
break;
235-
}
236-
case ScalarType::Float: {
237-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type);
238-
auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_", false, DeviceType::CUDA, ScalarType::Byte);
239-
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type);
240-
THCudaTensor_maskedCopy(globalContext().getTHCState(), self_, mask_, source_);
241-
break;
242-
}
243-
case ScalarType::Int: {
244-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type);
245-
auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_", false, DeviceType::CUDA, ScalarType::Byte);
246-
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type);
247-
THCudaIntTensor_maskedCopy(globalContext().getTHCState(), self_, mask_, source_);
248-
break;
249-
}
250-
case ScalarType::Long: {
251-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type);
252-
auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_", false, DeviceType::CUDA, ScalarType::Byte);
253-
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type);
254-
THCudaLongTensor_maskedCopy(globalContext().getTHCState(), self_, mask_, source_);
255-
break;
256-
}
257-
case ScalarType::Short: {
258-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type);
259-
auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_", false, DeviceType::CUDA, ScalarType::Byte);
260-
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type);
261-
THCudaShortTensor_maskedCopy(globalContext().getTHCState(), self_, mask_, source_);
262-
break;
263-
}
264-
case ScalarType::Half: {
265-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type);
266-
auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_", false, DeviceType::CUDA, ScalarType::Byte);
267-
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type);
268-
THCudaHalfTensor_maskedCopy(globalContext().getTHCState(), self_, mask_, source_);
269-
break;
270-
}
271-
case ScalarType::BFloat16: {
272-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type);
273-
auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_", false, DeviceType::CUDA, ScalarType::Byte);
274-
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type);
275-
THCudaBFloat16Tensor_maskedCopy(globalContext().getTHCState(), self_, mask_, source_);
276-
break;
277-
}
278-
default:
279-
AT_ERROR("_th_masked_scatter_ not supported on CUDAType for ", dispatch_scalar_type);
280-
}
281-
return self;
282-
}
283-
Tensor & _th_masked_scatter_bool_(Tensor & self, const Tensor & mask, const Tensor & source) {
284-
// DeviceGuard omitted
285-
auto dispatch_scalar_type = infer_scalar_type(self);
286-
287-
switch (dispatch_scalar_type) {
288-
case ScalarType::Bool: {
289-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type);
290-
auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_bool_", false, DeviceType::CUDA, ScalarType::Bool);
291-
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type);
292-
THCudaBoolTensor_maskedCopyBool(globalContext().getTHCState(), self_, mask_, source_);
293-
break;
294-
}
295-
case ScalarType::Byte: {
296-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type);
297-
auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_bool_", false, DeviceType::CUDA, ScalarType::Bool);
298-
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type);
299-
THCudaByteTensor_maskedCopyBool(globalContext().getTHCState(), self_, mask_, source_);
300-
break;
301-
}
302-
case ScalarType::Char: {
303-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type);
304-
auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_bool_", false, DeviceType::CUDA, ScalarType::Bool);
305-
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type);
306-
THCudaCharTensor_maskedCopyBool(globalContext().getTHCState(), self_, mask_, source_);
307-
break;
308-
}
309-
case ScalarType::Double: {
310-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type);
311-
auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_bool_", false, DeviceType::CUDA, ScalarType::Bool);
312-
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type);
313-
THCudaDoubleTensor_maskedCopyBool(globalContext().getTHCState(), self_, mask_, source_);
314-
break;
315-
}
316-
case ScalarType::Float: {
317-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type);
318-
auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_bool_", false, DeviceType::CUDA, ScalarType::Bool);
319-
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type);
320-
THCudaTensor_maskedCopyBool(globalContext().getTHCState(), self_, mask_, source_);
321-
break;
322-
}
323-
case ScalarType::Int: {
324-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type);
325-
auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_bool_", false, DeviceType::CUDA, ScalarType::Bool);
326-
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type);
327-
THCudaIntTensor_maskedCopyBool(globalContext().getTHCState(), self_, mask_, source_);
328-
break;
329-
}
330-
case ScalarType::Long: {
331-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type);
332-
auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_bool_", false, DeviceType::CUDA, ScalarType::Bool);
333-
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type);
334-
THCudaLongTensor_maskedCopyBool(globalContext().getTHCState(), self_, mask_, source_);
335-
break;
336-
}
337-
case ScalarType::Short: {
338-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type);
339-
auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_bool_", false, DeviceType::CUDA, ScalarType::Bool);
340-
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type);
341-
THCudaShortTensor_maskedCopyBool(globalContext().getTHCState(), self_, mask_, source_);
342-
break;
343-
}
344-
case ScalarType::Half: {
345-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type);
346-
auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_bool_", false, DeviceType::CUDA, ScalarType::Bool);
347-
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type);
348-
THCudaHalfTensor_maskedCopyBool(globalContext().getTHCState(), self_, mask_, source_);
349-
break;
350-
}
351-
case ScalarType::BFloat16: {
352-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type);
353-
auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_bool_", false, DeviceType::CUDA, ScalarType::Bool);
354-
auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type);
355-
THCudaBFloat16Tensor_maskedCopyBool(globalContext().getTHCState(), self_, mask_, source_);
356-
break;
357-
}
358-
default:
359-
AT_ERROR("_th_masked_scatter_bool_ not supported on CUDAType for ", dispatch_scalar_type);
360-
}
361-
return self;
362-
}
363203
Tensor & _th_index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
364204
// DeviceGuard omitted
365205
auto dispatch_scalar_type = infer_scalar_type(self);

aten/src/ATen/native/cuda/IndexKernel.cu

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@
1010
#include <ATen/cuda/detail/OffsetCalculator.cuh>
1111
#include <ATen/ExpandUtils.h>
1212
#include <ATen/MemoryOverlap.h>
13+
#include <ATen/native/cuda/Loops.cuh>
1314
#include <THC/THCTensorInfo.cuh>
15+
#include <THC/THCThrustAllocator.cuh>
16+
17+
#include <thrust/execution_policy.h>
18+
#include <thrust/device_ptr.h>
19+
#include <thrust/scan.h>
1420

1521
namespace at { namespace native {
1622

@@ -252,6 +258,103 @@ Tensor& take_out_cuda(Tensor& out, const Tensor& self, const Tensor& index) {
252258
return out;
253259
}
254260

261+
namespace {
262+
263+
template <typename mask_t>
264+
void masked_scatter_cuda_impl(Tensor& self, const Tensor& mask, const Tensor& source){
265+
auto srcSize = source.numel();
266+
267+
// Determine our output size
268+
auto totalElements = mask.sum().item<int64_t>();
269+
270+
// The number of `1` elements present in the mask must be <= the
271+
// number of elements available in `src`
272+
TORCH_CHECK(totalElements <= srcSize, "source nElements must be == mask `1` elements");
273+
274+
auto mask_cont = mask.contiguous();
275+
276+
// Use a prefix sum to determine the output locations of the masked elements
277+
auto maskPrefixSum = at::empty_like(mask, mask.options().dtype(kLong));
278+
279+
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
280+
281+
thrust::device_ptr<mask_t> maskData(mask_cont.data_ptr<mask_t>());
282+
thrust::device_ptr<int64_t> maskPrefixSumData(
283+
maskPrefixSum.data_ptr<int64_t>());
284+
285+
thrust::exclusive_scan(
286+
thrust::cuda::par(allocator).on(c10::cuda::getCurrentCUDAStream()),
287+
maskData,
288+
maskData + mask_cont.numel(),
289+
maskPrefixSumData);
290+
291+
// We are getting elements from `src` based on an offset from
292+
// `maskPrefixSum`, so that should be made contiguous too
293+
auto source_contig = source.contiguous();
294+
295+
auto iter = TensorIteratorConfig()
296+
.set_check_mem_overlap(false)
297+
.check_all_same_dtype(false)
298+
.resize_outputs(false)
299+
.add_output(self)
300+
.add_input(self)
301+
.add_input(mask_cont)
302+
.add_input(maskPrefixSum)
303+
.build();
304+
305+
AT_DISPATCH_ALL_TYPES_AND3(
306+
ScalarType::Bool,
307+
ScalarType::BFloat16,
308+
ScalarType::Half,
309+
self.scalar_type(),
310+
"masked_scatter_",
311+
[&]() {
312+
auto source_ptr = source_contig.data_ptr<scalar_t>();
313+
gpu_kernel(
314+
iter, [=] GPU_LAMBDA(scalar_t a, mask_t mask, int64_t maskPrefixSum) -> scalar_t {
315+
if (mask) {
316+
return source_ptr[maskPrefixSum];
317+
}
318+
return a;
319+
});
320+
cudaGetLastError();
321+
});
322+
}
323+
324+
} // anonymous namespace
325+
326+
Tensor & masked_scatter__cuda(Tensor& self, const Tensor& mask, const Tensor& source) {
327+
at::assert_no_internal_overlap(self);
328+
TORCH_CHECK(
329+
self.scalar_type() == source.scalar_type(),
330+
"masked_scatter: expected self and source to have same dtypes but got",
331+
self.scalar_type(),
332+
" and ",
333+
source.scalar_type());
334+
335+
TensorArg self_arg{self, "self", 1};
336+
TensorArg mask_arg{mask, "mask", 2};
337+
TensorArg source_arg{source, "source", 3};
338+
checkAllSameGPU("masked_scatter_", {self_arg, mask_arg, source_arg});
339+
340+
Tensor b_mask;
341+
std::tie(b_mask) = expand_inplace(self, mask, "masked_scatter_");
342+
343+
if (b_mask.dtype() == ScalarType::Byte) {
344+
TORCH_WARN("masked_scatter_ received a mask with dtype torch.uint8, this behavior is now deprecated," \
345+
"please use a mask with dtype torch.bool instead.");
346+
}
347+
348+
auto mask_dtype = b_mask.scalar_type();
349+
if (mask_dtype == ScalarType::Bool) {
350+
masked_scatter_cuda_impl<bool>(self, b_mask, source);
351+
} else {
352+
masked_scatter_cuda_impl<uint8_t>(self, b_mask, source);
353+
}
354+
355+
return self;
356+
}
357+
255358
REGISTER_DISPATCH(index_stub, &index_kernel);
256359
REGISTER_DISPATCH(index_put_stub, &index_put_kernel);
257360

aten/src/ATen/native/cuda/LegacyDefinitions.cpp

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -61,19 +61,4 @@ Tensor & masked_fill__cuda(Tensor& self, const Tensor & mask, const Tensor & val
6161
return self;
6262
}
6363

64-
Tensor & masked_scatter__cuda(Tensor& self, const Tensor & mask, const Tensor & source) {
65-
at::assert_no_internal_overlap(self);
66-
Tensor b_mask;
67-
std::tie(b_mask) = expand_inplace(self, mask, "masked_scatter_");
68-
// As we dispatch on self and TH is type-checked, we need different definitions.
69-
// This can be fixed by moving to ATen.
70-
if (b_mask.dtype() == at::ScalarType::Byte) {
71-
TORCH_WARN("masked_scatter_ received a mask with dtype torch.uint8, this behavior is now deprecated," \
72-
"please use a mask with dtype torch.bool instead.");
73-
return legacy::cuda::_th_masked_scatter_(self, b_mask, source);
74-
} else {
75-
return legacy::cuda::_th_masked_scatter_bool_(self, b_mask, source);
76-
}
77-
}
78-
7964
}} // namespace at::native

aten/src/THC/THCTensorMasked.cuh

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,6 @@ struct TensorMaskedFillOp {
2525
T value;
2626
};
2727

28-
template <typename T, typename MaskT, typename MaskPrefixSumT>
29-
struct TensorMaskedCopyOp {
30-
TensorMaskedCopyOp(T* s) : in(s) {}
31-
32-
__device__ inline void operator()(T* out,
33-
MaskT* mask,
34-
MaskPrefixSumT* maskPrefixSum) {
35-
if (*mask) {
36-
*out = in[*maskPrefixSum];
37-
}
38-
}
39-
40-
// Where we are copying from
41-
T* in;
42-
};
43-
4428
template <typename T, typename MaskT, typename MaskPrefixSumT>
4529
struct TensorMaskedSelectOp {
4630
TensorMaskedSelectOp(T* t) : out(t) {}

0 commit comments

Comments
 (0)