move softmax/logsoftmax to ATen#6786
Conversation
|
@onnxbot retest this please |
|
a naming question: we have |
|
@vadimkantorov good point, I can rename if there's an agreement from core devs. |
|
CC @colesbury on the legacy test issues. Last time I chatted with folks about it, we weren't planning to delete the legacy support, so something might have to be done. |
Yes, that would be great. I'd probably make another version of this template for "AccAccumulateType" or similar (following the TH convention)
I'm a little perplexed by the double-backwards situation; if CC @gchanan about scalars |
|
|
Re 3: that makes sense, but IIRC that's not what we used to do some time ago. Isn't it possible to rewrite the derivative of |
|
@apaszke No, the issue is different; none of the gradient formulas reference self. I think we should leave it for now but eventually we should figure out what exactly is going on here. I would have thought that maybe autograd should have worked out that if there's a differential on output, then you need to in turn compute the differential on self because output was computed from self. |
|
If they don't reference |
|
It's passed around so that double backward on self can be defined, in the 3-line snippet I posted above. It's no worse than it is now (THNN/THCUNN also does not use input in backward, yet it is there as an argument) |
|
Hmm I see now. Alright, that sounds a bit complicated, so we don't have to block on that. |
|
Rebased on master, fixed legacy by calling |
|
@apaszke what's your view on making log_softmax and logsigmoid naming consistent? #6786 (comment) |
|
If there's something to be renamed it should be |
apaszke
left a comment
There was a problem hiding this comment.
Not a complete review by any means, but looks ok at a glance. Just curious why some T changed to AccumT in the kernels
torch/nn/functional.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/cuda/SoftMax.cu
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/cuda/SoftMax.cu
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@ngimel can you kindly rebase please? :) |
|
I don't see an actual error on failed trusty build, other than gcc exited with error code 1 - did it time out? For failed xenial-cuda, the failure also does not look related (there's no softmax in the test script). |
|
Yeah, the gcc7.2 is a known flakiness #7202 @pytorchbot retest this please |
|
Hmm, cuda9-cudnn7 failed again, how do I run test/cpp/ tests locally? |
|
Interestingly, I can make adagrad fail on master by commenting out integration.cpp from test_api sources (because I was too lazy to wait for mnist training), in this case adagrad goes for 3000 epochs with loss stuck at 0.48. Optimizer tests seem kind of flaky |
|
So given that the failing test is flaky (#7288), and this PR actually does not touch any of the things the failing test is testing this should be good to go? |
|
Hey @ngimel, are you planning on further changing the CPU implementation? I'm currently working on vectorizing softmax and I have just finished what you merged on my end as well haha. |
|
No, I did not plan on further optimizing CPU path, but now that it's in ATen it should be easier for you to work on. |
|
Thanks @ngimel! |
* move softmax/logsoftmax to ATen * specify cpu and gpu accum types * use accreal for CPU * expose softmax backward to python, fix legacy interface * fix Distributions.cu to use common AccumulateType * fix cuda 8 build * delete commented out lines * rebase on master, fix breakages
* move softmax/logsoftmax to ATen * specify cpu and gpu accum types * use accreal for CPU * expose softmax backward to python, fix legacy interface * fix Distributions.cu to use common AccumulateType * fix cuda 8 build * delete commented out lines * rebase on master, fix breakages
Summary: **Summary**: This PR is a followup of mruberry's #9318. It tries to achieve the following: - Specializing std common math functions for `at::Half` type. - Create `CUDANumerics.cuh` to contain necessary parts from `THCNumerics.cuh`. - Update `THCNumerics.cuh` with new usage and comments to demonstrate the best practice for developers and hence, making way for its deprecation. - Remove legacy/redundant code path. - Remove unused CUDA HALF macros (see separate PR #10147) **Comments**: `CUDANumerics.cuh` contains mathematical functions that are either not in the std namespace or are specialized for compilation with CUDA NVCC or CUDA NVRTC. This header is derived from the legacy `THCNumerics.cuh`. Following are some rationale behind why some functions were kept while others were removed: - All arithmetic can now be done in ATen using binary cuda kernel or CUDA tensor pointwise apply (check #8919 and `CUDAApplyUtils`). `at::Half` comparisons rely on implicit conversion to float. - Functions that are c/c++ standard compliant, have been specialized for user defined types, for instance, the std namespace has been opened up for `at::Half`, that defines math function definitions for `at::Half`. Check `Half-inl.h` - Some standard compliant functions are specialized here for performance reasons. For instance, `powi` is used for `pow` calculation on integral types. Moreover, `abs`, `isinf`, `isnan` are specialized to save one API call vs when used with std. Although this is subject to change, depending on if we really care about saving one API call. - Numeric limits such as `max/min` is removed since they call standard defines. Moreover, numeric limits for `at::Half` is present in `Half-inl.h`. I understood that HIP has some issue with `std::numeric_limits` and this the related github issue I found: ROCm/hip#374. AlexVlx mentions that the issue can be avoided by launching `std::numeric_limits` in `__device__`. Since, we are launching lambdas with device contexts, I don't see an issue why `std::numeric_limits` won't compile on HIP if launched with device context within a kernel, unless I am not aware of the real reason why max/min was there in THCNumerics in the first place. (Haven't ever tried a build with HIP). Here are some reference PRs that was handy in refactoring TH into ATen: - #6786 - #5475 - #9401 - #8689 - #8919 Pull Request resolved: #10301 Differential Revision: D9204758 Pulled By: soumith fbshipit-source-id: 09f489c1656458c02367b6cd31c3eeeca5acdc8a
Summary: **Summary**: This PR is a followup of mruberry's pytorch/pytorch#9318. It tries to achieve the following: - Specializing std common math functions for `at::Half` type. - Create `CUDANumerics.cuh` to contain necessary parts from `THCNumerics.cuh`. - Update `THCNumerics.cuh` with new usage and comments to demonstrate the best practice for developers and hence, making way for its deprecation. - Remove legacy/redundant code path. - Remove unused CUDA HALF macros (see separate PR pytorch/pytorch#10147) **Comments**: `CUDANumerics.cuh` contains mathematical functions that are either not in the std namespace or are specialized for compilation with CUDA NVCC or CUDA NVRTC. This header is derived from the legacy `THCNumerics.cuh`. Following are some rationale behind why some functions were kept while others were removed: - All arithmetic can now be done in ATen using binary cuda kernel or CUDA tensor pointwise apply (check pytorch/pytorch#8919 and `CUDAApplyUtils`). `at::Half` comparisons rely on implicit conversion to float. - Functions that are c/c++ standard compliant, have been specialized for user defined types, for instance, the std namespace has been opened up for `at::Half`, that defines math function definitions for `at::Half`. Check `Half-inl.h` - Some standard compliant functions are specialized here for performance reasons. For instance, `powi` is used for `pow` calculation on integral types. Moreover, `abs`, `isinf`, `isnan` are specialized to save one API call vs when used with std. Although this is subject to change, depending on if we really care about saving one API call. - Numeric limits such as `max/min` is removed since they call standard defines. Moreover, numeric limits for `at::Half` is present in `Half-inl.h`. I understood that HIP has some issue with `std::numeric_limits` and this the related github issue I found: ROCm/hip#374. AlexVlx mentions that the issue can be avoided by launching `std::numeric_limits` in `__device__`. Since, we are launching lambdas with device contexts, I don't see an issue why `std::numeric_limits` won't compile on HIP if launched with device context within a kernel, unless I am not aware of the real reason why max/min was there in THCNumerics in the first place. (Haven't ever tried a build with HIP). Here are some reference PRs that was handy in refactoring TH into ATen: - pytorch/pytorch#6786 - pytorch/pytorch#5475 - pytorch/pytorch#9401 - pytorch/pytorch#8689 - pytorch/pytorch#8919 Pull Request resolved: pytorch/pytorch#10301 Differential Revision: D9204758 Pulled By: soumith fbshipit-source-id: 09f489c1656458c02367b6cd31c3eeeca5acdc8a
…rch#10301) Summary: **Summary**: This PR is a followup of mruberry's pytorch#9318. It tries to achieve the following: - Specializing std common math functions for `at::Half` type. - Create `CUDANumerics.cuh` to contain necessary parts from `THCNumerics.cuh`. - Update `THCNumerics.cuh` with new usage and comments to demonstrate the best practice for developers and hence, making way for its deprecation. - Remove legacy/redundant code path. - Remove unused CUDA HALF macros (see separate PR pytorch#10147) **Comments**: `CUDANumerics.cuh` contains mathematical functions that are either not in the std namespace or are specialized for compilation with CUDA NVCC or CUDA NVRTC. This header is derived from the legacy `THCNumerics.cuh`. Following are some rationale behind why some functions were kept while others were removed: - All arithmetic can now be done in ATen using binary cuda kernel or CUDA tensor pointwise apply (check pytorch#8919 and `CUDAApplyUtils`). `at::Half` comparisons rely on implicit conversion to float. - Functions that are c/c++ standard compliant, have been specialized for user defined types, for instance, the std namespace has been opened up for `at::Half`, that defines math function definitions for `at::Half`. Check `Half-inl.h` - Some standard compliant functions are specialized here for performance reasons. For instance, `powi` is used for `pow` calculation on integral types. Moreover, `abs`, `isinf`, `isnan` are specialized to save one API call vs when used with std. Although this is subject to change, depending on if we really care about saving one API call. - Numeric limits such as `max/min` is removed since they call standard defines. Moreover, numeric limits for `at::Half` is present in `Half-inl.h`. I understood that HIP has some issue with `std::numeric_limits` and this the related github issue I found: ROCm/hip#374. AlexVlx mentions that the issue can be avoided by launching `std::numeric_limits` in `__device__`. Since, we are launching lambdas with device contexts, I don't see an issue why `std::numeric_limits` won't compile on HIP if launched with device context within a kernel, unless I am not aware of the real reason why max/min was there in THCNumerics in the first place. (Haven't ever tried a build with HIP). Here are some reference PRs that was handy in refactoring TH into ATen: - pytorch#6786 - pytorch#5475 - pytorch#9401 - pytorch#8689 - pytorch#8919 Pull Request resolved: pytorch#10301 Differential Revision: D9204758 Pulled By: soumith fbshipit-source-id: 09f489c1656458c02367b6cd31c3eeeca5acdc8a
THCUNN kernels are mostly unchanged, with minimum changes to types so that more intermediate values are preserved in AccumT.
Softmax/LogSoftmax from THNN are combined into a single templated function.
Remaining issues