Skip to content

[CUDA graphs] [JIT] Capture-safe RNG in nvfuser#50148

Closed
mcarilli wants to merge 1 commit intopytorch:masterfrom
mcarilli:graphable_jit_rng
Closed

[CUDA graphs] [JIT] Capture-safe RNG in nvfuser#50148
mcarilli wants to merge 1 commit intopytorch:masterfrom
mcarilli:graphable_jit_rng

Conversation

@mcarilli
Copy link
Copy Markdown
Collaborator

@mcarilli mcarilli commented Jan 6, 2021

Update: @csarofeen says he prefers nvfuser changes to move through his fork before moving into mainline. Closing to submit nvfuser changes to his fork and the non-nvfuser (CUDAGeneratorImpl) changes as a separate PR.


Eager mode RNG kernels needed some minor changes to interact safely with cuda graphs. This PR extends those changes to the kernels generated by nvfuser.

One thing I'm unclear on is the best way to let NVRTC know the definition of PhiloxCudaState (defined in ATen/CUDAGeneratorImpl.h). I suggested two options in comments (1, 2) but im not sure.

Another thing I'm unclear on is the best way to test these diffs.

Unrelated to graphs, this PR also fixes what I'm fairly sure is a subtle bug with custom "Philox" class usage in jitted kernels. Philox constructors in kernels take the cuda rng generator's current offset. The Philox constructor then carries out offset/4 (a uint64_t division) to compute its internal offset in its virtual Philox bitstream of 128-bit chunks. In other words, it assumes the incoming offset is a multiple of 4. But (in current code) that's not guaranteed. For example, the increments used by these eager kernels could easily make offset not divisible by 4. I figured the easiest fix** was to round all incoming increments up to the nearest multiple of 4 in CUDAGeneratorImpl itself.

** Another option would be to round the current offset up to the next multiple of 4 at the jit point of use. But that would be a jit-specific offset jump, so jit rng kernels wouldn't have a prayer of being bitwise accurate with eager rng kernels that used non-multiple-of-4 offsets. Restricting the offset to multiples of 4 for everyone at least gives jit rng the chance to match eager rng. (Of course, there are still many other ways the numerics could diverge, like if a jit kernel launches a different number of threads than an eager kernel, or assigns threads to data elements differently.)

@mcarilli mcarilli requested a review from csarofeen January 6, 2021 18:18
@facebook-github-bot facebook-github-bot added cla signed oncall: jit Add this issue/PR to JIT oncall triage queue labels Jan 6, 2021
@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Jan 6, 2021

💊 CI failures summary and remediations

As of commit 05c6dbe (more details on the Dr. CI page):



🕵️ 4 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test1 (1/4)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Jan 06 19:17:53 AssertionError: mypy failed: test/test_dataset.py:93: error: Need type annotation for 'collate_ds_nolen' [var-annotated]
Jan 06 19:17:32   test_run_mypy (__main__.TestTypeHints) ... FAIL (63.492s)
Jan 06 19:17:34   test_run_mypy_strict (__main__.TestTypeHints) ... ok (2.798s)
Jan 06 19:17:53   test_type_hint_examples (__main__.TestTypeHints) ... ok (18.634s)
Jan 06 19:17:53 
Jan 06 19:17:53 ======================================================================
Jan 06 19:17:53 FAIL [63.492s]: test_run_mypy (__main__.TestTypeHints)
Jan 06 19:17:53 ----------------------------------------------------------------------
Jan 06 19:17:53 Traceback (most recent call last):
Jan 06 19:17:53   File "test_type_hints.py", line 214, in test_run_mypy
Jan 06 19:17:53     self.fail(f"mypy failed: {stdout} {stderr}")
Jan 06 19:17:53 AssertionError: mypy failed: test/test_dataset.py:93: error: Need type annotation for 'collate_ds_nolen'  [var-annotated]
Jan 06 19:17:53 test/test_dataset.py:147: error: Need type annotation for 'sampled_ds'  [var-annotated]
Jan 06 19:17:53 test/test_dataset.py:155: error: Need type annotation for 'random_sampled_ds'  [var-annotated]
Jan 06 19:17:53 Found 3 errors in 1 file (checked 1189 source files)
Jan 06 19:17:53  
Jan 06 19:17:53 
Jan 06 19:17:53 ----------------------------------------------------------------------
Jan 06 19:17:53 Ran 4 tests in 96.196s
Jan 06 19:17:53 
Jan 06 19:17:53 FAILED (failures=1)
Jan 06 19:17:53 

See CircleCI build pytorch_linux_bionic_py3_6_clang9_test (2/4)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Jan 06 19:00:16 AssertionError: mypy failed: test/test_dataset.py:93: error: Need type annotation for 'collate_ds_nolen' [var-annotated]
Jan 06 18:59:54   test_run_mypy (__main__.TestTypeHints) ... FAIL (59.361s)
Jan 06 18:59:57   test_run_mypy_strict (__main__.TestTypeHints) ... ok (2.815s)
Jan 06 19:00:16   test_type_hint_examples (__main__.TestTypeHints) ... ok (19.311s)
Jan 06 19:00:16 
Jan 06 19:00:16 ======================================================================
Jan 06 19:00:16 FAIL [59.361s]: test_run_mypy (__main__.TestTypeHints)
Jan 06 19:00:16 ----------------------------------------------------------------------
Jan 06 19:00:16 Traceback (most recent call last):
Jan 06 19:00:16   File "test_type_hints.py", line 214, in test_run_mypy
Jan 06 19:00:16     self.fail(f"mypy failed: {stdout} {stderr}")
Jan 06 19:00:16 AssertionError: mypy failed: test/test_dataset.py:93: error: Need type annotation for 'collate_ds_nolen'  [var-annotated]
Jan 06 19:00:16 test/test_dataset.py:147: error: Need type annotation for 'sampled_ds'  [var-annotated]
Jan 06 19:00:16 test/test_dataset.py:155: error: Need type annotation for 'random_sampled_ds'  [var-annotated]
Jan 06 19:00:16 Found 3 errors in 1 file (checked 1189 source files)
Jan 06 19:00:16  
Jan 06 19:00:16 
Jan 06 19:00:17 ----------------------------------------------------------------------
Jan 06 19:00:17 Ran 4 tests in 92.200s
Jan 06 19:00:17 
Jan 06 19:00:17 FAILED (failures=1)
Jan 06 19:00:17 

See CircleCI build pytorch_linux_xenial_py3_clang5_asan_test1 (3/4)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Jan 06 19:00:10 AssertionError: mypy failed: test/test_dataset.py:93: error: Need type annotation for 'collate_ds_nolen' [var-annotated]
Jan 06 18:59:44   test_run_mypy (__main__.TestTypeHints) ... FAIL (71.418s)
Jan 06 18:59:47   test_run_mypy_strict (__main__.TestTypeHints) ... ok (3.346s)
Jan 06 19:00:10   test_type_hint_examples (__main__.TestTypeHints) ... ok (23.155s)
Jan 06 19:00:10 
Jan 06 19:00:10 ======================================================================
Jan 06 19:00:10 FAIL [71.418s]: test_run_mypy (__main__.TestTypeHints)
Jan 06 19:00:10 ----------------------------------------------------------------------
Jan 06 19:00:10 Traceback (most recent call last):
Jan 06 19:00:10   File "test_type_hints.py", line 214, in test_run_mypy
Jan 06 19:00:10     self.fail(f"mypy failed: {stdout} {stderr}")
Jan 06 19:00:10 AssertionError: mypy failed: test/test_dataset.py:93: error: Need type annotation for 'collate_ds_nolen'  [var-annotated]
Jan 06 19:00:10 test/test_dataset.py:147: error: Need type annotation for 'sampled_ds'  [var-annotated]
Jan 06 19:00:10 test/test_dataset.py:155: error: Need type annotation for 'random_sampled_ds'  [var-annotated]
Jan 06 19:00:10 Found 3 errors in 1 file (checked 1189 source files)
Jan 06 19:00:10  
Jan 06 19:00:10 
Jan 06 19:00:10 ----------------------------------------------------------------------
Jan 06 19:00:10 Ran 4 tests in 111.342s
Jan 06 19:00:10 
Jan 06 19:00:10 FAILED (failures=1)
Jan 06 19:00:10 

See CircleCI build pytorch_linux_bionic_py3_8_gcc9_coverage_test1 (4/4)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Jan 06 19:06:10 AssertionError: mypy failed: test/test_dataset.py:93: error: Need type annotation for 'collate_ds_nolen' [var-annotated]
Jan 06 19:05:48   test_type_hint_examples (__main__.TestTypeHints)
Jan 06 19:06:10 Runs mypy over all the test examples present in ... ok (21.970s)
Jan 06 19:06:10 
Jan 06 19:06:10 ======================================================================
Jan 06 19:06:10 FAIL [79.903s]: test_run_mypy (__main__.TestTypeHints)
Jan 06 19:06:10 Runs mypy over all files specified in mypy.ini
Jan 06 19:06:10 ----------------------------------------------------------------------
Jan 06 19:06:10 Traceback (most recent call last):
Jan 06 19:06:10   File "test_type_hints.py", line 214, in test_run_mypy
Jan 06 19:06:10     self.fail(f"mypy failed: {stdout} {stderr}")
Jan 06 19:06:10 AssertionError: mypy failed: test/test_dataset.py:93: error: Need type annotation for 'collate_ds_nolen'  [var-annotated]
Jan 06 19:06:10 test/test_dataset.py:147: error: Need type annotation for 'sampled_ds'  [var-annotated]
Jan 06 19:06:10 test/test_dataset.py:155: error: Need type annotation for 'random_sampled_ds'  [var-annotated]
Jan 06 19:06:10 Found 3 errors in 1 file (checked 1189 source files)
Jan 06 19:06:10  
Jan 06 19:06:10 
Jan 06 19:06:10 ----------------------------------------------------------------------
Jan 06 19:06:10 Ran 4 tests in 119.347s
Jan 06 19:06:10 
Jan 06 19:06:10 FAILED (failures=1)
Jan 06 19:06:10 

❄️ 1 failure tentatively classified as flaky

but reruns have not yet been triggered to confirm:

See CircleCI build pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test2 (1/1)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun) ❄️

Jan 06 21:23:33 unknown file: Failure
Jan 06 21:23:27 [       OK ] NVFuserTest.FusionBCastConcretizeRfactor_CUDA (1 ms)
Jan 06 21:23:27 [ RUN      ] NVFuserTest.FusionProveIdEqBasic_CUDA
Jan 06 21:23:27 [       OK ] NVFuserTest.FusionProveIdEqBasic_CUDA (0 ms)
Jan 06 21:23:27 [ RUN      ] NVFuserTest.FusionProveIdEqRfactor_CUDA
Jan 06 21:23:27 [       OK ] NVFuserTest.FusionProveIdEqRfactor_CUDA (0 ms)
Jan 06 21:23:27 [ RUN      ] NVFuserTest.FusionScalarInputs_CUDA
Jan 06 21:23:27 [       OK ] NVFuserTest.FusionScalarInputs_CUDA (289 ms)
Jan 06 21:23:27 [ RUN      ] NVFuserTest.FusionLoopUnroll_CUDA
Jan 06 21:23:28 [       OK ] NVFuserTest.FusionLoopUnroll_CUDA (352 ms)
Jan 06 21:23:28 [ RUN      ] NVFuserTest.FusionUnaryOps_CUDA
Jan 06 21:23:33 unknown file: Failure
Jan 06 21:23:33 C++ exception with description "false INTERNAL ASSERT FAILED at "/var/lib/jenkins/workspace/torch/csrc/jit/codegen/cuda/executor_utils.cpp":337, please report a bug to PyTorch. namespace CudaCodeGen {
Jan 06 21:23:33 
Jan 06 21:23:33 #define __HALF_TO_US(var) *(reinterpret_cast<unsigned short*>(&(var)))
Jan 06 21:23:33 #define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short*>(&(var)))
Jan 06 21:23:33 
Jan 06 21:23:33 struct __align__(2) __half {
Jan 06 21:23:33   __host__ __device__ __half() {}
Jan 06 21:23:33 
Jan 06 21:23:33  protected:
Jan 06 21:23:33   unsigned short __x;
--- ### Extra GitHub checks: 1 failed * **Failed:** [GitHub Actions - `clang-format`](https://github.com/pytorch/pytorch/runs/1658141489)
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

This comment has been revised 9 times.

@mcarilli mcarilli requested a review from ngimel January 6, 2021 18:18
@mcarilli mcarilli closed this Jan 6, 2021
facebook-github-bot pushed a commit that referenced this pull request Jan 11, 2021
…kernels (#50169)

Summary:
Immediately-upstreamable part of #50148.

This PR fixes what I'm fairly sure is a subtle bug with custom `Philox` class usage in jitted kernels.  `Philox` [constructors in kernels](https://github.com/pytorch/pytorch/blob/68a6e4637903dba279c60daae5cff24e191ff9b4/torch/csrc/jit/codegen/cuda/codegen.cpp#L102) take the cuda rng generator's current offset.  The Philox constructor then carries out [`offset/4`](https://github.com/pytorch/pytorch/blob/74c055b24065d0202aecdf4bc837d3698d1639e1/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu#L13) (a uint64_t division) to compute its internal offset in its virtual Philox bitstream of 128-bit chunks.  In other words, it assumes the incoming offset is a multiple of 4.  But (in current code) that's not guaranteed.  For example, the increments used by [these eager kernels](https://github.com/pytorch/pytorch/blob/74c055b24065d0202aecdf4bc837d3698d1639e1/aten/src/ATen/native/cuda/Distributions.cu#L171-L216) could easily make offset not divisible by 4.

I figured the easiest fix was to round all incoming increments up to the nearest multiple of 4 in CUDAGeneratorImpl itself.

Another option would be to round the current offset up to the next multiple of 4 at the jit point of use.  But that would be a jit-specific offset jump, so jit rng kernels wouldn't have a prayer of being bitwise accurate with eager rng kernels that used non-multiple-of-4 offsets.  Restricting the offset to multiples of 4 for everyone at least gives jit rng the chance to match eager rng.  (Of course, there are still many other ways the numerics could diverge, like if a jit kernel launches a different number of threads than an eager kernel, or assigns threads to data elements differently.)

Pull Request resolved: #50169

Reviewed By: mruberry

Differential Revision: D25857934

Pulled By: ngimel

fbshipit-source-id: 43a75e2d0c8565651b0f12a5694c744fd86ece99
jianyuh added a commit to jianyuh/FBGEMM that referenced this pull request Jan 24, 2021
Summary:
Follow up on the failure case on FP16 stochastic rounding:
- pytorch/pytorch#50148
- D26006041 (pytorch@ceb16c9)

From Natalia:
- pytorch/pytorch#50916 is the fix, philox_engine_inputs is deprecated btw so if you could refactor it to use philox_cuda_state that would be great.
- instructions to change the call https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/CUDAGeneratorImpl.h#L48-L83, it will be important to use philox_cuda_state with graph capture.

Differential Revision: D26038596

fbshipit-source-id: f42bce9e5893eb9478b63cd4ca45c4da003dcf2d
jianyuh added a commit to jianyuh/FBGEMM that referenced this pull request Jan 25, 2021
Summary:
Pull Request resolved: pytorch/pytorch#51004

Pull Request resolved: pytorch#493

Follow up on the failure case on FP16 stochastic rounding:
- pytorch/pytorch#50148
- D26006041 (pytorch@ceb16c9)

From Natalia:
- pytorch/pytorch#50916 is the fix, philox_engine_inputs is deprecated btw so if you could refactor it to use philox_cuda_state that would be great.
- instructions to change the call https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/CUDAGeneratorImpl.h#L48-L83, it will be important to use philox_cuda_state with graph capture.

Benchmark:
- Before this Diff:
```
(base) [jianyuhuang@devgpu017.atn5.facebook.com: ~/fbsource/fbcode/hpc/ops/benchmarks] $  buck run mode/opt //hpc/ops/benchmarks:split_table_batched_embeddings_benchmark device -- --fp16 --stoc 2>&1 | tee before_diff.log
PARSING BUCK FILES: FINISHED IN 0.4s
CREATING ACTION GRAPH: FINISHED IN 0.0s
DOWNLOADED 0 ARTIFACTS, 0.00 BYTES, 0.0% CACHE MISS
BUILDING: FINISHED IN 5.3s (100%) 6474/6474 JOBS, 0 UPDATED
BUILD SUCCEEDED
DEBUG:root:Using fused exact_row_wise_adagrad with optimizer_args=OptimizerArgs(stochastic_rounding=True, gradient_clipping=False, max_gradient=1.0, learning_rate=0.1, eps=0.1, beta1=0.9, beta2=0.999, weight_decay=0.0, eta=0.001, momentum=0.9)
INFO:root:Embedding parameters:  0.41 GParam,  0.82GB
INFO:root:Accessed weights per batch:  83.89MB
INFO:root:Forward, B: 512, E: 100000, T: 32, D: 128, L: 20, W: False, BW:  607.48GB/s, T: 138us
INFO:root:ForwardBackward, B: 512, E: 100000, T: 32, D: 128, L: 20, BW:  220.85GB/s, T: 1139us
```

- After this Diff:
```
(base) [jianyuhuang@devgpu017.atn5.facebook.com: ~/fbsource/fbcode/hpc/ops/benchmarks] $  buck run mode/opt //hpc/ops/[5/1935]
ks:split_table_batched_embeddings_benchmark device -- --fp16 --stoc 2>&1 | tee after_diff.log
PARSING BUCK FILES: FINISHED IN 1.1s
CREATING ACTION GRAPH: FINISHED IN 0.0s
DEBUG:root:Using fused exact_row_wise_adagrad with optimizer_args=OptimizerArgs(stochastic_rounding=True, gradient_clipping=Fal
se, max_gradient=1.0, learning_rate=0.1, eps=0.1, beta1=0.9, beta2=0.999, weight_decay=0.0, eta=0.001, momentum=0.9)           INFO:root:Embedding parameters:  0.41 GParam,  0.82GB
INFO:root:Accessed weights per batch:  83.89MB
INFO:root:Forward, B: 512, E: 100000, T: 32, D: 128, L: 20, W: False, BW:  608.80GB/s, T: 138us
INFO:root:ForwardBackward, B: 512, E: 100000, T: 32, D: 128, L: 20, BW:  229.17GB/s, T: 1098us
```

Differential Revision: D26038596

fbshipit-source-id: 0154ba22688747e717d4c630e958938eff739b24
facebook-github-bot pushed a commit to pytorch/FBGEMM that referenced this pull request Feb 1, 2021
Summary:
Pull Request resolved: pytorch/pytorch#51004

Pull Request resolved: #493

Follow up on the failure case on FP16 stochastic rounding:
- pytorch/pytorch#50148
- D26006041 (ceb16c9)

From Natalia:
- pytorch/pytorch#50916 is the fix, philox_engine_inputs is deprecated btw so if you could refactor it to use philox_cuda_state that would be great.
- instructions to change the call https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/CUDAGeneratorImpl.h#L48-L83, it will be important to use philox_cuda_state with graph capture.

Benchmark:
- Before this Diff:
```
(base) [jianyuhuang@devgpu017.atn5.facebook.com: ~/fbsource/fbcode/hpc/ops/benchmarks] $  buck run mode/opt //hpc/ops/benchmarks:split_table_batched_embeddings_benchmark device -- --fp16 --stoc 2>&1 | tee before_diff.log
PARSING BUCK FILES: FINISHED IN 0.4s
CREATING ACTION GRAPH: FINISHED IN 0.0s
DOWNLOADED 0 ARTIFACTS, 0.00 BYTES, 0.0% CACHE MISS
BUILDING: FINISHED IN 5.3s (100%) 6474/6474 JOBS, 0 UPDATED
BUILD SUCCEEDED
DEBUG:root:Using fused exact_row_wise_adagrad with optimizer_args=OptimizerArgs(stochastic_rounding=True, gradient_clipping=False, max_gradient=1.0, learning_rate=0.1, eps=0.1, beta1=0.9, beta2=0.999, weight_decay=0.0, eta=0.001, momentum=0.9)
INFO:root:Embedding parameters:  0.41 GParam,  0.82GB
INFO:root:Accessed weights per batch:  83.89MB
INFO:root:Forward, B: 512, E: 100000, T: 32, D: 128, L: 20, W: False, BW:  607.48GB/s, T: 138us
INFO:root:ForwardBackward, B: 512, E: 100000, T: 32, D: 128, L: 20, BW:  220.85GB/s, T: 1139us
```

- After this Diff:
```
(base) [jianyuhuang@devgpu017.atn5.facebook.com: ~/fbsource/fbcode/hpc/ops/benchmarks] $  buck run mode/opt //hpc/ops/[5/1935]
ks:split_table_batched_embeddings_benchmark device -- --fp16 --stoc 2>&1 | tee after_diff.log
PARSING BUCK FILES: FINISHED IN 1.1s
CREATING ACTION GRAPH: FINISHED IN 0.0s
DEBUG:root:Using fused exact_row_wise_adagrad with optimizer_args=OptimizerArgs(stochastic_rounding=True, gradient_clipping=Fal
se, max_gradient=1.0, learning_rate=0.1, eps=0.1, beta1=0.9, beta2=0.999, weight_decay=0.0, eta=0.001, momentum=0.9)           INFO:root:Embedding parameters:  0.41 GParam,  0.82GB
INFO:root:Accessed weights per batch:  83.89MB
INFO:root:Forward, B: 512, E: 100000, T: 32, D: 128, L: 20, W: False, BW:  608.80GB/s, T: 138us
INFO:root:ForwardBackward, B: 512, E: 100000, T: 32, D: 128, L: 20, BW:  229.17GB/s, T: 1098us
```

Reviewed By: ngimel

Differential Revision: D26038596

fbshipit-source-id: 5360395c1c3b1a062b38e5695239258e892c63c4
facebook-github-bot pushed a commit that referenced this pull request Feb 1, 2021
…ding (#51004)

Summary:
Pull Request resolved: #51004

Pull Request resolved: pytorch/FBGEMM#493

Follow up on the failure case on FP16 stochastic rounding:
- #50148
- D26006041

From Natalia:
- #50916 is the fix, philox_engine_inputs is deprecated btw so if you could refactor it to use philox_cuda_state that would be great.
- instructions to change the call https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/CUDAGeneratorImpl.h#L48-L83, it will be important to use philox_cuda_state with graph capture.

Benchmark:
- Before this Diff:
```
(base) [jianyuhuang@devgpu017.atn5.facebook.com: ~/fbsource/fbcode/hpc/ops/benchmarks] $  buck run mode/opt //hpc/ops/benchmarks:split_table_batched_embeddings_benchmark device -- --fp16 --stoc 2>&1 | tee before_diff.log
PARSING BUCK FILES: FINISHED IN 0.4s
CREATING ACTION GRAPH: FINISHED IN 0.0s
DOWNLOADED 0 ARTIFACTS, 0.00 BYTES, 0.0% CACHE MISS
BUILDING: FINISHED IN 5.3s (100%) 6474/6474 JOBS, 0 UPDATED
BUILD SUCCEEDED
DEBUG:root:Using fused exact_row_wise_adagrad with optimizer_args=OptimizerArgs(stochastic_rounding=True, gradient_clipping=False, max_gradient=1.0, learning_rate=0.1, eps=0.1, beta1=0.9, beta2=0.999, weight_decay=0.0, eta=0.001, momentum=0.9)
INFO:root:Embedding parameters:  0.41 GParam,  0.82GB
INFO:root:Accessed weights per batch:  83.89MB
INFO:root:Forward, B: 512, E: 100000, T: 32, D: 128, L: 20, W: False, BW:  607.48GB/s, T: 138us
INFO:root:ForwardBackward, B: 512, E: 100000, T: 32, D: 128, L: 20, BW:  220.85GB/s, T: 1139us
```

- After this Diff:
```
(base) [jianyuhuang@devgpu017.atn5.facebook.com: ~/fbsource/fbcode/hpc/ops/benchmarks] $  buck run mode/opt //hpc/ops/[5/1935]
ks:split_table_batched_embeddings_benchmark device -- --fp16 --stoc 2>&1 | tee after_diff.log
PARSING BUCK FILES: FINISHED IN 1.1s
CREATING ACTION GRAPH: FINISHED IN 0.0s
DEBUG:root:Using fused exact_row_wise_adagrad with optimizer_args=OptimizerArgs(stochastic_rounding=True, gradient_clipping=Fal
se, max_gradient=1.0, learning_rate=0.1, eps=0.1, beta1=0.9, beta2=0.999, weight_decay=0.0, eta=0.001, momentum=0.9)           INFO:root:Embedding parameters:  0.41 GParam,  0.82GB
INFO:root:Accessed weights per batch:  83.89MB
INFO:root:Forward, B: 512, E: 100000, T: 32, D: 128, L: 20, W: False, BW:  608.80GB/s, T: 138us
INFO:root:ForwardBackward, B: 512, E: 100000, T: 32, D: 128, L: 20, BW:  229.17GB/s, T: 1098us
```

Test Plan: CI

Reviewed By: ngimel

Differential Revision: D26038596

fbshipit-source-id: 5360395c1c3b1a062b38e5695239258e892c63c4
@mcarilli mcarilli added the module: cuda graphs Ability to capture and then replay streams of CUDA kernels label Apr 28, 2021
pytorch-bot Bot pushed a commit to pytorch/FBGEMM that referenced this pull request Feb 26, 2026
Summary:
Pull Request resolved: pytorch/pytorch#51004

Pull Request resolved: #493

Follow up on the failure case on FP16 stochastic rounding:
- pytorch/pytorch#50148
- D26006041 (05873a3)

From Natalia:
- pytorch/pytorch#50916 is the fix, philox_engine_inputs is deprecated btw so if you could refactor it to use philox_cuda_state that would be great.
- instructions to change the call https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/CUDAGeneratorImpl.h#L48-L83, it will be important to use philox_cuda_state with graph capture.

Benchmark:
- Before this Diff:
```
(base) [jianyuhuang@devgpu017.atn5.facebook.com: ~/fbsource/fbcode/hpc/ops/benchmarks] $  buck run mode/opt //hpc/ops/benchmarks:split_table_batched_embeddings_benchmark device -- --fp16 --stoc 2>&1 | tee before_diff.log
PARSING BUCK FILES: FINISHED IN 0.4s
CREATING ACTION GRAPH: FINISHED IN 0.0s
DOWNLOADED 0 ARTIFACTS, 0.00 BYTES, 0.0% CACHE MISS
BUILDING: FINISHED IN 5.3s (100%) 6474/6474 JOBS, 0 UPDATED
BUILD SUCCEEDED
DEBUG:root:Using fused exact_row_wise_adagrad with optimizer_args=OptimizerArgs(stochastic_rounding=True, gradient_clipping=False, max_gradient=1.0, learning_rate=0.1, eps=0.1, beta1=0.9, beta2=0.999, weight_decay=0.0, eta=0.001, momentum=0.9)
INFO:root:Embedding parameters:  0.41 GParam,  0.82GB
INFO:root:Accessed weights per batch:  83.89MB
INFO:root:Forward, B: 512, E: 100000, T: 32, D: 128, L: 20, W: False, BW:  607.48GB/s, T: 138us
INFO:root:ForwardBackward, B: 512, E: 100000, T: 32, D: 128, L: 20, BW:  220.85GB/s, T: 1139us
```

- After this Diff:
```
(base) [jianyuhuang@devgpu017.atn5.facebook.com: ~/fbsource/fbcode/hpc/ops/benchmarks] $  buck run mode/opt //hpc/ops/[5/1935]
ks:split_table_batched_embeddings_benchmark device -- --fp16 --stoc 2>&1 | tee after_diff.log
PARSING BUCK FILES: FINISHED IN 1.1s
CREATING ACTION GRAPH: FINISHED IN 0.0s
DEBUG:root:Using fused exact_row_wise_adagrad with optimizer_args=OptimizerArgs(stochastic_rounding=True, gradient_clipping=Fal
se, max_gradient=1.0, learning_rate=0.1, eps=0.1, beta1=0.9, beta2=0.999, weight_decay=0.0, eta=0.001, momentum=0.9)           INFO:root:Embedding parameters:  0.41 GParam,  0.82GB
INFO:root:Accessed weights per batch:  83.89MB
INFO:root:Forward, B: 512, E: 100000, T: 32, D: 128, L: 20, W: False, BW:  608.80GB/s, T: 138us
INFO:root:ForwardBackward, B: 512, E: 100000, T: 32, D: 128, L: 20, BW:  229.17GB/s, T: 1098us
```

Reviewed By: ngimel

Differential Revision: D26038596

fbshipit-source-id: 5360395c1c3b1a062b38e5695239258e892c63c4
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
…kernels (pytorch#50169)

Summary:
Immediately-upstreamable part of pytorch#50148.

This PR fixes what I'm fairly sure is a subtle bug with custom `Philox` class usage in jitted kernels.  `Philox` [constructors in kernels](https://github.com/pytorch/pytorch/blob/30206b504ed5e786ad2792061ec5ebe4b9b6abe9/torch/csrc/jit/codegen/cuda/codegen.cpp#L102) take the cuda rng generator's current offset.  The Philox constructor then carries out [`offset/4`](https://github.com/pytorch/pytorch/blob/677f0d6383cde8700c41a6ca8e69a6f1d9748b4e/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu#L13) (a uint64_t division) to compute its internal offset in its virtual Philox bitstream of 128-bit chunks.  In other words, it assumes the incoming offset is a multiple of 4.  But (in current code) that's not guaranteed.  For example, the increments used by [these eager kernels](https://github.com/pytorch/pytorch/blob/677f0d6383cde8700c41a6ca8e69a6f1d9748b4e/aten/src/ATen/native/cuda/Distributions.cu#L171-L216) could easily make offset not divisible by 4.

I figured the easiest fix was to round all incoming increments up to the nearest multiple of 4 in CUDAGeneratorImpl itself.

Another option would be to round the current offset up to the next multiple of 4 at the jit point of use.  But that would be a jit-specific offset jump, so jit rng kernels wouldn't have a prayer of being bitwise accurate with eager rng kernels that used non-multiple-of-4 offsets.  Restricting the offset to multiples of 4 for everyone at least gives jit rng the chance to match eager rng.  (Of course, there are still many other ways the numerics could diverge, like if a jit kernel launches a different number of threads than an eager kernel, or assigns threads to data elements differently.)

Pull Request resolved: pytorch#50169

Reviewed By: mruberry

Differential Revision: D25857934

Pulled By: ngimel

fbshipit-source-id: 43a75e2d0c8565651b0f12a5694c744fd86ece99
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
…ding (pytorch#51004)

Summary:
Pull Request resolved: pytorch#51004

Pull Request resolved: pytorch/FBGEMM#493

Follow up on the failure case on FP16 stochastic rounding:
- pytorch#50148
- D26006041

From Natalia:
- pytorch#50916 is the fix, philox_engine_inputs is deprecated btw so if you could refactor it to use philox_cuda_state that would be great.
- instructions to change the call https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/CUDAGeneratorImpl.h#L48-L83, it will be important to use philox_cuda_state with graph capture.

Benchmark:
- Before this Diff:
```
(base) [jianyuhuang@devgpu017.atn5.facebook.com: ~/fbsource/fbcode/hpc/ops/benchmarks] $  buck run mode/opt //hpc/ops/benchmarks:split_table_batched_embeddings_benchmark device -- --fp16 --stoc 2>&1 | tee before_diff.log
PARSING BUCK FILES: FINISHED IN 0.4s
CREATING ACTION GRAPH: FINISHED IN 0.0s
DOWNLOADED 0 ARTIFACTS, 0.00 BYTES, 0.0% CACHE MISS
BUILDING: FINISHED IN 5.3s (100%) 6474/6474 JOBS, 0 UPDATED
BUILD SUCCEEDED
DEBUG:root:Using fused exact_row_wise_adagrad with optimizer_args=OptimizerArgs(stochastic_rounding=True, gradient_clipping=False, max_gradient=1.0, learning_rate=0.1, eps=0.1, beta1=0.9, beta2=0.999, weight_decay=0.0, eta=0.001, momentum=0.9)
INFO:root:Embedding parameters:  0.41 GParam,  0.82GB
INFO:root:Accessed weights per batch:  83.89MB
INFO:root:Forward, B: 512, E: 100000, T: 32, D: 128, L: 20, W: False, BW:  607.48GB/s, T: 138us
INFO:root:ForwardBackward, B: 512, E: 100000, T: 32, D: 128, L: 20, BW:  220.85GB/s, T: 1139us
```

- After this Diff:
```
(base) [jianyuhuang@devgpu017.atn5.facebook.com: ~/fbsource/fbcode/hpc/ops/benchmarks] $  buck run mode/opt //hpc/ops/[5/1935]
ks:split_table_batched_embeddings_benchmark device -- --fp16 --stoc 2>&1 | tee after_diff.log
PARSING BUCK FILES: FINISHED IN 1.1s
CREATING ACTION GRAPH: FINISHED IN 0.0s
DEBUG:root:Using fused exact_row_wise_adagrad with optimizer_args=OptimizerArgs(stochastic_rounding=True, gradient_clipping=Fal
se, max_gradient=1.0, learning_rate=0.1, eps=0.1, beta1=0.9, beta2=0.999, weight_decay=0.0, eta=0.001, momentum=0.9)           INFO:root:Embedding parameters:  0.41 GParam,  0.82GB
INFO:root:Accessed weights per batch:  83.89MB
INFO:root:Forward, B: 512, E: 100000, T: 32, D: 128, L: 20, W: False, BW:  608.80GB/s, T: 138us
INFO:root:ForwardBackward, B: 512, E: 100000, T: 32, D: 128, L: 20, BW:  229.17GB/s, T: 1098us
```

Test Plan: CI

Reviewed By: ngimel

Differential Revision: D26038596

fbshipit-source-id: 5360395c1c3b1a062b38e5695239258e892c63c4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed module: cuda graphs Ability to capture and then replay streams of CUDA kernels oncall: jit Add this issue/PR to JIT oncall triage queue open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants