Migrate equal from the TH to Aten (CPU)#33286
Conversation
💊 CI failures summary and remediationsAs of commit f72b7f4 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 42 times. |
There was a problem hiding this comment.
Should we also support Half here?
There was a problem hiding this comment.
Please use TensorIterator instead. This is perfect case of reduction.
VitalyFedyunin
left a comment
There was a problem hiding this comment.
CPU_tensor_apply2 is going away eventually, please use TensorIterator instead.
|
Thanks for the quick review. |
There was a problem hiding this comment.
Is it possible to move this part to aten/src/ATen/native/cpu/BinaryOpsKernel.cpp? There you can take advantage of SIMD instructions and the code structure is cleaner (especially after the CUDA version is also migrated). You can look into the other binary operators as examples.
There was a problem hiding this comment.
Thanks for the quick review.
I moved this to BinaryOpsKernel.
There was a problem hiding this comment.
Shouldn't specify a return type here (It's causing build errors). Capture can be more specific [&equal]
There was a problem hiding this comment.
It's weird that the unit test will fail when not specify a return type here. But test will be passed with a return type.
VitalyFedyunin
left a comment
There was a problem hiding this comment.
Please add bfloat16 support and add tests to cover it. Everything else looks good.
|
@xcnick The other way is to reuse foreach_reduced_elt(func(subiter)) in tensoriterator. |
glaringlee
left a comment
There was a problem hiding this comment.
@xcnick
I commented this PR above, let me know if you need more info on how to do tensor reduction.
|
sorry for the late response. I would like to use the first way, so I add and add then add add Am I doing it right? |
|
@xcnik need some tweaks in that function to support more inputs. |
|
Hey @xcnick, if you think changing the reduction code is too much for you, you can do the same thing as this PR. The at::eq support both cpu and cuda. And rename your equal() function to cpu_equal (change it in all related places), then you are done :) |
|
Hi, I modified the code as you said. Benchmarking results was updated. |
|
ping @glaringlee |
|
@xcnick @VitalyFedyunin |
|
@xcnick |
There was a problem hiding this comment.
@xcnick
Based on your benchmark, I found one problem when directly calling at::native::eq directly.
for contiguous non equal test, the performance of the at::native::eq is very bad, the reason is that, at::native::eq will compare every element between the two tensors, even an unequal case is already found, it won't stop. This will also slow down the non-contiguous non equal case, but not as bad as contiguous case.
Let's do something else here:
Instead of calling eq directly, let's implement our own loop function. You can put the following code here and remove the 'return at::native::eq' line.
std::atomic<bool> result{true};
auto iter = TensorIteratorConfig()
.add_input(self)
.add_input(other)
.allow_cpu_scalars(true)
.promote_inputs_to_common_dtype(true)
.build();
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "equal_cpu", [&]{
iter.for_each([&](char** data, const int64_t *strides, int64_t dim_size){
if (!result) {
return;
}
char* self_data = data[0];
char* other_data = data[1];
for (int64_t i = 0; i < dim_size; ++i) {
if (*((scalar_t*)self_data) != *((scalar_t*)other_data)) {
result = false;
return;
}
self_data += strides[0];
other_data += strides[1];
}
});
});
return result.load();
Let me explain a little bit. Since we just need to return a bool, so there is no need to even have an output tensor. So we initialize a tensor iterator with two input only, and keep track a global boolean 'result'
The tensor iterator has a for_each function which takes a function reference.
This reference has signatures like this: loop(char** data, const int64_t* strides, int size);
Tensor iterator internally, chunk the whole tensor into many 2D planes, and process them row by row, this loop function is used to iterator each row. 'data' array contains the starting point of current row for each input tensor, 'strides' array contains stride of current row (dimension) for each input tensor, the tensor orders for both 'data' and 'strides' are the same as what you initialize the tensor iterator. And the 'size' is the row length.
This code dispatch the iter.for_each() to all the supported data type, and if there is any non-equal element found between two tensors, it stopped immediately. For contiguous tensor, this means quit the whole tensor iteration since we totally have one continuous data array in this case. If tensor is not contiguous, there is still overhead, since in this case, the program will jump to next row and immediately break, and jump to next row, break, until looped entire tensors. But it already saved many time compare calling at::native::eq directly, and plus there is no tensor write.
Please make this change in and update the benchmark and rebase. I think we are good to go then.
| if (!self.is_same_size(other)) { | ||
| return false; | ||
| } | ||
| bool result = true; |
There was a problem hiding this comment.
@xcnick
Should this be atomic? I updated my comments around 1hr ago, I think your github page is cached.....
|
@glaringlee Thanks for your help! |
facebook-github-bot
left a comment
There was a problem hiding this comment.
@glaringlee has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
glaringlee
left a comment
There was a problem hiding this comment.
LGTM Now.
@xcnick Thanks a lot for your contribution!!
|
@glaringlee merged this pull request in 72f2c47. |
Summary: pytorch#24697 VitalyFedyunin glaringlee Test script: ```Python import timeit setup_ones = """ import torch a = torch.ones(({n}, {n}), dtype={dtype}) b = torch.ones(({n}, {n}), dtype={dtype}) """ for n, t in [(1000, 10000), (2000, 10000)]: for dtype in ('torch.bool', 'torch.int', 'torch.long', 'torch.bfloat16', 'torch.float', 'torch.double'): #for dtype in ('torch.bool', 'torch.int', 'torch.long', 'torch.float', 'torch.double'): print('torch.ones(({n}, {n})) equal for {t} times {dtype}'.format(n=n, t=t, dtype=dtype)) print(timeit.timeit(stmt='torch.equal(a, b)', setup=setup_ones.format(n=n, dtype=dtype), number=t)) setup_rand = """ import torch a = torch.rand(({n}, {n}), dtype={dtype}) b = a.clone() """ for n, t in [(1000, 10000), (2000, 10000)]: for dtype in ('torch.float', 'torch.double'): print('torch.rand(({n}, {n})) for {t} times {dtype}'.format(n=n, t=t, dtype=dtype)) print(timeit.timeit(stmt='torch.equal(a, b)', setup=setup_rand.format(n=n, dtype=dtype), number=t)) setup_non_contiguous = """ import torch a = torch.rand(({n}, {n}), dtype={dtype}) a2 = a[:, 500:] a3 = a2.clone() torch.equal(a2, a3) """ for n, t in [(1000, 10000), (2000, 10000)]: for dtype in ('torch.float', 'torch.double'): print('non_contiguous torch.rand(({n}, {n})) for {t} times {dtype}'.format(n=n, t=t, dtype=dtype)) print(timeit.timeit(stmt='torch.equal(a2, a3)', setup=setup_non_contiguous.format(n=n, dtype=dtype), number=t)) setup_not_equal = """ import torch a = torch.rand(({n}, {n}), dtype={dtype}) b = torch.rand(({n}, {n}), dtype={dtype}) torch.equal(a, b) """ for n, t in [(1000, 10000), (2000, 10000)]: for dtype in ('torch.float', 'torch.double'): print('not equal torch.rand(({n}, {n})) for {t} times {dtype}'.format(n=n, t=t, dtype=dtype)) print(timeit.timeit(stmt='torch.equal(a, b)', setup=setup_not_equal.format(n=n, dtype=dtype), number=t)) ``` TH ``` torch.ones((1000, 1000)) equal for 10000 times torch.bool 1.8391206220258027 torch.ones((1000, 1000)) equal for 10000 times torch.int 1.8877864250680432 torch.ones((1000, 1000)) equal for 10000 times torch.long 1.938108820002526 torch.ones((1000, 1000)) equal for 10000 times torch.bfloat16 3.184849138953723 torch.ones((1000, 1000)) equal for 10000 times torch.float 1.8825413499725983 torch.ones((1000, 1000)) equal for 10000 times torch.double 2.7266416549682617 torch.ones((2000, 2000)) equal for 10000 times torch.bool 7.227149627986364 torch.ones((2000, 2000)) equal for 10000 times torch.int 7.76215292501729 torch.ones((2000, 2000)) equal for 10000 times torch.long 9.631909006042406 torch.ones((2000, 2000)) equal for 10000 times torch.bfloat16 8.097328286035918 torch.ones((2000, 2000)) equal for 10000 times torch.float 5.5739822529722005 torch.ones((2000, 2000)) equal for 10000 times torch.double 8.444009944912978 torch.rand((1000, 1000)) for 10000 times torch.float 1.168096570065245 torch.rand((1000, 1000)) for 10000 times torch.double 1.6577326939441264 torch.rand((2000, 2000)) for 10000 times torch.float 5.49395391496364 torch.rand((2000, 2000)) for 10000 times torch.double 8.507486199960113 non_contiguous torch.rand((1000, 1000)) for 10000 times torch.float 6.074504268006422 non_contiguous torch.rand((1000, 1000)) for 10000 times torch.double 6.1426916810451075 non_contiguous torch.rand((2000, 2000)) for 10000 times torch.float 37.501055537955835 non_contiguous torch.rand((2000, 2000)) for 10000 times torch.double 44.6880351039581 not equal torch.rand((1000, 1000)) for 10000 times torch.float 0.029356416082009673 not equal torch.rand((1000, 1000)) for 10000 times torch.double 0.025421109050512314 not equal torch.rand((2000, 2000)) for 10000 times torch.float 0.026333761983551085 not equal torch.rand((2000, 2000)) for 10000 times torch.double 0.02748022007290274 ``` ATen ``` torch.ones((1000, 1000)) equal for 10000 times torch.bool 0.7961567062884569 torch.ones((1000, 1000)) equal for 10000 times torch.int 0.49172434909269214 torch.ones((1000, 1000)) equal for 10000 times torch.long 0.9459248608909547 torch.ones((1000, 1000)) equal for 10000 times torch.bfloat16 2.0877483217045665 torch.ones((1000, 1000)) equal for 10000 times torch.float 0.606857153121382 torch.ones((1000, 1000)) equal for 10000 times torch.double 1.1388208279386163 torch.ones((2000, 2000)) equal for 10000 times torch.bool 2.0329296849668026 torch.ones((2000, 2000)) equal for 10000 times torch.int 3.534358019940555 torch.ones((2000, 2000)) equal for 10000 times torch.long 8.19841272290796 torch.ones((2000, 2000)) equal for 10000 times torch.bfloat16 6.595649406313896 torch.ones((2000, 2000)) equal for 10000 times torch.float 4.193911510054022 torch.ones((2000, 2000)) equal for 10000 times torch.double 7.931309659034014 torch.rand((1000, 1000)) for 10000 times torch.float 0.8877940969541669 torch.rand((1000, 1000)) for 10000 times torch.double 1.4142901846207678 torch.rand((2000, 2000)) for 10000 times torch.float 4.010025603231043 torch.rand((2000, 2000)) for 10000 times torch.double 8.126411964651197 non_contiguous torch.rand((1000, 1000)) for 10000 times torch.float 0.602473056409508 non_contiguous torch.rand((1000, 1000)) for 10000 times torch.double 0.6784545010887086 non_contiguous torch.rand((2000, 2000)) for 10000 times torch.float 3.0991827426478267 non_contiguous torch.rand((2000, 2000)) for 10000 times torch.double 5.719010795000941 not equal torch.rand((1000, 1000)) for 10000 times torch.float 0.046060710679739714 not equal torch.rand((1000, 1000)) for 10000 times torch.double 0.036034489050507545 not equal torch.rand((2000, 2000)) for 10000 times torch.float 0.03686975734308362 not equal torch.rand((2000, 2000)) for 10000 times torch.double 0.04189508780837059 ``` Pull Request resolved: pytorch#33286 Differential Revision: D22211962 Pulled By: glaringlee fbshipit-source-id: a5c48f328432c1996f28e19bc75cb495fb689f6b
#24697
@VitalyFedyunin
@glaringlee
Test script:
TH
ATen