🐛 Bug
when using torch_xla version 2.6.0 with the function torch.nn.functional.scaled_dot_product_attention on TPU v5e, it's VERY slow for some reason. (comparison for previous versions below)
To Reproduce
I used the simple code:
import math
from functools import partial, wraps
import os
import torch
import timeit
import statistics
import torch_xla
import torch_xla.core.xla_model as xm
from torch_xla.amp import autocast
def wrap_with_mark_step(func):
@wraps(func)
def wrapper(*args, **kwargs):
res = func(*args, **kwargs)
xm.mark_step()
return res
return wrapper
@wrap_with_mark_step
def sdpa(query, key, value):
return torch.nn.functional.scaled_dot_product_attention(query, key, value)
@wrap_with_mark_step
def standard_attention(query, key, value):
# from sdpa's doc implementation (it's pretty straightforward)
scale_factor = 1 / math.sqrt(query.size(-1))
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight = torch.softmax(attn_weight, dim=-1)
return attn_weight @ value
def time_attention(attention_fn, query, key, value):
# Warm-up
for _ in range(100):
attention_fn(query, key, value)
# Run timings with repeats
repeats = 10
number = 100
times = timeit.repeat(partial(attention_fn, query, key, value), repeat=repeats, number=number)
mean_time = statistics.mean(times) / number * 1000 # ms per call
std_dev = statistics.stdev(times) / number * 1000 # ms per call
print(f" {attention_fn.__name__} average time: {mean_time:.3f} ms per call, dev: {std_dev:.3f} ms per call")
if __name__ == '__main__':
os.environ["XLA_DISABLE_FUNCTIONALIZATION"] = "1"
batch_size = 32
seq_len = 256
num_heads = 16
head_dim = 64
device = xm.xla_device()
# Generate input tensors with correct shape: (batch_size, num_heads, seq_len, head_dim)
query = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device)
key = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device)
value = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device)
assert torch.allclose(sdpa(query, key, value), standard_attention(query, key, value),
rtol=1e-02, atol=1e-02)
print("Both attention functions produce about the same output, test time performance now")
print("No mixed precision:")
time_attention(sdpa, query, key, value)
time_attention(standard_attention, query, key, value)
print("With mixed precision:")
with autocast(xm.xla_device(), enabled=True, dtype=torch.bfloat16):
time_attention(sdpa, query, key, value)
time_attention(standard_attention, query, key, value)
print(f"torch version: {torch.__version__}, torch_xla version: {torch_xla.__version__}")
Steps to reproduce the behavior:
- I tested i on GKE (with my code in a configmap), that's the yaml I used:
apiVersion: batch/v1
kind: Job
metadata:
name: test-sdpa-260
spec:
backoffLimit: 0
completionMode: Indexed
completions: 2
parallelism: 2
template:
spec:
containers:
- image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_libtpu_3.10_tpuvm
name: test
command:
- bash
args:
- -c
- python -m my_code.test_sdpa
env:
- name: PYTHONUNBUFFERED
value: "1"
- name: PJRT_DEVICE
value: "TPU"
ports:
- containerPort: 12355
- containerPort: 8080
- containerPort: 8431
- containerPort: 8471
- containerPort: 8476
- containerPort: 8477
- containerPort: 8478
- containerPort: 8479
resources:
limits:
google.com/tpu: '4'
requests:
google.com/tpu: '4'
securityContext:
privileged: true
volumeMounts:
- name: configmap-volume
mountPath: /my_code
- mountPath: /dev/shm
name: shm
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
cloud.google.com/gke-tpu-topology: 2x4
restartPolicy: Never
subdomain: headless-svc
tolerations:
- effect: NoSchedule
key: google.com/tpu
operator: Exists
volumes:
- emptyDir:
medium: Memory
name: shm
- name: configmap-volume
configMap:
name: my-configmap
- then I tested it with other officials docker images (us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_libtpu_3.10_tpuvm and us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.1_libtpu_3.10_tpuvm)
- the outputs I got are:
Both attention functions produce about the same output, test time performance now
No mixed precision:
sdpa average time: 0.714 ms per call, dev: 0.000 ms per call
standard_attention average time: 0.726 ms per call, dev: 0.001 ms per call
With mixed precision:
sdpa average time: 0.251 ms per call, dev: 0.000 ms per call
standard_attention average time: 0.251 ms per call, dev: 0.000 ms per call
torch version: 2.4.0+libtpu, torch_xla version: 2.4.0+libtpu
Both attention functions produce about the same output, test time performance now
No mixed precision:
sdpa average time: 0.722 ms per call, dev: 0.001 ms per call
standard_attention average time: 0.726 ms per call, dev: 0.000 ms per call
With mixed precision:
sdpa average time: 0.253 ms per call, dev: 0.003 ms per call
standard_attention average time: 0.239 ms per call, dev: 0.009 ms per call
torch version: 2.5.1+libtpu, torch_xla version: 2.5.1+libtpu
Both attention functions produce about the same output, test time performance now
No mixed precision:
sdpa average time: 116.679 ms per call, dev: 0.447 ms per call
standard_attention average time: 0.726 ms per call, dev: 0.000 ms per call
With mixed precision:
sdpa average time: 106.377 ms per call, dev: 0.511 ms per call
standard_attention average time: 0.347 ms per call, dev: 0.003 ms per call
torch version: 2.6.0+libtpu, torch_xla version: 2.6.0+libtpu
more then x100 times slower in version 2.6.0 regarding previous versions (and the straightforward implementation [which also was degradated in the mixed percision])!
This is very disturbing considering attention is the basic operator for a lot of modern architectures.
Environment
- Reproducible on XLA backend [CPU/TPU/CUDA]: TPU v5e
- torch_xla version: 2.4.0, 2.5.1, 2.6.0
🐛 Bug
when using torch_xla version 2.6.0 with the function torch.nn.functional.scaled_dot_product_attention on TPU v5e, it's VERY slow for some reason. (comparison for previous versions below)
To Reproduce
I used the simple code:
Steps to reproduce the behavior:
more then x100 times slower in version 2.6.0 regarding previous versions (and the straightforward implementation [which also was degradated in the mixed percision])!
This is very disturbing considering attention is the basic operator for a lot of modern architectures.
Environment