Skip to content

Migrate sin and sin_ from the TH to Aten (CUDA)#28237

Closed
xuhdev wants to merge 7 commits intogh/xuhdev/43/basefrom
gh/xuhdev/43/head
Closed

Migrate sin and sin_ from the TH to Aten (CUDA)#28237
xuhdev wants to merge 7 commits intogh/xuhdev/43/basefrom
gh/xuhdev/43/head

Conversation

@xuhdev
Copy link
Collaborator

@xuhdev xuhdev commented Oct 17, 2019

Stack from ghstack:

Benchmark (RHEL 7, gcc 8.3.1, P1000):

import timeit

for n, t in [(10_000, 20000),
             (100_000, 20000)]:
    for dtype in ('torch.half', 'torch.float', 'torch.double'):
        print(f'torch.sin(a) a.numel() == {n} for {t} times {dtype}')
        print(timeit.timeit(f'torch.sin(a); torch.cuda.synchronize()', setup=f'import torch; a=torch.arange({n}, dtype={dtype}, device="cuda")', number=t))

Before:

torch.sin(a) a.numel() == 10000 for 20000 times torch.half
0.4649172620011086
torch.sin(a) a.numel() == 10000 for 20000 times torch.float
0.4616892600006395
torch.sin(a) a.numel() == 10000 for 20000 times torch.double
0.5166665920005471
torch.sin(a) a.numel() == 100000 for 20000 times torch.half
0.5376560490003612
torch.sin(a) a.numel() == 100000 for 20000 times torch.float
0.6207812359989475
torch.sin(a) a.numel() == 100000 for 20000 times torch.double
1.873208982999131

After:

torch.sin(a) a.numel() == 10000 for 20000 times torch.half
0.4796977340010926
torch.sin(a) a.numel() == 10000 for 20000 times torch.float
0.48329569199995603
torch.sin(a) a.numel() == 10000 for 20000 times torch.double
0.5380683220009814
torch.sin(a) a.numel() == 100000 for 20000 times torch.half
0.5299932739999349
torch.sin(a) a.numel() == 100000 for 20000 times torch.float
0.6144487999990815
torch.sin(a) a.numel() == 100000 for 20000 times torch.double
1.8838113630008593

Close #24627

Differential Revision: D18089072

xuhdev added a commit that referenced this pull request Oct 17, 2019
Close #24627

ghstack-source-id: c6d9d9f
Pull Request resolved: #28237
xuhdev added a commit that referenced this pull request Oct 18, 2019
Close #24627

ghstack-source-id: 0632ddf
Pull Request resolved: #28237
@ifedan
Copy link
Contributor

ifedan commented Oct 18, 2019

Provide, please, performance metrics for this function before and after your change.

@xuhdev
Copy link
Collaborator Author

xuhdev commented Oct 21, 2019

Benchmark added. Seems the timing is pretty unstable though; I got numbers that are way different after a while. Also see benchmark on the previous two PRs on this stack. But this should be sufficient to show that there is likely no negligence in losing performance.

@ifedan
Copy link
Contributor

ifedan commented Oct 22, 2019

Benchmark added. Seems the timing is pretty unstable though; I got numbers that are way different after a while. Also see benchmark on the previous two PRs on this stack. But this should be sufficient to show that there is likely no negligence in losing performance.

did you use?
OMP_NUM_THREADS=1
MKL_NUM_THREADS=1

@xuhdev
Copy link
Collaborator Author

xuhdev commented Oct 22, 2019

did you use?
OMP_NUM_THREADS=1
MKL_NUM_THREADS=1

I didn't enable OMP (USE_OPENMP=0) and didn't have MKL installed.

@xuhdev
Copy link
Collaborator Author

xuhdev commented Oct 22, 2019

@ifedan I ended up benchmarking on a different machine which I have more control over the hardware (e.g., turning off turbo, warming up GPU and making sure its completely unused by others). The results are pretty close and stable now. Please see my update in the benchmark results.

Benchmark (RHEL 7, gcc 8.3.1, P1000):

```python
import timeit

for n, t in [(10_000, 20000),
             (100_000, 20000)]:
    for dtype in ('torch.half', 'torch.float', 'torch.double'):
        print(f'torch.sin(a) a.numel() == {n} for {t} times {dtype}')
        print(timeit.timeit(f'torch.sin(a); torch.cuda.synchronize()', setup=f'import torch; a=torch.arange({n}, dtype={dtype}, device="cuda")', number=t))
```

Before:

```
torch.sin(a) a.numel() == 10000 for 20000 times torch.half
0.4649172620011086
torch.sin(a) a.numel() == 10000 for 20000 times torch.float
0.4616892600006395
torch.sin(a) a.numel() == 10000 for 20000 times torch.double
0.5166665920005471
torch.sin(a) a.numel() == 100000 for 20000 times torch.half
0.5376560490003612
torch.sin(a) a.numel() == 100000 for 20000 times torch.float
0.6207812359989475
torch.sin(a) a.numel() == 100000 for 20000 times torch.double
1.873208982999131
```

After:

```
torch.sin(a) a.numel() == 10000 for 20000 times torch.half
0.4796977340010926
torch.sin(a) a.numel() == 10000 for 20000 times torch.float
0.48329569199995603
torch.sin(a) a.numel() == 10000 for 20000 times torch.double
0.5380683220009814
torch.sin(a) a.numel() == 100000 for 20000 times torch.half
0.5299932739999349
torch.sin(a) a.numel() == 100000 for 20000 times torch.float
0.6144487999990815
torch.sin(a) a.numel() == 100000 for 20000 times torch.double
1.8838113630008593
```


Close #24627

[ghstack-poisoned]
Benchmark (RHEL 7, gcc 8.3.1, P1000):

```python
import timeit

for n, t in [(10_000, 20000),
             (100_000, 20000)]:
    for dtype in ('torch.half', 'torch.float', 'torch.double'):
        print(f'torch.sin(a) a.numel() == {n} for {t} times {dtype}')
        print(timeit.timeit(f'torch.sin(a); torch.cuda.synchronize()', setup=f'import torch; a=torch.arange({n}, dtype={dtype}, device="cuda")', number=t))
```

Before:

```
torch.sin(a) a.numel() == 10000 for 20000 times torch.half
0.4649172620011086
torch.sin(a) a.numel() == 10000 for 20000 times torch.float
0.4616892600006395
torch.sin(a) a.numel() == 10000 for 20000 times torch.double
0.5166665920005471
torch.sin(a) a.numel() == 100000 for 20000 times torch.half
0.5376560490003612
torch.sin(a) a.numel() == 100000 for 20000 times torch.float
0.6207812359989475
torch.sin(a) a.numel() == 100000 for 20000 times torch.double
1.873208982999131
```

After:

```
torch.sin(a) a.numel() == 10000 for 20000 times torch.half
0.4796977340010926
torch.sin(a) a.numel() == 10000 for 20000 times torch.float
0.48329569199995603
torch.sin(a) a.numel() == 10000 for 20000 times torch.double
0.5380683220009814
torch.sin(a) a.numel() == 100000 for 20000 times torch.half
0.5299932739999349
torch.sin(a) a.numel() == 100000 for 20000 times torch.float
0.6144487999990815
torch.sin(a) a.numel() == 100000 for 20000 times torch.double
1.8838113630008593
```


Close #24627

[ghstack-poisoned]
Benchmark (RHEL 7, gcc 8.3.1, P1000):

```python
import timeit

for n, t in [(10_000, 20000),
             (100_000, 20000)]:
    for dtype in ('torch.half', 'torch.float', 'torch.double'):
        print(f'torch.sin(a) a.numel() == {n} for {t} times {dtype}')
        print(timeit.timeit(f'torch.sin(a); torch.cuda.synchronize()', setup=f'import torch; a=torch.arange({n}, dtype={dtype}, device="cuda")', number=t))
```

Before:

```
torch.sin(a) a.numel() == 10000 for 20000 times torch.half
0.4649172620011086
torch.sin(a) a.numel() == 10000 for 20000 times torch.float
0.4616892600006395
torch.sin(a) a.numel() == 10000 for 20000 times torch.double
0.5166665920005471
torch.sin(a) a.numel() == 100000 for 20000 times torch.half
0.5376560490003612
torch.sin(a) a.numel() == 100000 for 20000 times torch.float
0.6207812359989475
torch.sin(a) a.numel() == 100000 for 20000 times torch.double
1.873208982999131
```

After:

```
torch.sin(a) a.numel() == 10000 for 20000 times torch.half
0.4796977340010926
torch.sin(a) a.numel() == 10000 for 20000 times torch.float
0.48329569199995603
torch.sin(a) a.numel() == 10000 for 20000 times torch.double
0.5380683220009814
torch.sin(a) a.numel() == 100000 for 20000 times torch.half
0.5299932739999349
torch.sin(a) a.numel() == 100000 for 20000 times torch.float
0.6144487999990815
torch.sin(a) a.numel() == 100000 for 20000 times torch.double
1.8838113630008593
```


Close #24627

Differential Revision: [D18089072](https://our.internmc.facebook.com/intern/diff/D18089072)

[ghstack-poisoned]
zdevito pushed a commit to zdevito/ATen that referenced this pull request Oct 30, 2019
Summary:
Pull Request resolved: pytorch/pytorch#28237

Benchmark (RHEL 7, gcc 8.3.1, P1000):

```python
import timeit

for n, t in [(10_000, 20000),
             (100_000, 20000)]:
    for dtype in ('torch.half', 'torch.float', 'torch.double'):
        print(f'torch.sin(a) a.numel() == {n} for {t} times {dtype}')
        print(timeit.timeit(f'torch.sin(a); torch.cuda.synchronize()', setup=f'import torch; a=torch.arange({n}, dtype={dtype}, device="cuda")', number=t))
```

Before:

```
torch.sin(a) a.numel() == 10000 for 20000 times torch.half
0.4649172620011086
torch.sin(a) a.numel() == 10000 for 20000 times torch.float
0.4616892600006395
torch.sin(a) a.numel() == 10000 for 20000 times torch.double
0.5166665920005471
torch.sin(a) a.numel() == 100000 for 20000 times torch.half
0.5376560490003612
torch.sin(a) a.numel() == 100000 for 20000 times torch.float
0.6207812359989475
torch.sin(a) a.numel() == 100000 for 20000 times torch.double
1.873208982999131
```

After:

```
torch.sin(a) a.numel() == 10000 for 20000 times torch.half
0.4796977340010926
torch.sin(a) a.numel() == 10000 for 20000 times torch.float
0.48329569199995603
torch.sin(a) a.numel() == 10000 for 20000 times torch.double
0.5380683220009814
torch.sin(a) a.numel() == 100000 for 20000 times torch.half
0.5299932739999349
torch.sin(a) a.numel() == 100000 for 20000 times torch.float
0.6144487999990815
torch.sin(a) a.numel() == 100000 for 20000 times torch.double
1.8838113630008593
```

Close #24627

Test Plan: Imported from OSS

Differential Revision: D18089072

Pulled By: VitalyFedyunin

fbshipit-source-id: 4824804960309fe7fdb16073d021388704986993
@facebook-github-bot
Copy link
Contributor

@VitalyFedyunin merged this pull request in d0bd8a3.

@facebook-github-bot facebook-github-bot deleted the gh/xuhdev/43/head branch November 3, 2019 15:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants