Conversation
CI Flow Status⚛️ CI FlowRuleset - Version:
You can add a comment to the PR and tag @pytorchbot with the following commands: # ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun
# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slowFor more information, please take a look at the CI Flow Wiki. |
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 50677d1 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
| Job | Step | Action |
|---|---|---|
| Unknown | 🔁 rerun |
This comment was automatically generated by Dr. CI (expand for details).
Please report bugs/suggestions to the (internal) Dr. CI Users group.
|
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
dbf653d to
50677d1
Compare
|
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
|
@pytorchbot ciflow rerun -l ciflow/cuda |
Summary: This adds apex-inspired fast layer norm forward kernel to pytorch (it is a significant rewrite though). It's much faster than current implementation, for a typical transformer size (32*196, 1024) time goes down from ~180us to ~49 us on Volta. Compared to apex, it also produces bitwise accurate results between float inputs representable in fp16, and fp16 inputs. It produces slightly different results compared to current implementation though, because welford summation is implemented differently. It is slower than lightSeq (~37 us), but lightseq uses inaccurate variance approximation, and doesn't guarantee float - fp16 bitwise accuracy. Pull Request resolved: #67977 Reviewed By: mruberry Differential Revision: D32285331 Pulled By: ngimel fbshipit-source-id: a8b876a9cf3133daacfe0ce3a37e3ad566f4b6a8
Summary: This adds apex-inspired fast layer norm forward kernel to pytorch (it is a significant rewrite though). It's much faster than current implementation, for a typical transformer size (32*196, 1024) time goes down from ~180us to ~49 us on Volta. Compared to apex, it also produces bitwise accurate results between float inputs representable in fp16, and fp16 inputs. It produces slightly different results compared to current implementation though, because welford summation is implemented differently. It is slower than lightSeq (~37 us), but lightseq uses inaccurate variance approximation, and doesn't guarantee float - fp16 bitwise accuracy. Pull Request resolved: #67977 Reviewed By: mruberry Differential Revision: D32285331 Pulled By: ngimel fbshipit-source-id: a8b876a9cf3133daacfe0ce3a37e3ad566f4b6a8
Summary:
Benchmarks
At this PR
```
[------------------------------------------------------ ln ------------------------------------------------------]
| fwd, torch.float32 | fwdbwd, torch.float32 | fwd, torch.float16 | fwdbwd, torch.float16
1 threads: -------------------------------------------------------------------------------------------------------
200, 256 | 17.5 | 106.6 | 18.1 | 94.7
1000, 256 | 18.7 | 116.6 | 18.7 | 110.7
6000, 256 | 28.1 | 111.8 | 19.4 | 92.3
6272, 256 | 29.3 | 108.5 | 20.1 | 92.7
200, 512 | 19.3 | 83.8 | 19.1 | 116.3
1000, 512 | 17.9 | 88.0 | 17.9 | 93.0
6000, 512 | 36.9 | 141.2 | 27.4 | 103.3
6272, 512 | 38.2 | 146.5 | 28.1 | 107.9
200, 1024 | 18.1 | 89.5 | 21.1 | 102.7
1000, 1024 | 17.9 | 88.7 | 18.5 | 92.5
6000, 1024 | 77.6 | 277.5 | 40.3 | 148.5
6272, 1024 | 80.7 | 288.1 | 42.0 | 154.0
200, 1536 | 17.9 | 117.3 | 18.1 | 88.1
1000, 1536 | 22.9 | 92.0 | 19.4 | 89.0
6000, 1536 | 123.4 | 436.3 | 61.7 | 228.5
6272, 1536 | 129.1 | 457.3 | 64.3 | 238.5
200, 2048 | 18.0 | 90.5 | 19.1 | 101.6
1000, 2048 | 31.1 | 109.8 | 25.3 | 107.9
6000, 2048 | 174.5 | 589.8 | 87.1 | 310.5
6272, 2048 | 182.2 | 617.0 | 91.2 | 316.7
200, 3072 | 19.8 | 96.4 | 19.4 | 89.3
1000, 3072 | 48.1 | 168.7 | 23.5 | 100.9
6000, 3072 | 267.1 | 930.0 | 134.8 | 519.2
6272, 3072 | 278.2 | 971.2 | 140.7 | 540.2
```
Pre-#67977
```
[------------------------------------------------------- ln -------------------------------------------------------]
| fwd, torch.float32 | fwdbwd, torch.float32 | fwd, torch.float16 | fwdbwd, torch.float16
1 threads: ---------------------------------------------------------------------------------------------------------
200, 256 | 20.9 | 92.6 | 21.3 | 110.1
1000, 256 | 20.3 | 91.8 | 28.1 | 115.6
6000, 256 | 93.0 | 310.7 | 86.3 | 299.8
6272, 256 | 97.3 | 323.5 | 90.0 | 314.1
200, 512 | 20.9 | 110.2 | 21.1 | 95.0
1000, 512 | 24.0 | 102.8 | 22.2 | 95.9
6000, 512 | 121.7 | 367.2 | 105.6 | 337.4
6272, 512 | 127.0 | 382.3 | 111.3 | 352.0
200, 1024 | 21.0 | 131.8 | 20.4 | 93.3
1000, 1024 | 35.5 | 108.7 | 27.7 | 99.4
6000, 1024 | 170.4 | 495.5 | 137.7 | 411.4
6272, 1024 | 177.5 | 517.6 | 143.6 | 428.6
200, 1536 | 21.9 | 97.6 | 20.8 | 92.7
1000, 1536 | 44.3 | 129.7 | 33.9 | 100.1
6000, 1536 | 215.8 | 619.2 | 167.2 | 480.9
6272, 1536 | 225.0 | 646.9 | 174.8 | 505.9
200, 2048 | 21.8 | 100.8 | 20.7 | 96.7
1000, 2048 | 53.7 | 152.4 | 41.4 | 118.3
6000, 2048 | 267.0 | 753.6 | 220.4 | 571.5
6272, 2048 | 278.6 | 785.8 | 211.4 | 589.2
200, 3072 | 20.9 | 103.7 | 21.9 | 104.6
1000, 3072 | 71.4 | 201.1 | 53.1 | 148.3
6000, 3072 | 365.7 | 1040.3 | 262.0 | 731.5
6272, 3072 | 382.0 | 1084.4 | 273.3 | 766.3
```
Benchmarking script
```
import torch
from torch.utils.benchmark import Timer, Compare
results = []
for dtype in (torch.float, torch.half):
for fs in (256, 512, 1024, 1536, 2048, 3072):
for bs in (200, 1000, 6000, 196*32):
ln = torch.nn.LayerNorm((fs,), device="cuda", dtype=dtype)
X = torch.randn(bs, fs, device="cuda", dtype=dtype, requires_grad=True)
gO = torch.rand_like(X)
stmtfwd = "ln(X)"
stmtfwdbwd = "X.grad=None; ln.zero_grad(set_to_none=True); out = ln(X); out.backward(gO)"
tfwd = Timer(stmt=stmtfwd, label="ln", sub_label=f"{bs:5}, {fs:5}", description=f"fwd, {dtype}", globals=globals())
tfwdbwd = Timer(stmt=stmtfwdbwd, label="ln", sub_label=f"{bs:5}, {fs:5}", description=f"fwdbwd, {dtype}", globals=globals())
for t in (tfwd, tfwdbwd):
results.append(t.blocked_autorange())
print(fs, end='\r')
c = Compare(results)
c.print()
```
Pull Request resolved: #68238
Reviewed By: mruberry
Differential Revision: D32469450
Pulled By: ngimel
fbshipit-source-id: 08fe755c156d3d5c366c966cb808bf0f3e74c050
Summary: This adds apex-inspired fast layer norm forward kernel to pytorch (it is a significant rewrite though). It's much faster than current implementation, for a typical transformer size (32*196, 1024) time goes down from ~180us to ~49 us on Volta. Compared to apex, it also produces bitwise accurate results between float inputs representable in fp16, and fp16 inputs. It produces slightly different results compared to current implementation though, because welford summation is implemented differently. It is slower than lightSeq (~37 us), but lightseq uses inaccurate variance approximation, and doesn't guarantee float - fp16 bitwise accuracy. Pull Request resolved: pytorch#67977 Reviewed By: mruberry Differential Revision: D32285331 Pulled By: ngimel fbshipit-source-id: a8b876a9cf3133daacfe0ce3a37e3ad566f4b6a8
Summary:
Benchmarks
At this PR
```
[------------------------------------------------------ ln ------------------------------------------------------]
| fwd, torch.float32 | fwdbwd, torch.float32 | fwd, torch.float16 | fwdbwd, torch.float16
1 threads: -------------------------------------------------------------------------------------------------------
200, 256 | 17.5 | 106.6 | 18.1 | 94.7
1000, 256 | 18.7 | 116.6 | 18.7 | 110.7
6000, 256 | 28.1 | 111.8 | 19.4 | 92.3
6272, 256 | 29.3 | 108.5 | 20.1 | 92.7
200, 512 | 19.3 | 83.8 | 19.1 | 116.3
1000, 512 | 17.9 | 88.0 | 17.9 | 93.0
6000, 512 | 36.9 | 141.2 | 27.4 | 103.3
6272, 512 | 38.2 | 146.5 | 28.1 | 107.9
200, 1024 | 18.1 | 89.5 | 21.1 | 102.7
1000, 1024 | 17.9 | 88.7 | 18.5 | 92.5
6000, 1024 | 77.6 | 277.5 | 40.3 | 148.5
6272, 1024 | 80.7 | 288.1 | 42.0 | 154.0
200, 1536 | 17.9 | 117.3 | 18.1 | 88.1
1000, 1536 | 22.9 | 92.0 | 19.4 | 89.0
6000, 1536 | 123.4 | 436.3 | 61.7 | 228.5
6272, 1536 | 129.1 | 457.3 | 64.3 | 238.5
200, 2048 | 18.0 | 90.5 | 19.1 | 101.6
1000, 2048 | 31.1 | 109.8 | 25.3 | 107.9
6000, 2048 | 174.5 | 589.8 | 87.1 | 310.5
6272, 2048 | 182.2 | 617.0 | 91.2 | 316.7
200, 3072 | 19.8 | 96.4 | 19.4 | 89.3
1000, 3072 | 48.1 | 168.7 | 23.5 | 100.9
6000, 3072 | 267.1 | 930.0 | 134.8 | 519.2
6272, 3072 | 278.2 | 971.2 | 140.7 | 540.2
```
Pre-pytorch#67977
```
[------------------------------------------------------- ln -------------------------------------------------------]
| fwd, torch.float32 | fwdbwd, torch.float32 | fwd, torch.float16 | fwdbwd, torch.float16
1 threads: ---------------------------------------------------------------------------------------------------------
200, 256 | 20.9 | 92.6 | 21.3 | 110.1
1000, 256 | 20.3 | 91.8 | 28.1 | 115.6
6000, 256 | 93.0 | 310.7 | 86.3 | 299.8
6272, 256 | 97.3 | 323.5 | 90.0 | 314.1
200, 512 | 20.9 | 110.2 | 21.1 | 95.0
1000, 512 | 24.0 | 102.8 | 22.2 | 95.9
6000, 512 | 121.7 | 367.2 | 105.6 | 337.4
6272, 512 | 127.0 | 382.3 | 111.3 | 352.0
200, 1024 | 21.0 | 131.8 | 20.4 | 93.3
1000, 1024 | 35.5 | 108.7 | 27.7 | 99.4
6000, 1024 | 170.4 | 495.5 | 137.7 | 411.4
6272, 1024 | 177.5 | 517.6 | 143.6 | 428.6
200, 1536 | 21.9 | 97.6 | 20.8 | 92.7
1000, 1536 | 44.3 | 129.7 | 33.9 | 100.1
6000, 1536 | 215.8 | 619.2 | 167.2 | 480.9
6272, 1536 | 225.0 | 646.9 | 174.8 | 505.9
200, 2048 | 21.8 | 100.8 | 20.7 | 96.7
1000, 2048 | 53.7 | 152.4 | 41.4 | 118.3
6000, 2048 | 267.0 | 753.6 | 220.4 | 571.5
6272, 2048 | 278.6 | 785.8 | 211.4 | 589.2
200, 3072 | 20.9 | 103.7 | 21.9 | 104.6
1000, 3072 | 71.4 | 201.1 | 53.1 | 148.3
6000, 3072 | 365.7 | 1040.3 | 262.0 | 731.5
6272, 3072 | 382.0 | 1084.4 | 273.3 | 766.3
```
Benchmarking script
```
import torch
from torch.utils.benchmark import Timer, Compare
results = []
for dtype in (torch.float, torch.half):
for fs in (256, 512, 1024, 1536, 2048, 3072):
for bs in (200, 1000, 6000, 196*32):
ln = torch.nn.LayerNorm((fs,), device="cuda", dtype=dtype)
X = torch.randn(bs, fs, device="cuda", dtype=dtype, requires_grad=True)
gO = torch.rand_like(X)
stmtfwd = "ln(X)"
stmtfwdbwd = "X.grad=None; ln.zero_grad(set_to_none=True); out = ln(X); out.backward(gO)"
tfwd = Timer(stmt=stmtfwd, label="ln", sub_label=f"{bs:5}, {fs:5}", description=f"fwd, {dtype}", globals=globals())
tfwdbwd = Timer(stmt=stmtfwdbwd, label="ln", sub_label=f"{bs:5}, {fs:5}", description=f"fwdbwd, {dtype}", globals=globals())
for t in (tfwd, tfwdbwd):
results.append(t.blocked_autorange())
print(fs, end='\r')
c = Compare(results)
c.print()
```
Pull Request resolved: pytorch#68238
Reviewed By: mruberry
Differential Revision: D32469450
Pulled By: ngimel
fbshipit-source-id: 08fe755c156d3d5c366c966cb808bf0f3e74c050
This adds apex-inspired fast layer norm forward kernel to pytorch (it is a significant rewrite though).
It's much faster than current implementation, for a typical transformer size (32*196, 1024) time goes down from ~180us to ~49 us on Volta. Compared to apex, it also produces bitwise accurate results between float inputs representable in fp16, and fp16 inputs. It produces slightly different results compared to current implementation though, because welford summation is implemented differently.
It is slower than lightSeq (~37 us), but lightseq uses inaccurate variance approximation, and doesn't guarantee float - fp16 bitwise accuracy.