Skip to content

Fast cuda layer norm#67977

Closed
ngimel wants to merge 9 commits intomasterfrom
ngimel/layer_norm
Closed

Fast cuda layer norm#67977
ngimel wants to merge 9 commits intomasterfrom
ngimel/layer_norm

Conversation

@ngimel
Copy link
Copy Markdown
Collaborator

@ngimel ngimel commented Nov 8, 2021

This adds apex-inspired fast layer norm forward kernel to pytorch (it is a significant rewrite though).
It's much faster than current implementation, for a typical transformer size (32*196, 1024) time goes down from ~180us to ~49 us on Volta. Compared to apex, it also produces bitwise accurate results between float inputs representable in fp16, and fp16 inputs. It produces slightly different results compared to current implementation though, because welford summation is implemented differently.
It is slower than lightSeq (~37 us), but lightseq uses inaccurate variance approximation, and doesn't guarantee float - fp16 bitwise accuracy.

@pytorch-probot
Copy link
Copy Markdown

pytorch-probot Bot commented Nov 8, 2021

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/pytorch/pytorch/blob/50677d1885b26ea97b1e663930a033d4c2fce788/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default,ciflow/cuda

Workflows Labels (bold enabled) Status
Triggered Workflows
libtorch-linux-xenial-cuda10.2-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux ✅ triggered
libtorch-linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux ✅ triggered
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow ✅ triggered
linux-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/xla ✅ triggered
linux-vulkan-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-dynamic ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3.6-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers ✅ triggered
linux-xenial-py3.6-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx ✅ triggered
linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled ✅ triggered
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck ✅ triggered
periodic-linux-xenial-cuda11.1-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled ✅ triggered
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/win ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped
docker-builds ciflow/all 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-full-jit ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64-full-jit ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
linux-xenial-py3-clang5-mobile-code-analysis ciflow/all, ciflow/linux, ciflow/mobile 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos 🚫 skipped
macos-10-15-py3-x86-64 ciflow/all, ciflow/macos 🚫 skipped
parallelnative-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped

You can add a comment to the PR and tag @pytorchbot with the following commands:
# ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun

# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow

For more information, please take a look at the CI Flow Wiki.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Nov 8, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 50677d1 (more details on the Dr. CI page):


  • 2/2 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See GitHub Actions build linux-bionic-cuda10.2-py3.9-gcc7 / test (slow, 1, 1, linux.4xlarge.nvidia.gpu) (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2021-11-10T23:06:42.1879366Z RuntimeError: test_jit_fuser_te failed! Received signal: SIGIOT
2021-11-10T23:06:39.4293972Z Generating XML reports...
2021-11-10T23:06:39.4408150Z Generated XML report: test-reports/python-unittest/test_jit_fuser_te/TEST-TestTEFuser-20211110230015.xml
2021-11-10T23:06:39.4411437Z Generated XML report: test-reports/python-unittest/test_jit_fuser_te/TEST-jit.test_fuser_common.TestFuserCommon-20211110230015.xml
2021-11-10T23:06:39.4921700Z Generated XML report: test-reports/python-unittest/test_jit_fuser_te/TEST-TestNNCOpInfoCUDA-20211110230015.xml
2021-11-10T23:06:39.6696469Z double free or corruption (out)
2021-11-10T23:06:42.1874864Z Traceback (most recent call last):
2021-11-10T23:06:42.1875753Z   File "/var/lib/jenkins/workspace/test/run_test.py", line 1041, in <module>
2021-11-10T23:06:42.1876812Z     main()
2021-11-10T23:06:42.1877478Z   File "/var/lib/jenkins/workspace/test/run_test.py", line 1019, in main
2021-11-10T23:06:42.1878661Z     raise RuntimeError(err_message)
2021-11-10T23:06:42.1879366Z RuntimeError: test_jit_fuser_te failed! Received signal: SIGIOT
2021-11-10T23:06:42.8099634Z 
2021-11-10T23:06:42.8100798Z real	32m15.220s
2021-11-10T23:06:42.8101457Z user	35m14.846s
2021-11-10T23:06:42.8102012Z sys	8m12.207s
2021-11-10T23:06:42.8102721Z + cleanup
2021-11-10T23:06:42.8103383Z + retcode=1
2021-11-10T23:06:42.8103790Z + set +x
2021-11-10T23:06:42.8148412Z ##[error]Process completed with exit code 1.
2021-11-10T23:06:42.8202770Z ##[group]Run # Ensure the working directory gets chowned back to the current user
2021-11-10T23:06:42.8203752Z �[36;1m# Ensure the working directory gets chowned back to the current user�[0m

1 failure not recognized by patterns:

Job Step Action
GitHub Actions win-vs2019-cuda11.3-py3 / test (smoke_tests, 1, 1, windows.8xlarge.nvidia.gpu) Unknown 🔁 rerun

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Copy Markdown
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stamped!

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ngimel
Copy link
Copy Markdown
Collaborator Author

ngimel commented Nov 10, 2021

@pytorchbot ciflow rerun -l ciflow/cuda

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@ngimel merged this pull request in 84d3df8.

desertfire pushed a commit that referenced this pull request Nov 15, 2021
Summary:
This adds apex-inspired fast layer norm forward kernel to pytorch (it is a significant rewrite though).
It's much faster than current implementation, for a typical transformer size (32*196, 1024) time goes down from ~180us to ~49 us on Volta. Compared to apex, it also produces bitwise accurate results between float inputs representable in fp16, and fp16 inputs. It produces slightly different results compared to current implementation though, because welford summation is implemented differently.
It is slower than lightSeq (~37 us), but lightseq uses inaccurate variance approximation, and doesn't guarantee float - fp16 bitwise accuracy.

Pull Request resolved: #67977

Reviewed By: mruberry

Differential Revision: D32285331

Pulled By: ngimel

fbshipit-source-id: a8b876a9cf3133daacfe0ce3a37e3ad566f4b6a8
desertfire pushed a commit that referenced this pull request Nov 15, 2021
Summary:
This adds apex-inspired fast layer norm forward kernel to pytorch (it is a significant rewrite though).
It's much faster than current implementation, for a typical transformer size (32*196, 1024) time goes down from ~180us to ~49 us on Volta. Compared to apex, it also produces bitwise accurate results between float inputs representable in fp16, and fp16 inputs. It produces slightly different results compared to current implementation though, because welford summation is implemented differently.
It is slower than lightSeq (~37 us), but lightseq uses inaccurate variance approximation, and doesn't guarantee float - fp16 bitwise accuracy.

Pull Request resolved: #67977

Reviewed By: mruberry

Differential Revision: D32285331

Pulled By: ngimel

fbshipit-source-id: a8b876a9cf3133daacfe0ce3a37e3ad566f4b6a8
facebook-github-bot pushed a commit that referenced this pull request Nov 17, 2021
Summary:
Benchmarks
At this PR
```
[------------------------------------------------------ ln ------------------------------------------------------]
                  |  fwd, torch.float32  |  fwdbwd, torch.float32  |  fwd, torch.float16  |  fwdbwd, torch.float16
1 threads: -------------------------------------------------------------------------------------------------------
      200, 256    |         17.5         |          106.6          |         18.1         |           94.7
      1000, 256   |         18.7         |          116.6          |         18.7         |          110.7
      6000, 256   |         28.1         |          111.8          |         19.4         |           92.3
      6272, 256   |         29.3         |          108.5          |         20.1         |           92.7
      200, 512    |         19.3         |           83.8          |         19.1         |          116.3
      1000, 512   |         17.9         |           88.0          |         17.9         |           93.0
      6000, 512   |         36.9         |          141.2          |         27.4         |          103.3
      6272, 512   |         38.2         |          146.5          |         28.1         |          107.9
      200, 1024   |         18.1         |           89.5          |         21.1         |          102.7
      1000, 1024  |         17.9         |           88.7          |         18.5         |           92.5
      6000, 1024  |         77.6         |          277.5          |         40.3         |          148.5
      6272, 1024  |         80.7         |          288.1          |         42.0         |          154.0
      200, 1536   |         17.9         |          117.3          |         18.1         |           88.1
      1000, 1536  |         22.9         |           92.0          |         19.4         |           89.0
      6000, 1536  |        123.4         |          436.3          |         61.7         |          228.5
      6272, 1536  |        129.1         |          457.3          |         64.3         |          238.5
      200, 2048   |         18.0         |           90.5          |         19.1         |          101.6
      1000, 2048  |         31.1         |          109.8          |         25.3         |          107.9
      6000, 2048  |        174.5         |          589.8          |         87.1         |          310.5
      6272, 2048  |        182.2         |          617.0          |         91.2         |          316.7
      200, 3072   |         19.8         |           96.4          |         19.4         |           89.3
      1000, 3072  |         48.1         |          168.7          |         23.5         |          100.9
      6000, 3072  |        267.1         |          930.0          |        134.8         |          519.2
      6272, 3072  |        278.2         |          971.2          |        140.7         |          540.2
```
Pre-#67977
```
[------------------------------------------------------- ln -------------------------------------------------------]
                    |  fwd, torch.float32  |  fwdbwd, torch.float32  |  fwd, torch.float16  |  fwdbwd, torch.float16
1 threads: ---------------------------------------------------------------------------------------------------------
        200,   256  |         20.9         |            92.6         |         21.3         |          110.1
       1000,   256  |         20.3         |            91.8         |         28.1         |          115.6
       6000,   256  |         93.0         |           310.7         |         86.3         |          299.8
       6272,   256  |         97.3         |           323.5         |         90.0         |          314.1
        200,   512  |         20.9         |           110.2         |         21.1         |           95.0
       1000,   512  |         24.0         |           102.8         |         22.2         |           95.9
       6000,   512  |        121.7         |           367.2         |        105.6         |          337.4
       6272,   512  |        127.0         |           382.3         |        111.3         |          352.0
        200,  1024  |         21.0         |           131.8         |         20.4         |           93.3
       1000,  1024  |         35.5         |           108.7         |         27.7         |           99.4
       6000,  1024  |        170.4         |           495.5         |        137.7         |          411.4
       6272,  1024  |        177.5         |           517.6         |        143.6         |          428.6
        200,  1536  |         21.9         |            97.6         |         20.8         |           92.7
       1000,  1536  |         44.3         |           129.7         |         33.9         |          100.1
       6000,  1536  |        215.8         |           619.2         |        167.2         |          480.9
       6272,  1536  |        225.0         |           646.9         |        174.8         |          505.9
        200,  2048  |         21.8         |           100.8         |         20.7         |           96.7
       1000,  2048  |         53.7         |           152.4         |         41.4         |          118.3
       6000,  2048  |        267.0         |           753.6         |        220.4         |          571.5
       6272,  2048  |        278.6         |           785.8         |        211.4         |          589.2
        200,  3072  |         20.9         |           103.7         |         21.9         |          104.6
       1000,  3072  |         71.4         |           201.1         |         53.1         |          148.3
       6000,  3072  |        365.7         |          1040.3         |        262.0         |          731.5
       6272,  3072  |        382.0         |          1084.4         |        273.3         |          766.3
```
Benchmarking script
```
import torch
from torch.utils.benchmark import Timer, Compare

results = []
for dtype in (torch.float, torch.half):
    for fs in (256, 512, 1024, 1536, 2048, 3072):
        for bs in (200, 1000, 6000, 196*32):
            ln = torch.nn.LayerNorm((fs,), device="cuda", dtype=dtype)
            X = torch.randn(bs, fs, device="cuda", dtype=dtype, requires_grad=True)
            gO = torch.rand_like(X)
            stmtfwd = "ln(X)"
            stmtfwdbwd = "X.grad=None; ln.zero_grad(set_to_none=True); out = ln(X); out.backward(gO)"
            tfwd = Timer(stmt=stmtfwd, label="ln", sub_label=f"{bs:5}, {fs:5}", description=f"fwd, {dtype}", globals=globals())
            tfwdbwd = Timer(stmt=stmtfwdbwd, label="ln", sub_label=f"{bs:5}, {fs:5}", description=f"fwdbwd, {dtype}", globals=globals())
            for t in (tfwd, tfwdbwd):
                results.append(t.blocked_autorange())
        print(fs, end='\r')
c = Compare(results)
c.print()
```

Pull Request resolved: #68238

Reviewed By: mruberry

Differential Revision: D32469450

Pulled By: ngimel

fbshipit-source-id: 08fe755c156d3d5c366c966cb808bf0f3e74c050
@ngimel ngimel deleted the ngimel/layer_norm branch December 26, 2021 06:43
@xuzhao9 xuzhao9 restored the ngimel/layer_norm branch January 11, 2022 00:28
@github-actions github-actions Bot deleted the ngimel/layer_norm branch February 13, 2024 02:05
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
Summary:
This adds apex-inspired fast layer norm forward kernel to pytorch (it is a significant rewrite though).
It's much faster than current implementation, for a typical transformer size (32*196, 1024) time goes down from ~180us to ~49 us on Volta. Compared to apex, it also produces bitwise accurate results between float inputs representable in fp16, and fp16 inputs. It produces slightly different results compared to current implementation though, because welford summation is implemented differently.
It is slower than lightSeq (~37 us), but lightseq uses inaccurate variance approximation, and doesn't guarantee float - fp16 bitwise accuracy.

Pull Request resolved: pytorch#67977

Reviewed By: mruberry

Differential Revision: D32285331

Pulled By: ngimel

fbshipit-source-id: a8b876a9cf3133daacfe0ce3a37e3ad566f4b6a8
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
Summary:
Benchmarks
At this PR
```
[------------------------------------------------------ ln ------------------------------------------------------]
                  |  fwd, torch.float32  |  fwdbwd, torch.float32  |  fwd, torch.float16  |  fwdbwd, torch.float16
1 threads: -------------------------------------------------------------------------------------------------------
      200, 256    |         17.5         |          106.6          |         18.1         |           94.7
      1000, 256   |         18.7         |          116.6          |         18.7         |          110.7
      6000, 256   |         28.1         |          111.8          |         19.4         |           92.3
      6272, 256   |         29.3         |          108.5          |         20.1         |           92.7
      200, 512    |         19.3         |           83.8          |         19.1         |          116.3
      1000, 512   |         17.9         |           88.0          |         17.9         |           93.0
      6000, 512   |         36.9         |          141.2          |         27.4         |          103.3
      6272, 512   |         38.2         |          146.5          |         28.1         |          107.9
      200, 1024   |         18.1         |           89.5          |         21.1         |          102.7
      1000, 1024  |         17.9         |           88.7          |         18.5         |           92.5
      6000, 1024  |         77.6         |          277.5          |         40.3         |          148.5
      6272, 1024  |         80.7         |          288.1          |         42.0         |          154.0
      200, 1536   |         17.9         |          117.3          |         18.1         |           88.1
      1000, 1536  |         22.9         |           92.0          |         19.4         |           89.0
      6000, 1536  |        123.4         |          436.3          |         61.7         |          228.5
      6272, 1536  |        129.1         |          457.3          |         64.3         |          238.5
      200, 2048   |         18.0         |           90.5          |         19.1         |          101.6
      1000, 2048  |         31.1         |          109.8          |         25.3         |          107.9
      6000, 2048  |        174.5         |          589.8          |         87.1         |          310.5
      6272, 2048  |        182.2         |          617.0          |         91.2         |          316.7
      200, 3072   |         19.8         |           96.4          |         19.4         |           89.3
      1000, 3072  |         48.1         |          168.7          |         23.5         |          100.9
      6000, 3072  |        267.1         |          930.0          |        134.8         |          519.2
      6272, 3072  |        278.2         |          971.2          |        140.7         |          540.2
```
Pre-pytorch#67977
```
[------------------------------------------------------- ln -------------------------------------------------------]
                    |  fwd, torch.float32  |  fwdbwd, torch.float32  |  fwd, torch.float16  |  fwdbwd, torch.float16
1 threads: ---------------------------------------------------------------------------------------------------------
        200,   256  |         20.9         |            92.6         |         21.3         |          110.1
       1000,   256  |         20.3         |            91.8         |         28.1         |          115.6
       6000,   256  |         93.0         |           310.7         |         86.3         |          299.8
       6272,   256  |         97.3         |           323.5         |         90.0         |          314.1
        200,   512  |         20.9         |           110.2         |         21.1         |           95.0
       1000,   512  |         24.0         |           102.8         |         22.2         |           95.9
       6000,   512  |        121.7         |           367.2         |        105.6         |          337.4
       6272,   512  |        127.0         |           382.3         |        111.3         |          352.0
        200,  1024  |         21.0         |           131.8         |         20.4         |           93.3
       1000,  1024  |         35.5         |           108.7         |         27.7         |           99.4
       6000,  1024  |        170.4         |           495.5         |        137.7         |          411.4
       6272,  1024  |        177.5         |           517.6         |        143.6         |          428.6
        200,  1536  |         21.9         |            97.6         |         20.8         |           92.7
       1000,  1536  |         44.3         |           129.7         |         33.9         |          100.1
       6000,  1536  |        215.8         |           619.2         |        167.2         |          480.9
       6272,  1536  |        225.0         |           646.9         |        174.8         |          505.9
        200,  2048  |         21.8         |           100.8         |         20.7         |           96.7
       1000,  2048  |         53.7         |           152.4         |         41.4         |          118.3
       6000,  2048  |        267.0         |           753.6         |        220.4         |          571.5
       6272,  2048  |        278.6         |           785.8         |        211.4         |          589.2
        200,  3072  |         20.9         |           103.7         |         21.9         |          104.6
       1000,  3072  |         71.4         |           201.1         |         53.1         |          148.3
       6000,  3072  |        365.7         |          1040.3         |        262.0         |          731.5
       6272,  3072  |        382.0         |          1084.4         |        273.3         |          766.3
```
Benchmarking script
```
import torch
from torch.utils.benchmark import Timer, Compare

results = []
for dtype in (torch.float, torch.half):
    for fs in (256, 512, 1024, 1536, 2048, 3072):
        for bs in (200, 1000, 6000, 196*32):
            ln = torch.nn.LayerNorm((fs,), device="cuda", dtype=dtype)
            X = torch.randn(bs, fs, device="cuda", dtype=dtype, requires_grad=True)
            gO = torch.rand_like(X)
            stmtfwd = "ln(X)"
            stmtfwdbwd = "X.grad=None; ln.zero_grad(set_to_none=True); out = ln(X); out.backward(gO)"
            tfwd = Timer(stmt=stmtfwd, label="ln", sub_label=f"{bs:5}, {fs:5}", description=f"fwd, {dtype}", globals=globals())
            tfwdbwd = Timer(stmt=stmtfwdbwd, label="ln", sub_label=f"{bs:5}, {fs:5}", description=f"fwdbwd, {dtype}", globals=globals())
            for t in (tfwd, tfwdbwd):
                results.append(t.blocked_autorange())
        print(fs, end='\r')
c = Compare(results)
c.print()
```

Pull Request resolved: pytorch#68238

Reviewed By: mruberry

Differential Revision: D32469450

Pulled By: ngimel

fbshipit-source-id: 08fe755c156d3d5c366c966cb808bf0f3e74c050
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants