Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
f9bbbf8
COO intersection primitives: performance improvement
nikitaved Jan 25, 2023
e0b2908
Update on "COO intersection primitives: performance improvement"
nikitaved Jan 25, 2023
44048e7
Update on "COO intersection primitives: performance improvement"
nikitaved Jan 25, 2023
425c0eb
Update on "COO intersection primitives: performance improvement"
nikitaved Jan 25, 2023
389dcd7
Update on "COO intersection primitives: performance improvement"
nikitaved Jan 26, 2023
8ff3749
Update on "COO intersection primitives: performance improvement"
nikitaved Jan 26, 2023
10f370b
Update on "COO intersection primitives: performance improvement"
nikitaved Jan 26, 2023
32634e6
Update on "COO intersection primitives: performance improvement"
nikitaved Jan 26, 2023
eadebea
Update on "COO intersection primitives: performance improvement"
nikitaved Jan 26, 2023
74e788c
Update on "COO intersection primitives: performance improvement"
nikitaved Jan 27, 2023
7e8bbbc
Update on "[WIP, do not review] COO intersection primitives: performa…
nikitaved Jan 30, 2023
c364d57
Update on "[WIP, do not review] COO intersection primitives: performa…
nikitaved Jan 30, 2023
4086e4c
Update on "[WIP, do not review] COO intersection primitives: performa…
nikitaved Jan 31, 2023
54727f7
Update on "[WIP, do not review] COO intersection primitives: performa…
nikitaved Jan 31, 2023
67ee800
Update on "[WIP, do not review] COO intersection primitives: performa…
nikitaved Jan 31, 2023
ebab869
Update on "[WIP, do not review] COO intersection primitives: performa…
nikitaved Feb 1, 2023
a212efd
Update on "[WIP, do not review] COO intersection primitives: performa…
nikitaved Feb 1, 2023
a604cb8
Update on "[WIP, do not review] COO intersection primitives: performa…
nikitaved Feb 1, 2023
7557a25
Update on "[WIP, do not review] COO intersection primitives: performa…
nikitaved Feb 1, 2023
239c487
Update on "[WIP, do not review] COO intersection primitives: performa…
nikitaved Feb 2, 2023
0370622
Update on "[WIP, do not review] COO intersection primitives: performa…
nikitaved Feb 2, 2023
4f0d616
Update on "[WIP, do not review] COO intersection primitives: performa…
nikitaved Feb 2, 2023
b82189b
Update on "[WIP, do not review] COO intersection primitives: performa…
nikitaved Feb 3, 2023
0ad045d
Update on "[WIP, do not review] COO intersection primitives: performa…
nikitaved Feb 3, 2023
1047b52
Update on "[WIP, do not review] COO intersection primitives: performa…
nikitaved Feb 3, 2023
b22bb05
Update on "[WIP, do not review] COO intersection primitives: performa…
nikitaved Feb 3, 2023
0b3e027
Update on "[WIP, do not review] COO intersection primitives: performa…
nikitaved Feb 3, 2023
820f309
Update on "[WIP, do not review] COO intersection primitives: performa…
nikitaved Feb 3, 2023
6d19559
Update on "[WIP, do not review] COO intersection primitives: performa…
nikitaved Feb 3, 2023
1754314
Update on "[WIP, do not review] COO intersection primitives: performa…
nikitaved Feb 6, 2023
a2bb926
Update on "[WIP, do not review] COO intersection primitives: performa…
nikitaved Feb 6, 2023
05f45b9
Update on "[WIP, do not review] COO intersection primitives: performa…
nikitaved Feb 6, 2023
dbe8602
Update on "[WIP, do not review] COO intersection primitives: performa…
nikitaved Feb 6, 2023
3c3f4ca
Update on "[WIP, do not review] COO intersection primitives: performa…
nikitaved Feb 6, 2023
00214dd
Update on "[WIP, do not review] COO intersection primitives: performa…
nikitaved Feb 7, 2023
b06e928
Update on "[WIP, do not review] COO intersection primitives: performa…
nikitaved Feb 7, 2023
c832dad
Update on "COO intersection primitives: performance improvement"
nikitaved Feb 7, 2023
334bbad
Update on "COO intersection primitives: performance improvement"
nikitaved Feb 7, 2023
65ea434
Update on "COO intersection primitives: performance improvement"
nikitaved Feb 7, 2023
5132216
Update on "COO intersection primitives: performance improvement"
nikitaved Feb 8, 2023
a70c4bb
Update on "COO intersection primitives: performance improvement"
nikitaved Feb 8, 2023
422f813
Update on "COO intersection primitives: performance improvement"
nikitaved Feb 8, 2023
b5ced03
Update on "COO intersection primitives: performance improvement"
nikitaved Feb 8, 2023
3e9f650
Update on "COO intersection primitives: performance improvement"
nikitaved Feb 8, 2023
a24ebfa
Update on "COO intersection primitives: performance improvement"
nikitaved Feb 9, 2023
d2566d2
Update on "COO intersection primitives: performance improvement"
nikitaved Feb 10, 2023
336b8bb
Update on "COO intersection primitives: performance improvement"
nikitaved Feb 13, 2023
efec509
Update on "COO intersection primitives: performance improvement"
nikitaved Feb 13, 2023
04614ca
Update on "COO intersection primitives: performance improvement"
nikitaved Feb 15, 2023
444142d
Update on "COO intersection primitives: performance improvement"
nikitaved Feb 16, 2023
60565e0
Update on "COO intersection primitives: performance improvement"
nikitaved Feb 17, 2023
5003428
Update on "COO intersection primitives: performance improvement"
nikitaved Feb 19, 2023
5b2de2e
Update on "COO intersection primitives: performance improvement"
nikitaved Feb 19, 2023
c1a747a
Update on "COO intersection primitives: performance improvement"
nikitaved Feb 20, 2023
477d00a
Update on "COO intersection primitives: performance improvement"
nikitaved Feb 20, 2023
f3758d7
Update on "COO intersection primitives: performance improvement"
nikitaved Feb 21, 2023
72be903
Update on "COO intersection primitives: performance improvement"
nikitaved Mar 2, 2023
2b4760a
Update on "COO intersection primitives: performance improvement"
nikitaved Mar 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions aten/src/ATen/SparseTensorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/tensor.h>
#endif

namespace at {
Expand Down Expand Up @@ -119,5 +120,65 @@ TORCH_API Tensor flatten_indices_by_dims(
// Find the CSR representation for a row `indices` from the COO format
TORCH_API Tensor coo_to_csr(const int64_t* indices, int64_t dim, int64_t nnz);

template <size_t static_shape_max_len>
class TensorGeometryHolder {
using geometry_holder_t = std::array<int64_t, static_shape_max_len>;

public:
explicit TensorGeometryHolder(
IntArrayRef sizes,
IntArrayRef strides,
TensorOptions options = {}) {
std::copy(sizes.begin(), sizes.end(), t_sizes.begin());
std::copy(strides.begin(), strides.end(), t_strides.begin());
}

explicit TensorGeometryHolder(const Tensor& t)
: TensorGeometryHolder(t.sizes(), t.strides()) {}

auto operator*() const {
return std::make_tuple(t_sizes, t_strides);
}

private:
geometry_holder_t t_sizes;
geometry_holder_t t_strides;
};

template <>
class TensorGeometryHolder<0> {
using geometry_holder_t = Tensor;

public:
explicit TensorGeometryHolder(
IntArrayRef sizes,
IntArrayRef strides,
TensorOptions options) {
const int64_t t_ndims = sizes.size();
const auto cpu_options = TensorOptions(options).dtype(kLong).device(kCPU);
Tensor t_sizes_and_strides_cpu = at::empty({2, t_ndims}, cpu_options);
t_sizes_and_strides_cpu.select(0, 0).copy_(at::tensor(sizes, cpu_options));
t_sizes_and_strides_cpu.select(0, 1).copy_(
at::tensor(strides, cpu_options));
const Tensor t_sizes_and_strides =
t_sizes_and_strides_cpu.to(options.device());
t_sizes = t_sizes_and_strides.select(0, 0);
t_strides = t_sizes_and_strides.select(0, 1);
}

explicit TensorGeometryHolder(const Tensor& t)
: TensorGeometryHolder(t.sizes(), t.strides(), t.options()) {}

auto operator*() const {
return std::make_tuple(
t_sizes.template data_ptr<int64_t>(),
t_strides.template data_ptr<int64_t>());
}

private:
geometry_holder_t t_sizes;
geometry_holder_t t_strides;
};

} // namespace sparse
} // namespace at
54 changes: 32 additions & 22 deletions aten/src/ATen/native/cuda/SparseBinaryOpIntersectionKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/KernelUtils.cuh>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/AccumulateType.h>

namespace at::native {

Expand All @@ -28,10 +29,10 @@ FUNCAPI INLINE bool MulOp::apply(bool a, bool b) {
return a && b;
}

struct LhsProjOp {
struct RhsProjOp {
template <typename scalar_t>
static FUNCAPI scalar_t apply(scalar_t a, scalar_t b) {
return a;
return b;
}
};

Expand Down Expand Up @@ -68,11 +69,12 @@ template <typename binary_op_t, typename scalar_t, typename index_t>
void binary_op_intersection_kernel(
TensorIterator& iter,
int64_t lhs_nnz_stride,
int64_t rhs_nnz_stride) {
int64_t rhs_nnz_stride,
const Tensor& argsort) {
if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
binary_op_intersection_kernel<binary_op_t, scalar_t, index_t>(
sub_iter, lhs_nnz_stride, rhs_nnz_stride);
sub_iter, lhs_nnz_stride, rhs_nnz_stride, argsort);
}
return;
}
Expand All @@ -82,7 +84,8 @@ void binary_op_intersection_kernel(
const auto* RESTRICT ptr_lhs_select_idx_bytes = reinterpret_cast<char*>(iter.data_ptr(2));
const auto* RESTRICT ptr_rhs_values_bytes = reinterpret_cast<char*>(iter.data_ptr(3));
const auto* RESTRICT ptr_rhs_select_idx_bytes = reinterpret_cast<char*>(iter.data_ptr(4));
const auto* RESTRICT ptr_match_bytes = reinterpret_cast<char*>(iter.data_ptr(5));
const auto* RESTRICT ptr_intersction_counts_bytes = reinterpret_cast<char*>(iter.data_ptr(5));
const auto* RESTRICT ptr_argsort = argsort.data_ptr<index_t>();

auto offset_calc = make_offset_calculator<6>(iter);
auto loop = [=] FUNCAPI (int i) {
Expand All @@ -93,15 +96,22 @@ void binary_op_intersection_kernel(
const auto lhs_nnz_idx = *reinterpret_cast<const index_t*>(ptr_lhs_select_idx_bytes + offsets[2]);
const auto* RESTRICT ptr_rhs_values = reinterpret_cast<const scalar_t*>(ptr_rhs_values_bytes + offsets[3]);
const auto rhs_nnz_idx = *reinterpret_cast<const index_t*>(ptr_rhs_select_idx_bytes + offsets[4]);
const auto match = *reinterpret_cast<const bool*>(ptr_match_bytes + offsets[5]);

if (match) {
*ptr_res_values = binary_op_t::apply(
*(ptr_lhs_values + lhs_nnz_idx * lhs_nnz_stride),
*(ptr_rhs_values + rhs_nnz_idx * rhs_nnz_stride));
} else {
*ptr_res_values = 0;
const auto count = *reinterpret_cast<const int64_t*>(ptr_intersction_counts_bytes + offsets[5]);

const auto* RESTRICT ptr_lhs_begin = ptr_lhs_values + lhs_nnz_idx * lhs_nnz_stride;
const auto* RESTRICT ptr_rhs_sorted_nnz_idx = ptr_argsort + rhs_nnz_idx;

using accscalar_t = at::acc_type<scalar_t, /*is_gpu=*/true>;
accscalar_t res_values = 0;
accscalar_t lhs_values = static_cast<accscalar_t>(*ptr_lhs_begin);
accscalar_t rhs_values;
index_t rhs_sorted_nnz_idx;
for (int64_t c = 0; c < count; ++c) {
rhs_sorted_nnz_idx = *ptr_rhs_sorted_nnz_idx++;
rhs_values = static_cast<accscalar_t>(*(ptr_rhs_values + rhs_sorted_nnz_idx * rhs_nnz_stride));
res_values += binary_op_t::apply(lhs_values, rhs_values);
}
*ptr_res_values = static_cast<scalar_t>(res_values);
};

launch_kernel<num_threads(), thread_work_size()>(iter.numel(), loop);
Expand All @@ -115,13 +125,14 @@ struct CUDAValueSelectionIntersectionKernel {
const Tensor& lhs_select_idx,
const Tensor& rhs_values,
const Tensor& rhs_select_idx,
const c10::optional<Tensor>& match_mask = c10::nullopt) {
const Tensor& intersection_counts,
const Tensor& argsort) {
auto iter = make_value_selection_intersection_iter(
lhs_values,
lhs_select_idx,
rhs_values,
rhs_select_idx,
match_mask);
intersection_counts);
auto res_values = iter.tensor(0);

// If res_values is empty, we can return it right away.
Expand All @@ -136,11 +147,10 @@ struct CUDAValueSelectionIntersectionKernel {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
ScalarType::Bool, ScalarType::Half, ScalarType::BFloat16, res_values.scalar_type(),
"binary_op_intersection_cpu", [&] {
AT_DISPATCH_INDEX_TYPES(lhs_select_idx.scalar_type(),
"binary_op_intersection_cpu", [&] {
binary_op_intersection_kernel<binary_op_t, scalar_t, index_t>(
iter, lhs_nnz_stride, rhs_nnz_stride);
});
// COO indices are only 64-bit for now.
using index_t = int64_t;
binary_op_intersection_kernel<binary_op_t, scalar_t, index_t>(
iter, lhs_nnz_stride, rhs_nnz_stride, argsort);
});

return res_values;
Expand All @@ -161,8 +171,8 @@ void sparse_mask_intersection_out_cuda_kernel(
Tensor& result,
const Tensor& x,
const Tensor& y) {
using CUDAValueLhsProjKernel = CUDAValueSelectionIntersectionKernel<LhsProjOp>;
_sparse_binary_op_intersection_kernel_out<CUDAKernelLauncher, CUDAValueLhsProjKernel>(
using CUDAValueRhsProjKernel = CUDAValueSelectionIntersectionKernel<RhsProjOp>;
_sparse_binary_op_intersection_kernel_out<CUDAKernelLauncher, CUDAValueRhsProjKernel>(
result, x, y, true
);
}
Expand Down
Loading