Skip to content

Commit fca9d7b

Browse files
committed
upgrade to fexp_u20 since torch has been updated to 2.9
1 parent e83283e commit fca9d7b

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

sgl-kernel/csrc/cpu/extend.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ void extend_attention_kernel_impl(
238238

239239
// s_delta <- exp(s_i - m_i)
240240
at::vec::map<float>(
241-
[m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
241+
[m_i](Vec x) { return (x - Vec(m_i)).fexp_u20(); }, s_delta + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
242242

243243
// s' <- s' * m_delta + sum(s_delta)
244244
s_prime[row] *= m_delta;
@@ -349,7 +349,7 @@ void extend_attention_kernel_impl(
349349

350350
// s_delta <- exp(s_i - m_i)
351351
at::vec::map<float>(
352-
[m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
352+
[m_i](Vec x) { return (x - Vec(m_i)).fexp_u20(); }, s_delta + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
353353

354354
// s' <- s' * m_delta + sum(s_delta)
355355
s_prime[row] *= m_delta;

0 commit comments

Comments
 (0)