[pytorch] Accelerate indexing_backward_kernel with duplicates#99441
[pytorch] Accelerate indexing_backward_kernel with duplicates#99441valentinandrei wants to merge 21 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/99441
Note: Links to docs will display an error until the docs builds have been completed. ❗ 2 Active SEVsThere are 2 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit a007248: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…es ... attempt 10
…mparing to reference ... attempt 11
…eterministic ... attempt 12
|
@pytorchbot label "topic: performance" |
|
❌ 🤖 pytorchbot command failed: Try |
|
/easycla |
1 similar comment
|
/easycla |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
There was a problem hiding this comment.
@ngimel I'd suggest reverting this PR, since it's consistently breaking ROCm CI as seen here: https://hud.pytorch.org/ci/pytorch/pytorch/main?name_filter=rocm
I have added some comments on this PR where I think the code is very CUDA-specific.
| } | ||
| WARP_SYNC(); | ||
| for (int offset = 16; offset > 0; offset /= 2) { | ||
| gradient += WARP_SHFL_DOWN(gradient, offset); |
There was a problem hiding this comment.
@ngimel Would this work correctly for a warp_size of 64 (AMD GPUs)?
There was a problem hiding this comment.
Yeah we shouldn't hardcode offset here, you are right. Again, we'd need compile-time-known warp width.
|
Sorry about that @jithunnair-amd the tests were clean. |
|
@pytorchbot revert -c nosignal -m "breaks ROCM" |
|
@pytorchbot successfully started a revert job. Check the current status here. |
|
@valentinandrei your PR has been successfully reverted. |
…#99441)" This reverts commit 97afbcb. Reverted #99441 on behalf of https://github.com/ngimel due to breaks ROCM ([comment](#99441 (comment)))
…pytorch#99441)" This reverts commit 97afbcb. Reverted pytorch#99441 on behalf of https://github.com/ngimel due to breaks ROCM ([comment](pytorch#99441 (comment)))
…t broke ROCm CI
…code that broke ROCm CI" This reverts commit 1eb0fb7.
…t broke ROCm CI ... attempt 2
…attempt 2) (#100505) By knowing the stride value ahead of time, we can simplify the kernel code as follows: If stride == 1 we can use the whole warp to reduce the gradients If stride < warp_size we don't need the internal while (start_feature < stride) loop as blockDim.x is always 32 This changes improve the performance of the kernel when duplicates are present and do not affect the performance with low amount of duplicates. The implementation is deterministic. The proposed implementation uses opmath_t to accumulate in registers the gradient values so when using FP16/BF16 it may overflow if the number of elements is large. This is different from the initial implementation who accumulates in scalar_t and does not overflow. In addition, when the stride is 1, we are using warp shuffles to sum the gradient so the order of the addition is slightly different than a reference implementation which causes some minor numerical differences when compared to a reference. TEST CODE: ``` # The first element is the number of iterations. # The second represents the number of unique elements. If # set to 0, the number of unique elements is equal to the # number of elements. # The remaining elements are the tensor dimensions. basic_indexing_tests = [ [10, 0, 12345], [10, 4, 12345], [10, 16, 512, 512, 32], [10, 0, 4, 4], [10, 0, 32, 32], [10, 8, 32, 32], [10, 8, 64, 32, 16], [10, 0, 64, 32, 16], [10, 16, 512, 512, 32], [10, 0, 675, 999, 13], [10, 0, 123, 456, 31], [10, 0, 512, 512, 32], [10, 4, 512, 512, 32], [10, 2, 512, 512, 32], [10, 0, 128, 128, 16, 16], [10, 8, 128, 126, 16, 16], [10, 4, 128, 126, 16, 16], [10, 0, 64, 64, 16, 16, 16], [10, 8, 64, 64, 16, 16, 16], [10, 2, 64, 64, 16, 16, 16], [10, 1, 64, 64, 16, 16, 16], ] def run_basic_indexing_on_device(x, index, expected, device_string, iters): x_dev = x.to(device_string) x_dev = x_dev.detach().requires_grad_() index_dev = index.to(device_string) # Run backward pass; keep gradients and measure time torch.cuda.synchronize() t_bw_s = time() for _ in range(iters): y = x_dev[index_dev] z = y.sum() z.backward() torch.cuda.synchronize() t_bw_s = (time() - t_bw_s) / iters return (x_dev.grad, t_bw_s) def run_basic_indexing_test(test_input): tensor_size = tuple(test_input[:5]) niters = test_input[0] num_unique = test_input[1] tensor_size = tuple(test_input[2:]) numel = 1 for dim in tensor_size: numel *= dim if num_unique == 0: num_unique = numel index = torch.randint(0, num_unique, tensor_size, dtype=torch.long, device="cpu") x = torch.randn((numel,), dtype=torch.float32, device="cuda") index = index.detach() x = x.detach().requires_grad_() (cpu_grad, t_bw_cpu) = run_basic_indexing_on_device(x, index, numel / 2, "cpu", 1) (gpu_grad, t_bw_gpu) = run_basic_indexing_on_device(x, index, numel / 2, "cuda", 1) max_delta = torch.max(torch.abs(cpu_grad - gpu_grad.to("cpu"))) missmatches = torch.nonzero(torch.abs(cpu_grad - gpu_grad.to("cpu"))) (gpu_grad_perf, t_gpu) = run_basic_indexing_on_device( x, index, numel / 2, "cuda", niters ) print( "test = {}, delta = {:.5f}, missmatches = {} duration_ms = {:.3f}".format( tuple(test_input), max_delta, missmatches, t_gpu * 1000.0 ) ) if torch.numel(missmatches) > 0: print("cpu grad = {}", cpu_grad[missmatches]) print("gpu grad = {}", gpu_grad[missmatches]) ``` RESULTS: ``` Default Implementation test = (1, 0, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.726 test = (1, 4, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.867 test = (1, 16, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 80.514 test = (1, 0, 4, 4), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.689 test = (1, 0, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.547 test = (1, 8, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.537 test = (1, 8, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1.199 test = (1, 0, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.584 test = (1, 16, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 80.055 test = (1, 0, 675, 999, 13), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 8.411 test = (1, 0, 123, 456, 31), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 2.419 test = (1, 0, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 8.048 test = (1, 4, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 307.633 test = (1, 2, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 606.403 test = (1, 0, 128, 128, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 4.099 test = (1, 8, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 76.813 test = (1, 4, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 148.760 test = (1, 0, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 16.547 test = (1, 8, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 317.583 test = (1, 2, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1204.800 test = (1, 1, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 2412.133 Small Stride Kernel Version test = (1, 0, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.904 test = (1, 4, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 2.156 test = (1, 16, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 308.878 test = (1, 0, 4, 4), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.566 test = (1, 0, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.540 test = (1, 8, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.550 test = (1, 8, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 2.868 test = (1, 0, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.656 test = (1, 16, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 307.856 test = (1, 0, 675, 999, 13), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 6.624 test = (1, 0, 123, 456, 31), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1.837 test = (1, 0, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 6.274 test = (1, 4, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1127.040 test = (1, 2, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 2123.942 test = (1, 0, 128, 128, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 3.282 test = (1, 8, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 288.997 test = (1, 4, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 547.267 test = (1, 0, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 12.844 test = (1, 8, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1178.934 test = (1, 2, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 4262.042 test = (1, 1, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 8172.318 Stride 1 Kernel Version test = (1, 0, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.692 test = (1, 4, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.834 test = (1, 16, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 81.023 test = (1, 0, 4, 4), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.631 test = (100, 0, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.491 test = (100, 8, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.477 test = (50, 8, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.561 test = (50, 0, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.516 test = (16, 10, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 126.455 test = (10, 0, 675, 999, 13), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 8.238 test = (10, 0, 123, 456, 31), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1.520 test = (10, 0, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 7.854 test = (10, 4, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 306.327 test = (10, 2, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 610.498 test = (5, 0, 128, 128, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 3.684 test = (5, 8, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 75.604 test = (5, 4, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 148.679 test = (1, 0, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 16.525 test = (1, 8, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 315.095 test = (1, 2, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1214.715 ``` Pull Request resolved: #100505 Approved by: https://github.com/ngimel
…h#99441) By knowing the stride value ahead of time, we can simplify the kernel code as follows: If `stride == 1` we can use the whole warp to reduce the gradients If `stride < warp_size` we don't need the internal `while (start_feature < stride)` loop as `blockDim.x` is always 32 This changes improve the performance of the kernel when duplicates are present and do not affect the performance with low amount of duplicates. The implementation is deterministic. The proposed implementation uses `opmath_t` to accumulate in registers the gradient values so when using FP16/BF16 it may overflow if the number of elements is large. This is different from the initial implementation who accumulates in `scalar_t` and does not overflow. In addition, when the stride is 1, we are using warp shuffles to sum the gradient so the order of the addition is slightly different than a reference implementation which causes some minor numerical differences when compared to a reference. TEST CODE: ``` # The first element is the number of iterations. # The second represents the number of unique elements. If # set to 0, the number of unique elements is equal to the # number of elements. # The remaining elements are the tensor dimensions. basic_indexing_tests = [ [10, 0, 12345], [10, 4, 12345], [10, 16, 512, 512, 32], [10, 0, 4, 4], [10, 0, 32, 32], [10, 8, 32, 32], [10, 8, 64, 32, 16], [10, 0, 64, 32, 16], [10, 16, 512, 512, 32], [10, 0, 675, 999, 13], [10, 0, 123, 456, 31], [10, 0, 512, 512, 32], [10, 4, 512, 512, 32], [10, 2, 512, 512, 32], [10, 0, 128, 128, 16, 16], [10, 8, 128, 126, 16, 16], [10, 4, 128, 126, 16, 16], [10, 0, 64, 64, 16, 16, 16], [10, 8, 64, 64, 16, 16, 16], [10, 2, 64, 64, 16, 16, 16], [10, 1, 64, 64, 16, 16, 16], ] def run_basic_indexing_on_device(x, index, expected, device_string, iters): x_dev = x.to(device_string) x_dev = x_dev.detach().requires_grad_() index_dev = index.to(device_string) # Run backward pass; keep gradients and measure time torch.cuda.synchronize() t_bw_s = time() for _ in range(iters): y = x_dev[index_dev] z = y.sum() z.backward() torch.cuda.synchronize() t_bw_s = (time() - t_bw_s) / iters return (x_dev.grad, t_bw_s) def run_basic_indexing_test(test_input): tensor_size = tuple(test_input[:5]) niters = test_input[0] num_unique = test_input[1] tensor_size = tuple(test_input[2:]) numel = 1 for dim in tensor_size: numel *= dim if num_unique == 0: num_unique = numel index = torch.randint(0, num_unique, tensor_size, dtype=torch.long, device="cpu") x = torch.randn((numel,), dtype=torch.float32, device="cuda") index = index.detach() x = x.detach().requires_grad_() (cpu_grad, t_bw_cpu) = run_basic_indexing_on_device(x, index, numel / 2, "cpu", 1) (gpu_grad, t_bw_gpu) = run_basic_indexing_on_device(x, index, numel / 2, "cuda", 1) max_delta = torch.max(torch.abs(cpu_grad - gpu_grad.to("cpu"))) missmatches = torch.nonzero(torch.abs(cpu_grad - gpu_grad.to("cpu"))) (gpu_grad_perf, t_gpu) = run_basic_indexing_on_device( x, index, numel / 2, "cuda", niters ) print( "test = {}, delta = {:.5f}, missmatches = {} duration_ms = {:.3f}".format( tuple(test_input), max_delta, missmatches, t_gpu * 1000.0 ) ) if torch.numel(missmatches) > 0: print("cpu grad = {}", cpu_grad[missmatches]) print("gpu grad = {}", gpu_grad[missmatches]) ``` RESULTS: ``` Default Implementation test = (1, 0, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.726 test = (1, 4, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.867 test = (1, 16, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 80.514 test = (1, 0, 4, 4), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.689 test = (1, 0, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.547 test = (1, 8, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.537 test = (1, 8, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1.199 test = (1, 0, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.584 test = (1, 16, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 80.055 test = (1, 0, 675, 999, 13), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 8.411 test = (1, 0, 123, 456, 31), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 2.419 test = (1, 0, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 8.048 test = (1, 4, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 307.633 test = (1, 2, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 606.403 test = (1, 0, 128, 128, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 4.099 test = (1, 8, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 76.813 test = (1, 4, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 148.760 test = (1, 0, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 16.547 test = (1, 8, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 317.583 test = (1, 2, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1204.800 test = (1, 1, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 2412.133 Small Stride Kernel Version test = (1, 0, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.904 test = (1, 4, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 2.156 test = (1, 16, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 308.878 test = (1, 0, 4, 4), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.566 test = (1, 0, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.540 test = (1, 8, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.550 test = (1, 8, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 2.868 test = (1, 0, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.656 test = (1, 16, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 307.856 test = (1, 0, 675, 999, 13), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 6.624 test = (1, 0, 123, 456, 31), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1.837 test = (1, 0, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 6.274 test = (1, 4, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1127.040 test = (1, 2, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 2123.942 test = (1, 0, 128, 128, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 3.282 test = (1, 8, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 288.997 test = (1, 4, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 547.267 test = (1, 0, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 12.844 test = (1, 8, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1178.934 test = (1, 2, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 4262.042 test = (1, 1, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 8172.318 Stride 1 Kernel Version test = (1, 0, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.692 test = (1, 4, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.834 test = (1, 16, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 81.023 test = (1, 0, 4, 4), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.631 test = (100, 0, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.491 test = (100, 8, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.477 test = (50, 8, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.561 test = (50, 0, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.516 test = (16, 10, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 126.455 test = (10, 0, 675, 999, 13), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 8.238 test = (10, 0, 123, 456, 31), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1.520 test = (10, 0, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 7.854 test = (10, 4, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 306.327 test = (10, 2, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 610.498 test = (5, 0, 128, 128, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 3.684 test = (5, 8, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 75.604 test = (5, 4, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 148.679 test = (1, 0, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 16.525 test = (1, 8, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 315.095 test = (1, 2, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1214.715 ``` Pull Request resolved: pytorch#99441 Approved by: https://github.com/ngimel
…pytorch#99441)" This reverts commit 0d2f3ae. Reverted pytorch#99441 on behalf of https://github.com/ngimel due to breaks ROCM ([comment](pytorch#99441 (comment)))
…ch#99441 attempt 2) (pytorch#100505) By knowing the stride value ahead of time, we can simplify the kernel code as follows: If stride == 1 we can use the whole warp to reduce the gradients If stride < warp_size we don't need the internal while (start_feature < stride) loop as blockDim.x is always 32 This changes improve the performance of the kernel when duplicates are present and do not affect the performance with low amount of duplicates. The implementation is deterministic. The proposed implementation uses opmath_t to accumulate in registers the gradient values so when using FP16/BF16 it may overflow if the number of elements is large. This is different from the initial implementation who accumulates in scalar_t and does not overflow. In addition, when the stride is 1, we are using warp shuffles to sum the gradient so the order of the addition is slightly different than a reference implementation which causes some minor numerical differences when compared to a reference. TEST CODE: ``` # The first element is the number of iterations. # The second represents the number of unique elements. If # set to 0, the number of unique elements is equal to the # number of elements. # The remaining elements are the tensor dimensions. basic_indexing_tests = [ [10, 0, 12345], [10, 4, 12345], [10, 16, 512, 512, 32], [10, 0, 4, 4], [10, 0, 32, 32], [10, 8, 32, 32], [10, 8, 64, 32, 16], [10, 0, 64, 32, 16], [10, 16, 512, 512, 32], [10, 0, 675, 999, 13], [10, 0, 123, 456, 31], [10, 0, 512, 512, 32], [10, 4, 512, 512, 32], [10, 2, 512, 512, 32], [10, 0, 128, 128, 16, 16], [10, 8, 128, 126, 16, 16], [10, 4, 128, 126, 16, 16], [10, 0, 64, 64, 16, 16, 16], [10, 8, 64, 64, 16, 16, 16], [10, 2, 64, 64, 16, 16, 16], [10, 1, 64, 64, 16, 16, 16], ] def run_basic_indexing_on_device(x, index, expected, device_string, iters): x_dev = x.to(device_string) x_dev = x_dev.detach().requires_grad_() index_dev = index.to(device_string) # Run backward pass; keep gradients and measure time torch.cuda.synchronize() t_bw_s = time() for _ in range(iters): y = x_dev[index_dev] z = y.sum() z.backward() torch.cuda.synchronize() t_bw_s = (time() - t_bw_s) / iters return (x_dev.grad, t_bw_s) def run_basic_indexing_test(test_input): tensor_size = tuple(test_input[:5]) niters = test_input[0] num_unique = test_input[1] tensor_size = tuple(test_input[2:]) numel = 1 for dim in tensor_size: numel *= dim if num_unique == 0: num_unique = numel index = torch.randint(0, num_unique, tensor_size, dtype=torch.long, device="cpu") x = torch.randn((numel,), dtype=torch.float32, device="cuda") index = index.detach() x = x.detach().requires_grad_() (cpu_grad, t_bw_cpu) = run_basic_indexing_on_device(x, index, numel / 2, "cpu", 1) (gpu_grad, t_bw_gpu) = run_basic_indexing_on_device(x, index, numel / 2, "cuda", 1) max_delta = torch.max(torch.abs(cpu_grad - gpu_grad.to("cpu"))) missmatches = torch.nonzero(torch.abs(cpu_grad - gpu_grad.to("cpu"))) (gpu_grad_perf, t_gpu) = run_basic_indexing_on_device( x, index, numel / 2, "cuda", niters ) print( "test = {}, delta = {:.5f}, missmatches = {} duration_ms = {:.3f}".format( tuple(test_input), max_delta, missmatches, t_gpu * 1000.0 ) ) if torch.numel(missmatches) > 0: print("cpu grad = {}", cpu_grad[missmatches]) print("gpu grad = {}", gpu_grad[missmatches]) ``` RESULTS: ``` Default Implementation test = (1, 0, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.726 test = (1, 4, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.867 test = (1, 16, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 80.514 test = (1, 0, 4, 4), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.689 test = (1, 0, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.547 test = (1, 8, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.537 test = (1, 8, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1.199 test = (1, 0, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.584 test = (1, 16, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 80.055 test = (1, 0, 675, 999, 13), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 8.411 test = (1, 0, 123, 456, 31), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 2.419 test = (1, 0, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 8.048 test = (1, 4, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 307.633 test = (1, 2, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 606.403 test = (1, 0, 128, 128, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 4.099 test = (1, 8, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 76.813 test = (1, 4, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 148.760 test = (1, 0, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 16.547 test = (1, 8, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 317.583 test = (1, 2, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1204.800 test = (1, 1, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 2412.133 Small Stride Kernel Version test = (1, 0, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.904 test = (1, 4, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 2.156 test = (1, 16, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 308.878 test = (1, 0, 4, 4), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.566 test = (1, 0, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.540 test = (1, 8, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.550 test = (1, 8, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 2.868 test = (1, 0, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.656 test = (1, 16, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 307.856 test = (1, 0, 675, 999, 13), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 6.624 test = (1, 0, 123, 456, 31), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1.837 test = (1, 0, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 6.274 test = (1, 4, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1127.040 test = (1, 2, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 2123.942 test = (1, 0, 128, 128, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 3.282 test = (1, 8, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 288.997 test = (1, 4, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 547.267 test = (1, 0, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 12.844 test = (1, 8, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1178.934 test = (1, 2, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 4262.042 test = (1, 1, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 8172.318 Stride 1 Kernel Version test = (1, 0, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.692 test = (1, 4, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.834 test = (1, 16, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 81.023 test = (1, 0, 4, 4), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.631 test = (100, 0, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.491 test = (100, 8, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.477 test = (50, 8, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.561 test = (50, 0, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.516 test = (16, 10, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 126.455 test = (10, 0, 675, 999, 13), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 8.238 test = (10, 0, 123, 456, 31), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1.520 test = (10, 0, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 7.854 test = (10, 4, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 306.327 test = (10, 2, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 610.498 test = (5, 0, 128, 128, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 3.684 test = (5, 8, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 75.604 test = (5, 4, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 148.679 test = (1, 0, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 16.525 test = (1, 8, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 315.095 test = (1, 2, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1214.715 ``` Pull Request resolved: pytorch#100505 Approved by: https://github.com/ngimel
By knowing the stride value ahead of time, we can simplify the kernel code as follows:
If
stride == 1we can use the whole warp to reduce the gradientsIf
stride < warp_sizewe don't need the internalwhile (start_feature < stride)loop asblockDim.xis always 32This changes improve the performance of the kernel when duplicates are present and do not affect the performance with low amount of duplicates. The implementation is deterministic.
The proposed implementation uses
opmath_tto accumulate in registers the gradient values so when using FP16/BF16 it may overflow if the number of elements is large. This is different from the initial implementation who accumulates inscalar_tand does not overflow. In addition, when the stride is 1, we are using warp shuffles to sum the gradient so the order of the addition is slightly different than a reference implementation which causes some minor numerical differences when compared to a reference.TEST CODE:
RESULTS: