Skip to content

Commit 17ae642

Browse files
committed
Update on "Add missing types to inductor IR assert"
Unclear if there is a more efficient way to define the allowed types for IR (or if we even need this, perhaps we just ditch the assert?) But Inductor experts can deteremine if these added ops are appropriate and if so they fix the reported issue. Fixes #96204 cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
2 parents fc4ac5a + f0039d6 commit 17ae642

24 files changed

Lines changed: 237 additions & 637 deletions

.github/scripts/generate_ci_workflows.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -271,18 +271,6 @@ class OperatingSystem:
271271
isolated_workflow=True,
272272
),
273273
),
274-
BinaryBuildWorkflow(
275-
os=OperatingSystem.MACOS,
276-
package_type="libtorch",
277-
abi_version=generate_binary_build_matrix.PRE_CXX11_ABI,
278-
build_configs=generate_binary_build_matrix.generate_libtorch_matrix(
279-
OperatingSystem.MACOS, generate_binary_build_matrix.PRE_CXX11_ABI
280-
),
281-
ciflow_config=CIFlowConfig(
282-
labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_LIBTORCH},
283-
isolated_workflow=True,
284-
),
285-
),
286274
BinaryBuildWorkflow(
287275
os=OperatingSystem.MACOS_ARM64,
288276
package_type="wheel",

.github/workflows/generated-macos-binary-libtorch-pre-cxx11-nightly.yml

Lines changed: 0 additions & 502 deletions
This file was deleted.

.lintrunner.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -920,3 +920,11 @@ init_command = [
920920
'--output-name=bazel',
921921
]
922922
is_formatter = true
923+
924+
[[linter]]
925+
code = 'LINTRUNNER_VERSION'
926+
include_patterns = ['**']
927+
command = [
928+
'python3',
929+
'tools/linter/adapters/lintrunner_version_linter.py'
930+
]

aten/src/ATen/native/native_functions.yaml

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13896,7 +13896,7 @@
1389613896
CUDA, NestedTensorCUDA: native_multi_head_attention_cuda
1389713897
autogen: _native_multi_head_attention.out
1389813898

13899-
- func: scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False) -> Tensor
13899+
- func: scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor
1390013900
python_module: nn
1390113901
variants: function
1390213902
autogen: scaled_dot_product_attention.out
@@ -13908,55 +13908,55 @@
1390813908
autogen: _scaled_dot_product_attention.out
1390913909

1391013910
# This aten function is kept so that we can test the choice function from Python
13911-
- func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False) -> int
13911+
- func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> int
1391213912
dispatch:
1391313913
Meta: _fused_sdp_choice_meta
1391413914
CPU, NestedTensorCPU: _fused_sdp_choice_cpp
1391513915
CUDA, NestedTensorCUDA: _fused_sdp_choice_cuda
1391613916

13917-
- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None) -> (Tensor, Tensor)
13917+
- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None) -> (Tensor, Tensor)
1391813918
variants: function
1391913919

13920-
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False) -> (Tensor ouput, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, int philox_seed, int philox_offset, Tensor debug_attn_mask)
13920+
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor ouput, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, int philox_seed, int philox_offset, Tensor debug_attn_mask)
1392113921
dispatch:
1392213922
CUDA: _scaled_dot_product_flash_attention_cuda
1392313923
NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda
1392413924

13925-
- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, int philox_seed, int philox_offset) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
13925+
- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, int philox_seed, int philox_offse, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
1392613926
variants: function
1392713927
dispatch:
1392813928
CUDA: _scaled_dot_product_flash_attention_backward_cuda
1392913929

13930-
- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, bool is_causal=False) -> (Tensor, Tensor)
13930+
- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor)
1393113931
dispatch:
1393213932
CUDA: _scaled_dot_product_efficient_attention_cuda
1393313933
NestedTensorCUDA: _scaled_dot_product_efficient_attention_nestedtensor_cuda
1393413934

13935-
- func: _scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, bool is_causal=False, bool chunk_grad_outputs=False) -> (Tensor, Tensor, Tensor)
13935+
- func: _scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, bool is_causal=False, bool chunk_grad_outputs=False, *, float? scale=None) -> (Tensor, Tensor, Tensor)
1393613936
dispatch:
1393713937
CUDA: _scaled_dot_product_efficient_attention_backward_cuda
1393813938

1393913939
- func: _chunk_grad_outputs_efficient_attention(Tensor query, Tensor key, Tensor value, bool is_causal=False) -> bool
1394013940
dispatch:
1394113941
CUDA: _chunk_grad_outputs_efficient_attention
1394213942

13943-
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, bool return_debug_mask) -> (Tensor output, Tensor softmax_logsumexp, int philox_seed, int philox_offset, Tensor debug_attn_mask)
13943+
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, int philox_seed, int philox_offset, Tensor debug_attn_mask)
1394413944
variants: function
1394513945
dispatch:
1394613946
CUDA: _flash_attention_forward
1394713947

13948-
- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, int philox_seed, int philox_offset) -> (Tensor, Tensor, Tensor)
13948+
- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, int philox_seed, int philox_offset, *, float? scale=None) -> (Tensor, Tensor, Tensor)
1394913949
variants: function
1395013950
dispatch:
1395113951
CUDA: _flash_attention_backward
1395213952

1395313953
# Returns ouput, logsumexp if compute_logsumexp
13954-
- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False) -> (Tensor, Tensor)
13954+
- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False, *, float? scale=None) -> (Tensor, Tensor)
1395513955
variants: function
1395613956
dispatch:
1395713957
CUDA: _efficient_attention_forward
1395813958

13959-
- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, bool is_causal=False, bool chunk_grad_outputs=False) -> (Tensor, Tensor, Tensor)
13959+
- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, bool is_causal=False, bool chunk_grad_outputs=False, *, float? scale=None) -> (Tensor, Tensor, Tensor)
1396013960
variants: function
1396113961
dispatch:
1396213962
CUDA: _efficient_attention_backward

aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,8 @@ _scaled_dot_product_flash_attention_nestedtensor_cuda(
689689
const Tensor& value,
690690
double dropout_p,
691691
bool is_causal,
692-
bool return_debug_mask) {
692+
bool return_debug_mask,
693+
c10::optional<double> scale) {
693694
Tensor query_buffer_reshaped, key_buffer_reshaped, value_buffer_reshaped,
694695
cumulative_sequence_length_q, cumulative_sequence_length_kv, output_shape;
695696
int64_t max_seqlen_batch_q{0}, max_seqlen_batch_kv{0};
@@ -716,7 +717,8 @@ _scaled_dot_product_flash_attention_nestedtensor_cuda(
716717
max_seqlen_batch_kv,
717718
dropout_p,
718719
is_causal,
719-
return_debug_mask);
720+
return_debug_mask,
721+
scale);
720722
// Reshape output to convert nnz to batch_size and seq_len
721723
attention = wrap_buffer(attention.view(-1), output_shape).transpose(1, 2);
722724
return std::make_tuple(
@@ -737,7 +739,8 @@ _scaled_dot_product_efficient_attention_nestedtensor_cuda(
737739
const Tensor& key,
738740
const Tensor& value,
739741
bool compute_log_sumexp,
740-
bool is_causal) {
742+
bool is_causal,
743+
c10::optional<double> scale) {
741744
Tensor query_buffer_reshaped, key_buffer_reshaped, value_buffer_reshaped,
742745
cumulative_sequence_length_q, cumulative_sequence_length_kv, output_shape;
743746
int64_t max_seqlen_batch_q{0};
@@ -760,7 +763,8 @@ _scaled_dot_product_efficient_attention_nestedtensor_cuda(
760763
cumulative_sequence_length_kv,
761764
max_seqlen_batch_q,
762765
compute_log_sumexp,
763-
is_causal);
766+
is_causal,
767+
scale);
764768
// Reshape output to convert nnz to batch_size and seq_len
765769
Tensor attention = std::get<0>(attention_and_logsumexp);
766770
attention = wrap_buffer(attention.view(-1), output_shape).transpose(1, 2);

aten/src/ATen/native/transformers/attention.cpp

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> native_decoder_only_multi_head_attent
663663
}
664664

665665
int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Tensor& value,
666-
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal){
666+
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, c10::optional<double> scale){
667667
return static_cast<int64_t>(sdp::SDPBackend::math);
668668
}
669669

@@ -673,7 +673,8 @@ int64_t _fused_sdp_choice_meta(
673673
const Tensor& value,
674674
const c10::optional<Tensor>& attn_mask_,
675675
double dropout_p,
676-
bool is_causal) {
676+
bool is_causal,
677+
c10::optional<double> scale) {
677678
auto query_key_set = query_.key_set();
678679
bool has_cuda = query_key_set.has(c10::DispatchKey::CUDA);
679680
if (has_cuda) {
@@ -684,7 +685,8 @@ int64_t _fused_sdp_choice_meta(
684685
value,
685686
attn_mask_,
686687
dropout_p,
687-
is_causal);
688+
is_causal,
689+
scale);
688690
return choice_int;
689691
}
690692
return static_cast<int64_t>(sdp::SDPBackend::math);
@@ -703,11 +705,11 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention(
703705
if (!need_attn_weights) {
704706
return std::make_tuple(
705707
at::scaled_dot_product_attention(
706-
query_, key, value, attn_mask_, dropout_p, is_causal),
708+
query_, key, value, attn_mask_, dropout_p, is_causal, c10::nullopt),
707709
Tensor());
708710
}
709711
return at::_scaled_dot_product_attention_math(
710-
query_, key, value, attn_mask_, dropout_p, is_causal);
712+
query_, key, value, attn_mask_, dropout_p, is_causal, c10::nullopt);
711713
}
712714

713715
inline void validate_sdpa_input(
@@ -716,7 +718,8 @@ inline void validate_sdpa_input(
716718
const Tensor& value,
717719
const c10::optional<Tensor>& attn_mask_,
718720
double dropout_p,
719-
bool is_causal) {
721+
bool is_causal,
722+
c10::optional<double> scale) {
720723
TORCH_CHECK(
721724
query_.dtype() == key.dtype() && query_.dtype() == value.dtype(),
722725
"Expected query, key, and value to have the same dtype, but got query.dtype: ",
@@ -771,26 +774,27 @@ Tensor scaled_dot_product_attention(
771774
const Tensor& value,
772775
const c10::optional<Tensor>& attn_mask_,
773776
double dropout_p,
774-
bool is_causal) {
775-
validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_causal);
777+
bool is_causal,
778+
c10::optional<double> scale) {
779+
validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_causal, scale);
776780
int64_t choice_int = static_cast<int64_t>(sdp::SDPBackend::math);
777781
if (query_.device().type() == DeviceType::CUDA){
778782
choice_int = _fused_sdp_choice_stub(query_.device().type(),
779-
query_, key, value, attn_mask_, dropout_p, is_causal);
783+
query_, key, value, attn_mask_, dropout_p, is_causal, scale);
780784
}
781785
sdp::SDPBackend backend = static_cast<sdp::SDPBackend>(choice_int);
782786
switch (backend) {
783787
case sdp::SDPBackend::flash_attention: {
784788
auto out_lse_softmax = at::_scaled_dot_product_flash_attention(
785-
query_, key, value, dropout_p, is_causal);
789+
query_, key, value, dropout_p, is_causal, false /*return_debug_mask*/, scale);
786790
return std::get<0>(out_lse_softmax);
787791
}
788792
case sdp::SDPBackend::efficient_attention: {
789793
bool compute_logsumexp =
790794
(query_.requires_grad() || key.requires_grad() ||
791795
value.requires_grad());
792796
auto out_and_lse = at::_scaled_dot_product_efficient_attention(
793-
query_, key, value, compute_logsumexp, is_causal);
797+
query_, key, value, compute_logsumexp, is_causal, scale);
794798
return std::get<0>(out_and_lse);
795799
}
796800
case sdp::SDPBackend::math:
@@ -800,7 +804,9 @@ Tensor scaled_dot_product_attention(
800804
value,
801805
attn_mask_,
802806
dropout_p,
803-
is_causal));
807+
is_causal,
808+
c10::nullopt, /*dropout_mask*/
809+
scale));
804810
default:
805811
TORCH_CHECK(
806812
false,
@@ -812,7 +818,7 @@ Tensor scaled_dot_product_attention(
812818
std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
813819
const Tensor& query_, const Tensor& key, const Tensor& value,
814820
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal,
815-
const c10::optional<Tensor>& dropout_mask) {
821+
const c10::optional<Tensor>& dropout_mask, c10::optional<double> scale) {
816822
C10_LOG_API_USAGE_ONCE("torch.sdpa.math_fallback");
817823
if (query_.is_nested() || key.is_nested() || value.is_nested()) {
818824
TORCH_CHECK(
@@ -823,10 +829,9 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
823829
auto attn_mask = attn_mask_;
824830
// Naive, composite implementation defined here.
825831

826-
// Scale q,k before matmul for stability see https://tinyurl.com/sudb9s96 for math
827-
const auto embed_size = SymFloat(query_.sym_size(-1));
828-
const auto scaling_factor = embed_size.sqrt().sqrt();
829-
const auto query = query_ / scaling_factor;
832+
// Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math
833+
const auto scaling_factor = sdp::calculate_scale(query_, scale).sqrt();
834+
const auto query = query_ * scaling_factor;
830835
if (is_causal) {
831836
TORCH_CHECK(!attn_mask.has_value(),
832837
"_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True");
@@ -849,7 +854,7 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
849854
}
850855
// Otherwise, attn_mask represents an additive attention tensor
851856
}
852-
auto attn = at::matmul(query, key.transpose(-2, -1)/scaling_factor);
857+
auto attn = at::matmul(query, key.transpose(-2, -1)*scaling_factor);
853858
if (attn_mask.has_value()) {
854859
attn.add_(*attn_mask);
855860
}

aten/src/ATen/native/transformers/attention.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
#include <c10/macros/Export.h>
44
#include <ATen/native/DispatchStub.h>
55
#include <ATen/native/transformers/attention.h>
6+
#include <c10/util/Optional.h>
67

78
namespace at {
89
namespace native {
910

1011
using fused_sdp_choice_fn = int64_t (*)(const Tensor& query_, const Tensor& key, const Tensor& value,
11-
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal);
12+
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, c10::optional<double> scale);
1213

1314
DECLARE_DISPATCH(fused_sdp_choice_fn, _fused_sdp_choice_stub);
1415

aten/src/ATen/native/transformers/cuda/attention.cu

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
2525
#include <ATen/native/nested/NestedTensorUtils.h>
2626
#include <ATen/native/transformers/cuda/sdp_utils.h>
27+
#include <ATen/native/transformers/sdp_utils_cpp.h>
2728

2829
#ifdef USE_FLASH_ATTENTION
2930
#include <ATen/native/transformers/cuda/flash_attn/fmha_api.h>
@@ -585,7 +586,7 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(
585586
.transpose(1, 2);
586587

587588
auto y = at::scaled_dot_product_attention(
588-
chunks[0], chunks[1], chunks[2], mask, 0.0, false);
589+
chunks[0], chunks[1], chunks[2], mask, 0.0, false, c10::nullopt);
589590

590591
auto past_sdp = y.transpose(1, 2).reshape({x_size_0, -1, embed_dim});
591592
return std::make_tuple(
@@ -689,7 +690,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t, int64_t, int64_t, int64_t, T
689690
const Tensor& value,
690691
double dropout_p,
691692
bool is_causal,
692-
bool return_debug_mask) {
693+
bool return_debug_mask,
694+
c10::optional<double> scale) {
693695
// Used for tracking usage statistics
694696
C10_LOG_API_USAGE_ONCE("torch.sdpa.flash_attention");
695697
// Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
@@ -746,7 +748,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t, int64_t, int64_t, int64_t, T
746748
max_seqlen_batch_k,
747749
dropout_p,
748750
is_causal,
749-
return_debug_mask);
751+
return_debug_mask,
752+
scale);
750753
// Reshape output to convert nnz to batch_size and seq_len
751754
attention =
752755
attention.view({batch_size, max_seqlen_batch_q, num_heads, head_dim}).transpose(1,2);
@@ -759,7 +762,8 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_efficient_attention_cuda(
759762
const Tensor& key,
760763
const Tensor& value,
761764
bool compute_log_sumexp,
762-
bool is_causal) {
765+
bool is_causal,
766+
c10::optional<double> scale) {
763767
// Used for tracking usage statistics
764768
C10_LOG_API_USAGE_ONCE("torch.sdpa.mem_efficient_attention");
765769
// Query -> Query(Batch x Q_seq_len x Num_heads x Dim_per_head)
@@ -778,13 +782,14 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_efficient_attention_cuda(
778782
c10::nullopt,
779783
c10::nullopt,
780784
compute_log_sumexp,
781-
is_causal);
785+
is_causal,
786+
scale);
782787
attention = attention.transpose(1,2);
783788
return std::make_tuple(std::move(attention), std::move(log_sumexp));
784789
}
785790

786791
int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value,
787-
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal){
792+
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, c10::optional<double> scale){
788793
sdp::sdp_params kernel_params{query_, key, value, attn_mask_.has_value(), dropout_p, is_causal};
789794
auto backend = select_sdp_backend(kernel_params);
790795
if (backend == sdp::SDPBackend::error) {
@@ -823,7 +828,8 @@ std::tuple<Tensor, Tensor, int64_t, int64_t, Tensor> _flash_attention_forward(
823828
const int64_t max_seqlen_batch_k,
824829
double dropout_p,
825830
bool is_causal,
826-
bool return_debug_mask) {
831+
bool return_debug_mask,
832+
c10::optional<double> scale) {
827833
#if defined(USE_FLASH_ATTENTION)
828834
/*
829835
num_splits determines how much to parallelize over the seqlen_q dimension
@@ -832,7 +838,7 @@ std::tuple<Tensor, Tensor, int64_t, int64_t, Tensor> _flash_attention_forward(
832838
benchmarking. We will hard code it to 0 for now
833839
*/
834840
constexpr int num_splits{0};
835-
auto softmax_scale = std::pow(query.size(-1), -0.5);
841+
const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
836842
at::Tensor output = at::empty_like(query);
837843

838844
Tensor logsumexp, debug_attn_mask;
@@ -877,7 +883,8 @@ std::tuple<at::Tensor, at::Tensor> _efficient_attention_forward(
877883
// (Mode 1MHK only) Maximum sequence length across batches
878884
const c10::optional<int64_t> max_seqlen_q_,
879885
bool compute_logsumexp,
880-
bool causal) {
886+
bool causal,
887+
c10::optional<double> scale) {
881888
#if defined(USE_FLASH_ATTENTION)
882889
// TODO In theory it is possible to compile with _CUDA_ARCH < 5.0 and run on a
883890
// machine that is >= 5.0. In practice, this is not a problem but since
@@ -985,6 +992,7 @@ std::tuple<at::Tensor, at::Tensor> _efficient_attention_forward(
985992
TORCH_CHECK(B < std::numeric_limits<decltype(A)>::max(), #B " overflows"); \
986993
}
987994

995+
p.scale = sdp::calculate_scale(query, scale).as_float_unchecked();
988996
p.num_heads = num_heads;
989997
p.head_dim = query.size(3);
990998
p.head_dim_value = value.size(3);

0 commit comments

Comments
 (0)