Add AVX optimizations for pdist#11230
Conversation
|
@pytorchbot retest this please |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/Distance.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
What's the command you used to profile this? And also the script. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@cpuhrsch In general I've just been using ipythons The script for the comparison to scipy looks something like: import torch
import scipy.spatial.distance as spd
xt = torch.randn(2048, 2048)
xn = xt.numpy()
for p in [0, 1, 2, 3, float('inf')]:
if p == 0:
spdist = lambda x: spd.pdist(x, 'hamming') * x.shape[1]
elif p == float('inf'):
spdist = lambda x: spd.pdist(x, lambda row, col: np.abs(row - col).max())
else:
spdist = lambda x: spd.pdist(x, 'minkowski', p=p)
print(p)
%timeit torch.pdist(xt, p)
%timeit spdist(xn) |
|
@erikbrinkman - In general it should be worth running this on a few different magnitudes of tensors (so 10^3/4/5 elements) to see how it scales. |
|
Can https://pytorch.org/docs/stable/nn.html?highlight=pairwise#torch.nn.functional.pairwise_distance benefit from this vectorized code now? |
There was a problem hiding this comment.
Still need to resolve slowdown on DEFAULT capability and resolve some of the comments. We're talking offline as well. Will approve on author's request, however I'd like to see the changes to vec256 to be as specific to the required usage sides as possible (so, don't change functions you're not adding), just in case it causes slowdowns, since we're not looking at those here and it's out of scope.
|
@ezyang Unfortunately with regards to pairwise distance, this uses a completely different loop / function call. There are ways they could be combined, but it is nonobvious to someone with my knowledge of the code base. Also, for the purposes of this function is makes sense to mandate contiguous memory as each row is read (n - 1) times, but for pairwise distance they're only read once need to be bradcast, etc. Also, pairwise_distance calls |
97f2f76 to
4bb3c3b
Compare
|
@cpuhrsch The default dispatch is still slower than it was before introducing this change. The cause isn't clear other than it seems that the compiler is not eliding all of the Vec256 operations. If the answer is that lack of AVX is an uncommon use case and we don't care about performance there than this seems ready to push. If the answer is that we still want this to be reasonably fast than there are two ways to fix it: making the dispatch dispatch to the old code, or introducing tricks to get the compiler to figure out how to handle vec256_base functions appropriately. The later seems far outside the scope of this PR. In response to timing for various sizes single threaded, usually an average of 7 runs. These two graphs are with the other dimension set to 2048. Note that for changing the number of rows, the computation grows by n^2. The width of each bar corresponds to the uncertainty. The scaling on each seems appropriate. |
55bdf58 to
7a07d79
Compare
cpuhrsch
left a comment
There was a problem hiding this comment.
Looks good! I'll look at the DEFAULT branch in another PR.
3a73260 to
260c91f
Compare
|
@pytorchbot retest this please |
facebook-github-bot
left a comment
There was a problem hiding this comment.
erikbrinkman is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Added AVX optimizations for pdist using Vec256. This brings single threaded performance up to speed with scipy, but the current implementation greatly hurts performance without AVX enabled. Is there a way to special case out AVX on dispatch and call the non Vec256 code? Or is the way I used Vec256 completely wrong? Single threaded comparison to scipy ============================ This is the time to compute the pdist of a 2048 x 2048 float matrix with only one thread for various values of p between torch and scipy. p = 3 is the code path for arbitrary p, and so is much slower than the other values. p | torch | scipy -----|-----------|------ 0 | 6.27 s ± 393 ms | 7.23 s ± 498 ms 1 | 5.49 s ± 201 ms | 43.4 s ± 1.09 s 2 | 5.74 s ± 474 ms | 53.8 s ± 3.52 s ∞ | 5.59 s ± 292 ms | 47.4 s ± 2.03 s 3 | really slow | gave up Result by AVX support ================ This is the time to compute the distance and gradient of a 2048 x 2048 float matrix with all threads by AVX support. `before` is the old code, `default` is no AVX support, etc. Interestingly the AVX optimizations provided a great benefit over the old unoptimized code, but drastically hurt performance when compiled without AVX optimizations. p = 3 is the code path for arbitrary p, and so is much slower than the other values. Results for p = 0 ---------------- avx | dist | grad ----|------|----- before | 514 ms ± 87.5 ms | 191 µs ± 35 µs default | 3.47 s ± 183 ms | 201 µs ± 24.6 µs avx | 123 ms ± 18.2 ms | 281 µs ± 130 µs avx2 | 103 ms ± 11.4 ms | 216 µs ± 74.4 µs Results for p = 1 ---------------- avx | dist | grad ----|------|----- before | 426 ms ± 35 ms | 6.21 s ± 187 ms default | 2.6 s ± 123 ms | 5.62 s ± 273 ms avx | 104 ms ± 6.37 ms | 833 ms ± 44.3 ms avx2 | 106 ms ± 3.59 ms | 924 ms ± 86.2 ms Results for p = 2 ----------------- avx | dist | grad ----|------|----- before | 425 ms ± 45.4 ms | 6.31 s ± 125 ms default | 3.04 s ± 187 ms | 3.55 s ± 242 ms avx | 110 ms ± 3.66 ms | 896 ms ± 21.8 ms avx2 | 113 ms ± 4.68 ms | 934 ms ± 25.2 ms Results for p = ∞ ------------------ avx | dist | grad ----|------|----- before | 501 ms ± 39.5 ms | 6.64 s ± 321 ms default | 2.15 s ± 92.9 ms | 8.43 s ± 355 ms avx | 104 ms ± 5.52 ms | 835 ms ± 36.7 ms avx2 | 100 ms ± 3.41 ms | 864 ms ± 67 ms Results for p = 3 ----------------- avx | dist | grad ----|------|----- before | 22.6 s ± 413 ms | 11.1 s ± 242 ms default | 24.9 s ± 1 s | 11.2 s ± 293 ms avx | 2.69 s ± 148 ms | 5.63 s ± 88.4 ms avx2 | 2.48 s ± 31.8 ms | 5.61 s ± 114 ms Pull Request resolved: pytorch/pytorch#11230 Differential Revision: D9735503 Pulled By: erikbrinkman fbshipit-source-id: a9da619249e4ca2625b39ca1ca7f5543c3086bfb
Summary: Added AVX optimizations for pdist using Vec256. This brings single threaded performance up to speed with scipy, but the current implementation greatly hurts performance without AVX enabled. Is there a way to special case out AVX on dispatch and call the non Vec256 code? Or is the way I used Vec256 completely wrong? Single threaded comparison to scipy ============================ This is the time to compute the pdist of a 2048 x 2048 float matrix with only one thread for various values of p between torch and scipy. p = 3 is the code path for arbitrary p, and so is much slower than the other values. p | torch | scipy -----|-----------|------ 0 | 6.27 s ± 393 ms | 7.23 s ± 498 ms 1 | 5.49 s ± 201 ms | 43.4 s ± 1.09 s 2 | 5.74 s ± 474 ms | 53.8 s ± 3.52 s ∞ | 5.59 s ± 292 ms | 47.4 s ± 2.03 s 3 | really slow | gave up Result by AVX support ================ This is the time to compute the distance and gradient of a 2048 x 2048 float matrix with all threads by AVX support. `before` is the old code, `default` is no AVX support, etc. Interestingly the AVX optimizations provided a great benefit over the old unoptimized code, but drastically hurt performance when compiled without AVX optimizations. p = 3 is the code path for arbitrary p, and so is much slower than the other values. Results for p = 0 ---------------- avx | dist | grad ----|------|----- before | 514 ms ± 87.5 ms | 191 µs ± 35 µs default | 3.47 s ± 183 ms | 201 µs ± 24.6 µs avx | 123 ms ± 18.2 ms | 281 µs ± 130 µs avx2 | 103 ms ± 11.4 ms | 216 µs ± 74.4 µs Results for p = 1 ---------------- avx | dist | grad ----|------|----- before | 426 ms ± 35 ms | 6.21 s ± 187 ms default | 2.6 s ± 123 ms | 5.62 s ± 273 ms avx | 104 ms ± 6.37 ms | 833 ms ± 44.3 ms avx2 | 106 ms ± 3.59 ms | 924 ms ± 86.2 ms Results for p = 2 ----------------- avx | dist | grad ----|------|----- before | 425 ms ± 45.4 ms | 6.31 s ± 125 ms default | 3.04 s ± 187 ms | 3.55 s ± 242 ms avx | 110 ms ± 3.66 ms | 896 ms ± 21.8 ms avx2 | 113 ms ± 4.68 ms | 934 ms ± 25.2 ms Results for p = ∞ ------------------ avx | dist | grad ----|------|----- before | 501 ms ± 39.5 ms | 6.64 s ± 321 ms default | 2.15 s ± 92.9 ms | 8.43 s ± 355 ms avx | 104 ms ± 5.52 ms | 835 ms ± 36.7 ms avx2 | 100 ms ± 3.41 ms | 864 ms ± 67 ms Results for p = 3 ----------------- avx | dist | grad ----|------|----- before | 22.6 s ± 413 ms | 11.1 s ± 242 ms default | 24.9 s ± 1 s | 11.2 s ± 293 ms avx | 2.69 s ± 148 ms | 5.63 s ± 88.4 ms avx2 | 2.48 s ± 31.8 ms | 5.61 s ± 114 ms Pull Request resolved: pytorch#11230 Differential Revision: D9735503 Pulled By: erikbrinkman fbshipit-source-id: a9da619249e4ca2625b39ca1ca7f5543c3086bfb


Added AVX optimizations for pdist using Vec256. This brings single threaded performance up to speed with scipy, but the current implementation greatly hurts performance without AVX enabled. Is there a way to special case out AVX on dispatch and call the non Vec256 code? Or is the way I used Vec256 completely wrong?
Single threaded comparison to scipy
This is the time to compute the pdist of a 2048 x 2048 float matrix with only one thread for various values of p between torch and scipy. p = 3 is the code path for arbitrary p, and so is much slower than the other values.
Result by AVX support
This is the time to compute the distance and gradient of a 2048 x 2048 float matrix with all threads by AVX support.
beforeis the old code,defaultis no AVX support, etc. Interestingly the AVX optimizations provided a great benefit over the old unoptimized code, but drastically hurt performance when compiled without AVX optimizations. p = 3 is the code path for arbitrary p, and so is much slower than the other values.Results for p = 0
Results for p = 1
Results for p = 2
Results for p = ∞
Results for p = 3