Skip to content

Add AVX optimizations for pdist#11230

Closed
erikbrinkman wants to merge 3 commits intopytorch:masterfrom
erikbrinkman:avx
Closed

Add AVX optimizations for pdist#11230
erikbrinkman wants to merge 3 commits intopytorch:masterfrom
erikbrinkman:avx

Conversation

@erikbrinkman
Copy link
Contributor

Added AVX optimizations for pdist using Vec256. This brings single threaded performance up to speed with scipy, but the current implementation greatly hurts performance without AVX enabled. Is there a way to special case out AVX on dispatch and call the non Vec256 code? Or is the way I used Vec256 completely wrong?

Single threaded comparison to scipy

This is the time to compute the pdist of a 2048 x 2048 float matrix with only one thread for various values of p between torch and scipy. p = 3 is the code path for arbitrary p, and so is much slower than the other values.

p torch scipy
0 6.27 s ± 393 ms 7.23 s ± 498 ms
1 5.49 s ± 201 ms 43.4 s ± 1.09 s
2 5.74 s ± 474 ms 53.8 s ± 3.52 s
5.59 s ± 292 ms 47.4 s ± 2.03 s
3 really slow gave up

Result by AVX support

This is the time to compute the distance and gradient of a 2048 x 2048 float matrix with all threads by AVX support. before is the old code, default is no AVX support, etc. Interestingly the AVX optimizations provided a great benefit over the old unoptimized code, but drastically hurt performance when compiled without AVX optimizations. p = 3 is the code path for arbitrary p, and so is much slower than the other values.

Results for p = 0

avx dist grad
before 514 ms ± 87.5 ms 191 µs ± 35 µs
default 3.47 s ± 183 ms 201 µs ± 24.6 µs
avx 123 ms ± 18.2 ms 281 µs ± 130 µs
avx2 103 ms ± 11.4 ms 216 µs ± 74.4 µs

Results for p = 1

avx dist grad
before 426 ms ± 35 ms 6.21 s ± 187 ms
default 2.6 s ± 123 ms 5.62 s ± 273 ms
avx 104 ms ± 6.37 ms 833 ms ± 44.3 ms
avx2 106 ms ± 3.59 ms 924 ms ± 86.2 ms

Results for p = 2

avx dist grad
before 425 ms ± 45.4 ms 6.31 s ± 125 ms
default 3.04 s ± 187 ms 3.55 s ± 242 ms
avx 110 ms ± 3.66 ms 896 ms ± 21.8 ms
avx2 113 ms ± 4.68 ms 934 ms ± 25.2 ms

Results for p = ∞

avx dist grad
before 501 ms ± 39.5 ms 6.64 s ± 321 ms
default 2.15 s ± 92.9 ms 8.43 s ± 355 ms
avx 104 ms ± 5.52 ms 835 ms ± 36.7 ms
avx2 100 ms ± 3.41 ms 864 ms ± 67 ms

Results for p = 3

avx dist grad
before 22.6 s ± 413 ms 11.1 s ± 242 ms
default 24.9 s ± 1 s 11.2 s ± 293 ms
avx 2.69 s ± 148 ms 5.63 s ± 88.4 ms
avx2 2.48 s ± 31.8 ms 5.61 s ± 114 ms

@erikbrinkman
Copy link
Contributor Author

@pytorchbot retest this please

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Sep 5, 2018

What's the command you used to profile this? And also the script.

This comment was marked as off-topic.

@erikbrinkman
Copy link
Contributor Author

erikbrinkman commented Sep 5, 2018

@cpuhrsch In general I've just been using ipythons %timeit to run the command at least 10 times and report the timing.

The script for the comparison to scipy looks something like:

import torch
import scipy.spatial.distance as spd

xt = torch.randn(2048, 2048)
xn = xt.numpy()

for p in [0, 1, 2, 3, float('inf')]:
    if p == 0:
        spdist = lambda x: spd.pdist(x, 'hamming') * x.shape[1]
    elif p == float('inf'):
        spdist = lambda x: spd.pdist(x, lambda row, col: np.abs(row - col).max())
    else:
        spdist = lambda x: spd.pdist(x, 'minkowski', p=p)

    print(p)
    %timeit torch.pdist(xt, p)
    %timeit spdist(xn)

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Sep 5, 2018

@erikbrinkman - In general it should be worth running this on a few different magnitudes of tensors (so 10^3/4/5 elements) to see how it scales.

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty solid.

@ezyang
Copy link
Contributor

ezyang commented Sep 5, 2018

Copy link
Contributor

@cpuhrsch cpuhrsch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still need to resolve slowdown on DEFAULT capability and resolve some of the comments. We're talking offline as well. Will approve on author's request, however I'd like to see the changes to vec256 to be as specific to the required usage sides as possible (so, don't change functions you're not adding), just in case it causes slowdowns, since we're not looking at those here and it's out of scope.

@erikbrinkman
Copy link
Contributor Author

@ezyang Unfortunately with regards to pairwise distance, this uses a completely different loop / function call. There are ways they could be combined, but it is nonobvious to someone with my knowledge of the code base. Also, for the purposes of this function is makes sense to mandate contiguous memory as each row is read (n - 1) times, but for pairwise distance they're only read once need to be bradcast, etc.

Also, pairwise_distance calls norm, which I imagine is already vectorized.

@erikbrinkman erikbrinkman force-pushed the avx branch 2 times, most recently from 97f2f76 to 4bb3c3b Compare September 5, 2018 22:15
@erikbrinkman
Copy link
Contributor Author

@cpuhrsch The default dispatch is still slower than it was before introducing this change. The cause isn't clear other than it seems that the compiler is not eliding all of the Vec256 operations. If the answer is that lack of AVX is an uncommon use case and we don't care about performance there than this seems ready to push. If the answer is that we still want this to be reasonably fast than there are two ways to fix it: making the dispatch dispatch to the old code, or introducing tricks to get the compiler to figure out how to handle vec256_base functions appropriately. The later seems far outside the scope of this PR.

In response to timing for various sizes single threaded, usually an average of 7 runs. These two graphs are with the other dimension set to 2048. Note that for changing the number of rows, the computation grows by n^2. The width of each bar corresponds to the uncertainty. The scaling on each seems appropriate.

num_cols

num_rows

@erikbrinkman erikbrinkman force-pushed the avx branch 2 times, most recently from 55bdf58 to 7a07d79 Compare September 6, 2018 06:04
Copy link
Contributor

@cpuhrsch cpuhrsch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! I'll look at the DEFAULT branch in another PR.

@erikbrinkman erikbrinkman force-pushed the avx branch 4 times, most recently from 3a73260 to 260c91f Compare September 7, 2018 18:43
@erikbrinkman
Copy link
Contributor Author

@pytorchbot retest this please

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

erikbrinkman is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Sep 9, 2018
Summary:
Added AVX optimizations for pdist using Vec256. This brings single threaded performance up to speed with scipy, but the current implementation greatly hurts performance without AVX enabled. Is there a way to special case out AVX on dispatch and call the non Vec256 code? Or is the way I used Vec256 completely wrong?

Single threaded comparison to scipy
============================

This is the time to compute the pdist of a 2048 x 2048 float matrix with only one thread for various values of p between torch and scipy. p = 3 is the code path for arbitrary p, and so is much slower than the other values.

p | torch | scipy
-----|-----------|------
0 | 6.27 s ± 393 ms | 7.23 s ± 498 ms
1 | 5.49 s ± 201 ms | 43.4 s ± 1.09 s
2 | 5.74 s ± 474 ms | 53.8 s ± 3.52 s
∞ | 5.59 s ± 292 ms | 47.4 s ± 2.03 s
3 | really slow | gave up

Result by AVX support
================

This is the time to compute the distance and gradient of a 2048 x 2048 float matrix with all threads by AVX support. `before` is the old code, `default` is no AVX support, etc. Interestingly the AVX optimizations provided a great benefit over the old unoptimized code, but drastically hurt performance when compiled without AVX optimizations. p = 3 is the code path for arbitrary p, and so is much slower than the other values.

Results for p = 0
----------------

avx | dist | grad
----|------|-----
before | 514 ms ± 87.5 ms | 191 µs ± 35 µs
default | 3.47 s ± 183 ms | 201 µs ± 24.6 µs
avx | 123 ms ± 18.2 ms | 281 µs ± 130 µs
avx2 | 103 ms ± 11.4 ms | 216 µs ± 74.4 µs

Results for p = 1
----------------

avx | dist | grad
----|------|-----
before | 426 ms ± 35 ms | 6.21 s ± 187 ms
default | 2.6 s ± 123 ms | 5.62 s ± 273 ms
avx | 104 ms ± 6.37 ms | 833 ms ± 44.3 ms
avx2 | 106 ms ± 3.59 ms | 924 ms ± 86.2 ms

Results for p = 2
-----------------

avx | dist | grad
----|------|-----
before | 425 ms ± 45.4 ms | 6.31 s ± 125 ms
default | 3.04 s ± 187 ms | 3.55 s ± 242 ms
avx | 110 ms ± 3.66 ms | 896 ms ± 21.8 ms
avx2 | 113 ms ± 4.68 ms | 934 ms ± 25.2 ms

Results for p = ∞
------------------

avx | dist | grad
----|------|-----
before | 501 ms ± 39.5 ms | 6.64 s ± 321 ms
default | 2.15 s ± 92.9 ms | 8.43 s ± 355 ms
avx | 104 ms ± 5.52 ms | 835 ms ± 36.7 ms
avx2 | 100 ms ± 3.41 ms | 864 ms ± 67 ms

Results for p = 3
-----------------

avx | dist | grad
----|------|-----
before | 22.6 s ± 413 ms | 11.1 s ± 242 ms
default | 24.9 s ± 1 s | 11.2 s ± 293 ms
avx | 2.69 s ± 148 ms | 5.63 s ± 88.4 ms
avx2 | 2.48 s ± 31.8 ms | 5.61 s ± 114 ms
Pull Request resolved: pytorch/pytorch#11230

Differential Revision: D9735503

Pulled By: erikbrinkman

fbshipit-source-id: a9da619249e4ca2625b39ca1ca7f5543c3086bfb
PenghuiCheng pushed a commit to PenghuiCheng/pytorch that referenced this pull request Sep 11, 2018
Summary:
Added AVX optimizations for pdist using Vec256. This brings single threaded performance up to speed with scipy, but the current implementation greatly hurts performance without AVX enabled. Is there a way to special case out AVX on dispatch and call the non Vec256 code? Or is the way I used Vec256 completely wrong?

Single threaded comparison to scipy
============================

This is the time to compute the pdist of a 2048 x 2048 float matrix with only one thread for various values of p between torch and scipy. p = 3 is the code path for arbitrary p, and so is much slower than the other values.

p | torch | scipy
-----|-----------|------
0 | 6.27 s ± 393 ms | 7.23 s ± 498 ms
1 | 5.49 s ± 201 ms | 43.4 s ± 1.09 s
2 | 5.74 s ± 474 ms | 53.8 s ± 3.52 s
∞ | 5.59 s ± 292 ms | 47.4 s ± 2.03 s
3 | really slow | gave up

Result by AVX support
================

This is the time to compute the distance and gradient of a 2048 x 2048 float matrix with all threads by AVX support. `before` is the old code, `default` is no AVX support, etc. Interestingly the AVX optimizations provided a great benefit over the old unoptimized code, but drastically hurt performance when compiled without AVX optimizations. p = 3 is the code path for arbitrary p, and so is much slower than the other values.

Results for p = 0
----------------

avx | dist | grad
----|------|-----
before | 514 ms ± 87.5 ms | 191 µs ± 35 µs
default | 3.47 s ± 183 ms | 201 µs ± 24.6 µs
avx | 123 ms ± 18.2 ms | 281 µs ± 130 µs
avx2 | 103 ms ± 11.4 ms | 216 µs ± 74.4 µs

Results for p = 1
----------------

avx | dist | grad
----|------|-----
before | 426 ms ± 35 ms | 6.21 s ± 187 ms
default | 2.6 s ± 123 ms | 5.62 s ± 273 ms
avx | 104 ms ± 6.37 ms | 833 ms ± 44.3 ms
avx2 | 106 ms ± 3.59 ms | 924 ms ± 86.2 ms

Results for p = 2
-----------------

avx | dist | grad
----|------|-----
before | 425 ms ± 45.4 ms | 6.31 s ± 125 ms
default | 3.04 s ± 187 ms | 3.55 s ± 242 ms
avx | 110 ms ± 3.66 ms | 896 ms ± 21.8 ms
avx2 | 113 ms ± 4.68 ms | 934 ms ± 25.2 ms

Results for p = ∞
------------------

avx | dist | grad
----|------|-----
before | 501 ms ± 39.5 ms | 6.64 s ± 321 ms
default | 2.15 s ± 92.9 ms | 8.43 s ± 355 ms
avx | 104 ms ± 5.52 ms | 835 ms ± 36.7 ms
avx2 | 100 ms ± 3.41 ms | 864 ms ± 67 ms

Results for p = 3
-----------------

avx | dist | grad
----|------|-----
before | 22.6 s ± 413 ms | 11.1 s ± 242 ms
default | 24.9 s ± 1 s | 11.2 s ± 293 ms
avx | 2.69 s ± 148 ms | 5.63 s ± 88.4 ms
avx2 | 2.48 s ± 31.8 ms | 5.61 s ± 114 ms
Pull Request resolved: pytorch#11230

Differential Revision: D9735503

Pulled By: erikbrinkman

fbshipit-source-id: a9da619249e4ca2625b39ca1ca7f5543c3086bfb
@ezyang ezyang added the merged label Jun 26, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants