Skip to content

Migrate asin and asin_ from the TH to Aten (CUDA)#28482

Closed
xuhdev wants to merge 2 commits intogh/xuhdev/44/basefrom
gh/xuhdev/44/head
Closed

Migrate asin and asin_ from the TH to Aten (CUDA)#28482
xuhdev wants to merge 2 commits intogh/xuhdev/44/basefrom
gh/xuhdev/44/head

Conversation

@xuhdev
Copy link
Collaborator

@xuhdev xuhdev commented Oct 23, 2019

Stack from ghstack:

Benchmark (Debian Buster, CUDA 9.2, Quadro P400, turbo off, Release, gcc 7.4):

import timeit

for n, t in [(10_000, 20000),
             (100_000, 20000)]:
    for dtype in ('torch.half', 'torch.float', 'torch.double'):
        print(f'torch.asin(a) a.numel() == {n} for {t} times {dtype}')
        print(timeit.timeit(f'torch.asin(a); torch.cuda.synchronize()', setup=f'import torch; a=torch.arange({n}, dtype={dtype}, device="cuda")', number=t))

Before:

torch.asin(a) a.numel() == 10000 for 20000 times torch.half
0.37638233399957244
torch.asin(a) a.numel() == 10000 for 20000 times torch.float
0.377967665000142
torch.asin(a) a.numel() == 10000 for 20000 times torch.double
0.9213513669992608
torch.asin(a) a.numel() == 100000 for 20000 times torch.half
0.8473557480001546
torch.asin(a) a.numel() == 100000 for 20000 times torch.float
1.0340258509995692
torch.asin(a) a.numel() == 100000 for 20000 times torch.double
6.045749833000627

After:

torch.asin(a) a.numel() == 10000 for 20000 times torch.half
0.37552232399957575
torch.asin(a) a.numel() == 10000 for 20000 times torch.float
0.3796842550000292
torch.asin(a) a.numel() == 10000 for 20000 times torch.double
0.9202746920000209
torch.asin(a) a.numel() == 100000 for 20000 times torch.half
0.8276844379997783
torch.asin(a) a.numel() == 100000 for 20000 times torch.float
1.014272968000114
torch.asin(a) a.numel() == 100000 for 20000 times torch.double
6.039313418999882

Close #24537

Differential Revision: D18089074

Benchmark (RHEL 7.3, Release, P1000, gcc 8.3):

```python
import timeit

for n, t in [(10_000, 20000),
             (100_000, 20000)]:
    for dtype in ('torch.half', 'torch.float', 'torch.double'):
        print(f'torch.asin(a) a.numel() == {n} for {t} times {dtype}')
        print(timeit.timeit(f'torch.asin(a); torch.cuda.synchronize()', setup=f'import torch; a=torch.arange({n}, dtype={dtype}, device="cuda")', number=t))
```

Before:

```
torch.asin(a) a.numel() == 10000 for 20000 times torch.half
0.475854377997166
torch.asin(a) a.numel() == 10000 for 20000 times torch.float
0.4772826389998954
torch.asin(a) a.numel() == 10000 for 20000 times torch.double
0.6297428649995709
torch.asin(a) a.numel() == 100000 for 20000 times torch.half
0.5475849750000634
torch.asin(a) a.numel() == 100000 for 20000 times torch.float
0.6156488769993302
torch.asin(a) a.numel() == 100000 for 20000 times torch.double
2.728912709000724
```

After:

```
torch.asin(a) a.numel() == 10000 for 20000 times torch.half
0.5107104659982724
torch.asin(a) a.numel() == 10000 for 20000 times torch.float
0.509122366001975
torch.asin(a) a.numel() == 10000 for 20000 times torch.double
0.6929216960015765
torch.asin(a) a.numel() == 100000 for 20000 times torch.half
0.5914848840002378
torch.asin(a) a.numel() == 100000 for 20000 times torch.float
0.6518679289983993
torch.asin(a) a.numel() == 100000 for 20000 times torch.double
2.916458261999651
```

Close #24537

[ghstack-poisoned]
xuhdev added a commit that referenced this pull request Oct 23, 2019
Benchmark (RHEL 7.3, Release, P1000, gcc 8.3):

```python
import timeit

for n, t in [(10_000, 20000),
             (100_000, 20000)]:
    for dtype in ('torch.half', 'torch.float', 'torch.double'):
        print(f'torch.asin(a) a.numel() == {n} for {t} times {dtype}')
        print(timeit.timeit(f'torch.asin(a); torch.cuda.synchronize()', setup=f'import torch; a=torch.arange({n}, dtype={dtype}, device="cuda")', number=t))
```

Before:

```
torch.asin(a) a.numel() == 10000 for 20000 times torch.half
0.475854377997166
torch.asin(a) a.numel() == 10000 for 20000 times torch.float
0.4772826389998954
torch.asin(a) a.numel() == 10000 for 20000 times torch.double
0.6297428649995709
torch.asin(a) a.numel() == 100000 for 20000 times torch.half
0.5475849750000634
torch.asin(a) a.numel() == 100000 for 20000 times torch.float
0.6156488769993302
torch.asin(a) a.numel() == 100000 for 20000 times torch.double
2.728912709000724
```

After:

```
torch.asin(a) a.numel() == 10000 for 20000 times torch.half
0.5107104659982724
torch.asin(a) a.numel() == 10000 for 20000 times torch.float
0.509122366001975
torch.asin(a) a.numel() == 10000 for 20000 times torch.double
0.6929216960015765
torch.asin(a) a.numel() == 100000 for 20000 times torch.half
0.5914848840002378
torch.asin(a) a.numel() == 100000 for 20000 times torch.float
0.6518679289983993
torch.asin(a) a.numel() == 100000 for 20000 times torch.double
2.916458261999651
```

Close #24537

ghstack-source-id: 36ddd47
Pull Request resolved: #28482
@xuhdev xuhdev requested a review from VitalyFedyunin October 23, 2019 01:05
@vishwakftw
Copy link
Contributor

It seems like there's a regression, any idea why?

@xuhdev
Copy link
Collaborator Author

xuhdev commented Oct 23, 2019

It seems like there's a regression, any idea why?

It's likely some sort of hardware unstability. It's pretty hard to get consistent CUDA benchmarks (as it turns out, even power source would affect the outcome). I can try to run the benchmark again at another time, but the CUDA build is really time consuming.

Benchmark (RHEL 7.3, Release, P1000, gcc 8.3):

```python
import timeit

for n, t in [(10_000, 20000),
             (100_000, 20000)]:
    for dtype in ('torch.half', 'torch.float', 'torch.double'):
        print(f'torch.asin(a) a.numel() == {n} for {t} times {dtype}')
        print(timeit.timeit(f'torch.asin(a); torch.cuda.synchronize()', setup=f'import torch; a=torch.arange({n}, dtype={dtype}, device="cuda")', number=t))
```

Before:

```
torch.asin(a) a.numel() == 10000 for 20000 times torch.half
0.475854377997166
torch.asin(a) a.numel() == 10000 for 20000 times torch.float
0.4772826389998954
torch.asin(a) a.numel() == 10000 for 20000 times torch.double
0.6297428649995709
torch.asin(a) a.numel() == 100000 for 20000 times torch.half
0.5475849750000634
torch.asin(a) a.numel() == 100000 for 20000 times torch.float
0.6156488769993302
torch.asin(a) a.numel() == 100000 for 20000 times torch.double
2.728912709000724
```

After:

```
torch.asin(a) a.numel() == 10000 for 20000 times torch.half
0.5107104659982724
torch.asin(a) a.numel() == 10000 for 20000 times torch.float
0.509122366001975
torch.asin(a) a.numel() == 10000 for 20000 times torch.double
0.6929216960015765
torch.asin(a) a.numel() == 100000 for 20000 times torch.half
0.5914848840002378
torch.asin(a) a.numel() == 100000 for 20000 times torch.float
0.6518679289983993
torch.asin(a) a.numel() == 100000 for 20000 times torch.double
2.916458261999651
```

Close #24537

Differential Revision: [D18089074](https://our.internmc.facebook.com/intern/diff/D18089074)

[ghstack-poisoned]
@xuhdev
Copy link
Collaborator Author

xuhdev commented Oct 23, 2019

I added more stabilization to the benchmarking environment and now they are producing very close results. See the updated benchmarking result.

@vishwakftw
Copy link
Contributor

Thanks @xuhdev!

zdevito pushed a commit to zdevito/ATen that referenced this pull request Oct 30, 2019
Summary:
Pull Request resolved: pytorch/pytorch#28482

Benchmark (RHEL 7.3, Release, P1000, gcc 8.3):

```python
import timeit

for n, t in [(10_000, 20000),
             (100_000, 20000)]:
    for dtype in ('torch.half', 'torch.float', 'torch.double'):
        print(f'torch.asin(a) a.numel() == {n} for {t} times {dtype}')
        print(timeit.timeit(f'torch.asin(a); torch.cuda.synchronize()', setup=f'import torch; a=torch.arange({n}, dtype={dtype}, device="cuda")', number=t))
```

Before:

```
torch.asin(a) a.numel() == 10000 for 20000 times torch.half
0.475854377997166
torch.asin(a) a.numel() == 10000 for 20000 times torch.float
0.4772826389998954
torch.asin(a) a.numel() == 10000 for 20000 times torch.double
0.6297428649995709
torch.asin(a) a.numel() == 100000 for 20000 times torch.half
0.5475849750000634
torch.asin(a) a.numel() == 100000 for 20000 times torch.float
0.6156488769993302
torch.asin(a) a.numel() == 100000 for 20000 times torch.double
2.728912709000724
```

After:

```
torch.asin(a) a.numel() == 10000 for 20000 times torch.half
0.5107104659982724
torch.asin(a) a.numel() == 10000 for 20000 times torch.float
0.509122366001975
torch.asin(a) a.numel() == 10000 for 20000 times torch.double
0.6929216960015765
torch.asin(a) a.numel() == 100000 for 20000 times torch.half
0.5914848840002378
torch.asin(a) a.numel() == 100000 for 20000 times torch.float
0.6518679289983993
torch.asin(a) a.numel() == 100000 for 20000 times torch.double
2.916458261999651
```

Close #24537

Test Plan: Imported from OSS

Differential Revision: D18089074

Pulled By: VitalyFedyunin

fbshipit-source-id: f27515dd1ee73b6e2391ebcc0004af28bcb82234
@facebook-github-bot
Copy link
Contributor

@VitalyFedyunin merged this pull request in a7166ae.

@facebook-github-bot facebook-github-bot deleted the gh/xuhdev/44/head branch November 3, 2019 15:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants