Skip to content

Port THCS to ATen.#8689

Merged
ezyang merged 34 commits intopytorch:masterfrom
ezyang:pr/thcs-to-aten
Jun 24, 2018
Merged

Port THCS to ATen.#8689
ezyang merged 34 commits intopytorch:masterfrom
ezyang:pr/thcs-to-aten

Conversation

@ezyang
Copy link
Contributor

@ezyang ezyang commented Jun 20, 2018

General structure of the sparse implementation:

  • SparseCUDATensor.{cpp, cu} and SparseCUDATensorMath.cu contain
    the same functions as their CPU analogues
  • SparseCUDAApplyUtils.cuh contains what used to be in
    THCSTensor.cu
  • SparseCUDABlas.cu contains what used to be THCSparse.cu

Unrelated improvements:

  • Forward declared CUDA types in Context.h are now moved
    exclusively to CUDAHooks
  • New getCurrentCUDASparseHandle in Context
  • Support for printing CUSPARSE_STATUS_ZERO_PIVOT error message
    directly

Some unusual pieces:

  • get_device got the LegacyBridge makeover, as it needs special
    logic on sparse tensors (defer to the inner tensors).
  • I noticed that I need to turn off device_guard codegen
    for many functions in sparse, noticed because get_device
    became a native function, and resulted in an infinite recursion. This was
    done by adding device_guard: False to the native definitions. An alternative
    strategy might be to make the heuristic for deciding when to put in a device
    guard more clever.

Scaffolding removal:

  • LegacyBridge now special-cases only on sparse versus dense;
    no more CUDA test (hooray!)
  • Native bindings get CUDA/SparseCUDA dispatch entries.

CPU sparse refactoring:

  • New SparseUtils.h header, with all of the utility functions that
    used to live in SparseTensor.cpp
  • new_with_tensor_sparse now correctly handles both CPU and CUDA
  • transpose functions in sparse/ turned out to be dead, so I killed them

Bugs I noticed while working on this:

  • I used accessor<...>() on a CUDA tensor, because I thought it does
    the CUDA-CPU sync. It does not.

TODO:

  • For sparse only methods, we can now remove the TH binding
    entirely

Signed-off-by: Edward Z. Yang ezyang@fb.com

@ezyang ezyang requested a review from gchanan June 20, 2018 15:25
@ezyang
Copy link
Contributor Author

ezyang commented Jun 20, 2018

@pytorchbot retest this please

} else {
return th_add(self, other, alpha);
}
// See Note [CPU sparse is globally native] and Note [Multiple dispatch to sparse]

This comment was marked as off-topic.

} else {
return th_add_(self, other, alpha);
}
// See Note [CPU sparse is globally native] and Note [Multiple dispatch to sparse]

This comment was marked as off-topic.

} else {
return th_addmm_out(result, self, mat1, mat2, beta, alpha);
}
// See Note [CPU sparse is globally native] and Note [Multiple dispatch to sparse]

This comment was marked as off-topic.

This comment was marked as off-topic.

} else {
return th_addmm(self, mat1, mat2, beta, alpha);
}
// See Note [CPU sparse is globally native] and Note [Multiple dispatch to sparse]

This comment was marked as off-topic.

device_guard: False


- func: native_get_device(Tensor self) -> int64_t

This comment was marked as off-topic.

This comment was marked as off-topic.


// TODO: Expose this for real in ATen, some day?
// NB: Doesn't preserve data.
inline Tensor _new_values_with_size_of(const Tensor& values, int64_t nnz) {

This comment was marked as off-topic.

This comment was marked as off-topic.

", mask is on device ", mask.get_device(), ", out is on device ", r.get_device());
resize_as_sparse_(r, mask);
if (mask._nnz() == 0) {
r.zero_();

This comment was marked as off-topic.

_get_sparse_impl(r)->set_coalesced(mask.is_coalesced());
_get_sparse_impl(r)->set_nnz(mask._nnz());

LongTensor indices = at::zeros({mask._nnz()}, mask_indices.type());

This comment was marked as off-topic.

This comment was marked as off-topic.

if(n == 1)
*ldb = k;
}
THError("Internal error! This API is deprecated. Shout if you need it.");

This comment was marked as off-topic.

This comment was marked as off-topic.

TensorInfo<indexT, IndexType> indices,
TensorInfo<Real, IndexType> values,
const IndexType nnz) {
IndexType indskip = indices.strides[0];

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@ezyang ezyang force-pushed the pr/thcs-to-aten branch 2 times, most recently from 2d1d33b to d9cc44a Compare June 21, 2018 13:24
@ezyang ezyang force-pushed the pr/thcs-to-aten branch 2 times, most recently from 319903b to 5d61466 Compare June 22, 2018 13:09
ezyang added 18 commits June 22, 2018 08:49
General structure of the sparse implementation:
- SparseCUDATensor.{cpp, cu} and SparseCUDATensorMath.cu contain
  the same functions as their CPU analogues
- SparseCUDAApplyUtils.cuh contains what used to be in
  THCSTensor.cu
- SparseCUDABlas.cu contains what used to be THCSparse.cu

Unrelated improvements:
- Forward declared CUDA types in Context.h are now moved
  exclusively to CUDAHooks
- New getCurrentCUDASparseHandle in Context
- Support for printing CUSPARSE_STATUS_ZERO_PIVOT error message
  directly

Some unusual pieces:
- get_device got the LegacyBridge makeover, as it needs special
  logic on sparse tensors (defer to the inner tensors).
- I noticed that I need to turn off device_guard codegen
  for many functions in sparse, noticed because get_device
  became a native function, and resulted in an infinite recursion.  This was
  done by adding device_guard: False to the native definitions.  An alternative
  strategy might be to make the heuristic for deciding when to put in a device
  guard more clever.

Scaffolding removal:
- LegacyBridge now special-cases only on sparse versus dense;
  no more CUDA test (hooray!)
- Native bindings get CUDA/SparseCUDA dispatch entries.

CPU sparse refactoring:
- New SparseUtils.h header, with all of the utility functions that
  used to live in SparseTensor.cpp
- new_with_tensor_sparse now correctly handles both CPU and CUDA
- transpose functions in sparse/ turned out to be dead, so I killed them

Bugs I noticed while working on this:
- I used accessor<...>() on a CUDA tensor, because I thought it does
  the CUDA-CPU sync.  It does not.

TODO:
- For sparse only methods, we can now remove the TH binding
  entirely

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
ezyang added 10 commits June 22, 2018 08:49
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
@ezyang ezyang force-pushed the pr/thcs-to-aten branch from d326be0 to c13b092 Compare June 22, 2018 15:49
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Copy link
Contributor

@gchanan gchanan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hooray! Some minor comments / nits, but this looks ready to go once those are addressed.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

// NB: Doesn't preserve data.
inline Tensor _new_values_with_size_of(const Tensor& values, int64_t nnz) {
if (values.numel() == 0) { // values tensor uninitialized
// TODO: This logic looks bogus; if we have an uninitialized

This comment was marked as off-topic.

This comment was marked as off-topic.

// TODO: This error message seems awfully opaque
AT_CHECK(sparse_._sparseDims() == 2, "matrices expected, got ", sparse_._sparseDims(), "D tensor");
AT_CHECK(sparse_._denseDims() == 0, "scalar values expected, got ", sparse_._denseDims(), "D values");
AT_CHECK(dense.dim() == 2, "matrices expected, got ", dense.dim(), "D tensor");

This comment was marked as off-topic.

This comment was marked as off-topic.

"Argument #2: matrices expected, got ", sparse_._sparseDims(), "D tensor");
AT_CHECK(sparse_._denseDims() == 0,
"Argument #2: scalar values expected, got ", sparse_._denseDims(), "D values");
AT_CHECK(dense.dim() == 2,

This comment was marked as off-topic.

This comment was marked as off-topic.

SparseTensor& s_mul_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t_, const SparseTensor& src_) {
#ifndef __HIP_PLATFORM_HCC__
AT_CHECK(_check_device({r_, t_, src_}));
AT_CHECK(t_.sizes().equals(src_.sizes()), "mul operands have incompatible sizes");

This comment was marked as off-topic.

ezyang added 5 commits June 22, 2018 12:56
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
…elds in env when they are dead.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
@ezyang ezyang merged commit 3598356 into pytorch:master Jun 24, 2018
facebook-github-bot pushed a commit that referenced this pull request Aug 24, 2018
Summary:
**Summary**: This PR is a followup of mruberry's #9318. It tries to achieve the following:
- Specializing std common math functions for `at::Half` type.
- Create `CUDANumerics.cuh` to contain necessary parts from `THCNumerics.cuh`.
- Update `THCNumerics.cuh` with new usage and comments to  demonstrate the best practice for developers and hence, making way for its deprecation.
- Remove legacy/redundant code path.
- Remove unused CUDA HALF macros (see separate PR #10147)

**Comments**: `CUDANumerics.cuh` contains mathematical functions that are either not in the std namespace or are specialized for compilation with CUDA NVCC or CUDA NVRTC. This header is derived from the legacy `THCNumerics.cuh`. Following are some rationale behind why some functions were kept while others were removed:
- All arithmetic can now be done in ATen using binary cuda kernel  or CUDA tensor pointwise apply (check #8919 and `CUDAApplyUtils`). `at::Half` comparisons rely on implicit conversion to float.
- Functions that are c/c++ standard compliant, have been specialized for user defined types, for instance, the std namespace has been opened up for `at::Half`, that defines math function definitions for `at::Half`. Check `Half-inl.h`
- Some standard compliant functions are specialized here for performance reasons. For instance, `powi` is used for `pow` calculation on integral types. Moreover, `abs`, `isinf`, `isnan` are specialized to save one API call vs when used with std. Although this is subject to change, depending on if we really care about saving one API call.
- Numeric limits such as `max/min` is removed since they call standard defines. Moreover, numeric limits for
`at::Half` is present in `Half-inl.h`. I understood that HIP has some issue with `std::numeric_limits` and this the related github issue I found: ROCm/hip#374. AlexVlx mentions that the issue can be avoided by launching `std::numeric_limits` in `__device__`. Since, we are launching lambdas with device contexts, I don't see an issue why `std::numeric_limits` won't compile on HIP if launched with device context within a kernel, unless I am not aware of the real reason why max/min was there in THCNumerics in the first place. (Haven't ever tried a build with HIP).

Here are some reference PRs that was handy in refactoring TH into ATen:
- #6786
- #5475
- #9401
- #8689
- #8919
Pull Request resolved: #10301

Differential Revision: D9204758

Pulled By: soumith

fbshipit-source-id: 09f489c1656458c02367b6cd31c3eeeca5acdc8a
zdevito pushed a commit to zdevito/ATen that referenced this pull request Aug 25, 2018
Summary:
**Summary**: This PR is a followup of mruberry's pytorch/pytorch#9318. It tries to achieve the following:
- Specializing std common math functions for `at::Half` type.
- Create `CUDANumerics.cuh` to contain necessary parts from `THCNumerics.cuh`.
- Update `THCNumerics.cuh` with new usage and comments to  demonstrate the best practice for developers and hence, making way for its deprecation.
- Remove legacy/redundant code path.
- Remove unused CUDA HALF macros (see separate PR pytorch/pytorch#10147)

**Comments**: `CUDANumerics.cuh` contains mathematical functions that are either not in the std namespace or are specialized for compilation with CUDA NVCC or CUDA NVRTC. This header is derived from the legacy `THCNumerics.cuh`. Following are some rationale behind why some functions were kept while others were removed:
- All arithmetic can now be done in ATen using binary cuda kernel  or CUDA tensor pointwise apply (check pytorch/pytorch#8919 and `CUDAApplyUtils`). `at::Half` comparisons rely on implicit conversion to float.
- Functions that are c/c++ standard compliant, have been specialized for user defined types, for instance, the std namespace has been opened up for `at::Half`, that defines math function definitions for `at::Half`. Check `Half-inl.h`
- Some standard compliant functions are specialized here for performance reasons. For instance, `powi` is used for `pow` calculation on integral types. Moreover, `abs`, `isinf`, `isnan` are specialized to save one API call vs when used with std. Although this is subject to change, depending on if we really care about saving one API call.
- Numeric limits such as `max/min` is removed since they call standard defines. Moreover, numeric limits for
`at::Half` is present in `Half-inl.h`. I understood that HIP has some issue with `std::numeric_limits` and this the related github issue I found: ROCm/hip#374. AlexVlx mentions that the issue can be avoided by launching `std::numeric_limits` in `__device__`. Since, we are launching lambdas with device contexts, I don't see an issue why `std::numeric_limits` won't compile on HIP if launched with device context within a kernel, unless I am not aware of the real reason why max/min was there in THCNumerics in the first place. (Haven't ever tried a build with HIP).

Here are some reference PRs that was handy in refactoring TH into ATen:
- pytorch/pytorch#6786
- pytorch/pytorch#5475
- pytorch/pytorch#9401
- pytorch/pytorch#8689
- pytorch/pytorch#8919
Pull Request resolved: pytorch/pytorch#10301

Differential Revision: D9204758

Pulled By: soumith

fbshipit-source-id: 09f489c1656458c02367b6cd31c3eeeca5acdc8a
PenghuiCheng pushed a commit to PenghuiCheng/pytorch that referenced this pull request Sep 11, 2018
…rch#10301)

Summary:
**Summary**: This PR is a followup of mruberry's pytorch#9318. It tries to achieve the following:
- Specializing std common math functions for `at::Half` type.
- Create `CUDANumerics.cuh` to contain necessary parts from `THCNumerics.cuh`.
- Update `THCNumerics.cuh` with new usage and comments to  demonstrate the best practice for developers and hence, making way for its deprecation.
- Remove legacy/redundant code path.
- Remove unused CUDA HALF macros (see separate PR pytorch#10147)

**Comments**: `CUDANumerics.cuh` contains mathematical functions that are either not in the std namespace or are specialized for compilation with CUDA NVCC or CUDA NVRTC. This header is derived from the legacy `THCNumerics.cuh`. Following are some rationale behind why some functions were kept while others were removed:
- All arithmetic can now be done in ATen using binary cuda kernel  or CUDA tensor pointwise apply (check pytorch#8919 and `CUDAApplyUtils`). `at::Half` comparisons rely on implicit conversion to float.
- Functions that are c/c++ standard compliant, have been specialized for user defined types, for instance, the std namespace has been opened up for `at::Half`, that defines math function definitions for `at::Half`. Check `Half-inl.h`
- Some standard compliant functions are specialized here for performance reasons. For instance, `powi` is used for `pow` calculation on integral types. Moreover, `abs`, `isinf`, `isnan` are specialized to save one API call vs when used with std. Although this is subject to change, depending on if we really care about saving one API call.
- Numeric limits such as `max/min` is removed since they call standard defines. Moreover, numeric limits for
`at::Half` is present in `Half-inl.h`. I understood that HIP has some issue with `std::numeric_limits` and this the related github issue I found: ROCm/hip#374. AlexVlx mentions that the issue can be avoided by launching `std::numeric_limits` in `__device__`. Since, we are launching lambdas with device contexts, I don't see an issue why `std::numeric_limits` won't compile on HIP if launched with device context within a kernel, unless I am not aware of the real reason why max/min was there in THCNumerics in the first place. (Haven't ever tried a build with HIP).

Here are some reference PRs that was handy in refactoring TH into ATen:
- pytorch#6786
- pytorch#5475
- pytorch#9401
- pytorch#8689
- pytorch#8919
Pull Request resolved: pytorch#10301

Differential Revision: D9204758

Pulled By: soumith

fbshipit-source-id: 09f489c1656458c02367b6cd31c3eeeca5acdc8a
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants