Skip to content

Lower as_strided_copy use fast path with slice #8734

Merged
tengyifei merged 13 commits intomasterfrom
piz/as_stride
Feb 27, 2025
Merged

Lower as_strided_copy use fast path with slice #8734
tengyifei merged 13 commits intomasterfrom
piz/as_stride

Conversation

@zpcore
Copy link
Copy Markdown
Member

@zpcore zpcore commented Feb 21, 2025

When we execute the following two code snippets regarding flash attention kernel in custom_kernel.py, they suppose to produce the same result.
a.

l, m = (v[..., 0] for v in aux[-2:])

b.

l, m = aux[-2:]
l = torch.ops.aten.slice(l, -1, 0, 1)
m = torch.ops.aten.slice(m, -1, 0, 1)

Both will be lowered through

at::Tensor XLANativeFunctions::as_strided_copy(
, the difference is that input argument stride and size will be one element fewer in code a compared with code b. With such argument difference, code a will be fallback into aten::take and this can trigger the following error when we call with SPMD:

F0223 07:18:45.157172  842998 hlo_sharding.cc:1024] Check failed: !IsManual() 

I plan to check in test_as_stride_use_slice.py in this PR.
Note 1. Failing test test_scan_layer_aot is not enabled until #8742 is resolved.
2. Failing test test_scan_weight_layer_aot is not enabled unti #8753 is resolved

@zpcore zpcore changed the title slice lower Lower as_strided_copy use fast path with slice Feb 23, 2025
@zpcore zpcore marked this pull request as ready for review February 23, 2025 08:17
@zpcore zpcore requested a review from tengyifei February 23, 2025 08:19
@zpcore zpcore requested a review from bhavya01 February 24, 2025 17:37
Comment thread torch_xla/csrc/aten_xla_type.cpp Outdated
Comment thread test/cpp/test_aten_xla_tensor_3.cpp
Comment thread torch_xla/experimental/custom_kernel.py Outdated
Comment thread test/test_as_stride_use_slice.py
Comment thread test/cpp/test_aten_xla_tensor_3.cpp
Comment thread torch_xla/experimental/custom_kernel.py Outdated
Comment thread test/test_as_stride_use_slice.py Outdated
Comment thread test/test_as_stride_use_slice.py Outdated
Comment thread test/test_as_stride_use_slice.py Outdated
Comment thread test/test_as_stride_use_slice.py Outdated
Comment thread test/test_as_stride_use_slice.py Outdated
Comment thread test/test_as_stride_use_slice.py Outdated
Comment thread torch_xla/csrc/aten_xla_type.cpp
Comment thread torch_xla/csrc/aten_xla_type.cpp Outdated
Comment thread torch_xla/csrc/aten_xla_type.cpp Outdated
Comment thread test/test_as_stride_use_slice.py Outdated
Comment thread test/test_as_stride_use_slice.py Outdated
Comment thread test/test_as_stride_use_slice.py Outdated
Comment thread test/tpu/run_tests.sh
Comment thread torch_xla/csrc/aten_xla_type.cpp Outdated
Comment thread torch_xla/csrc/aten_xla_type.cpp Outdated
Comment thread torch_xla/csrc/aten_xla_type.cpp Outdated
Comment thread torch_xla/csrc/aten_xla_type.cpp Outdated
Comment thread torch_xla/csrc/aten_xla_type.cpp Outdated
if (stride_mul != stride[j]) {
if (skip_dim == -1) {
skip_dim = i;
K = stride[j] / stride_mul;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check that stride[j] can be evenly divided by stride_mul and exit if the remainder is not 0?

Copy link
Copy Markdown
Member Author

@zpcore zpcore Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't need to be 'evenly divided' for stride[j] as long as all indexes before j of stride matches with the cumulative product of tensor dim.

Comment thread torch_xla/csrc/aten_xla_type.cpp Outdated
Comment thread torch_xla/csrc/aten_xla_type.cpp Outdated
Comment thread torch_xla/csrc/aten_xla_type.cpp Outdated
Comment thread torch_xla/csrc/aten_xla_type.cpp Outdated
Comment thread torch_xla/csrc/aten_xla_type.cpp Outdated
Comment thread test/test_as_stride_use_slice.py Outdated
Comment thread test/test_as_stride_use_slice.py Outdated
Comment thread test/test_as_stride_use_slice.py Outdated
Comment thread test/test_as_stride_use_slice.py Outdated
Comment thread test/test_as_stride_use_slice.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants