Skip to content

a HUGE performance degradation in SDPA (scale_dot_product_attention) in xla 2.6.0 on TPU #8746

@giladxw

Description

@giladxw

🐛 Bug

when using torch_xla version 2.6.0 with the function torch.nn.functional.scaled_dot_product_attention on TPU v5e, it's VERY slow for some reason. (comparison for previous versions below)

To Reproduce

I used the simple code:

import math
from functools import partial, wraps
import os
import torch
import timeit
import statistics
import torch_xla
import torch_xla.core.xla_model as xm
from torch_xla.amp import autocast


def wrap_with_mark_step(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        res = func(*args, **kwargs)
        xm.mark_step()
        return res
    return wrapper


@wrap_with_mark_step
def sdpa(query, key, value):
    return torch.nn.functional.scaled_dot_product_attention(query, key, value)


@wrap_with_mark_step
def standard_attention(query, key, value):
    # from sdpa's doc implementation (it's pretty straightforward)
    scale_factor = 1 / math.sqrt(query.size(-1))
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight = torch.softmax(attn_weight, dim=-1)
    return attn_weight @ value


def time_attention(attention_fn, query, key, value):
    # Warm-up
    for _ in range(100):
        attention_fn(query, key, value)

    # Run timings with repeats
    repeats = 10
    number = 100
    times = timeit.repeat(partial(attention_fn, query, key, value), repeat=repeats, number=number)

    mean_time = statistics.mean(times) / number * 1000  # ms per call
    std_dev = statistics.stdev(times) / number * 1000  # ms per call

    print(f"    {attention_fn.__name__} average time: {mean_time:.3f} ms per call, dev: {std_dev:.3f} ms per call")


if __name__ == '__main__':
    os.environ["XLA_DISABLE_FUNCTIONALIZATION"] = "1"

    batch_size = 32
    seq_len = 256
    num_heads = 16
    head_dim = 64
    device = xm.xla_device()

    # Generate input tensors with correct shape: (batch_size, num_heads, seq_len, head_dim)
    query = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device)
    key = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device)
    value = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device)

    assert torch.allclose(sdpa(query, key, value), standard_attention(query, key, value),
                          rtol=1e-02, atol=1e-02)
    print("Both attention functions produce about the same output, test time performance now")

    print("No mixed precision:")
    time_attention(sdpa, query, key, value)
    time_attention(standard_attention, query, key, value)

    print("With mixed precision:")
    with autocast(xm.xla_device(), enabled=True, dtype=torch.bfloat16):
        time_attention(sdpa, query, key, value)
        time_attention(standard_attention, query, key, value)

    print(f"torch version: {torch.__version__}, torch_xla version: {torch_xla.__version__}")

Steps to reproduce the behavior:

  1. I tested i on GKE (with my code in a configmap), that's the yaml I used:
apiVersion: batch/v1
kind: Job
metadata:
  name: test-sdpa-260
spec:
  backoffLimit: 0
  completionMode: Indexed
  completions: 2
  parallelism: 2
  template:
    spec:
      containers:
      - image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_libtpu_3.10_tpuvm
        name: test
        command:
          - bash
        args:
          - -c
          - python -m my_code.test_sdpa
        env:
          - name: PYTHONUNBUFFERED
            value: "1"
          - name: PJRT_DEVICE
            value: "TPU"
        ports:
          - containerPort: 12355
          - containerPort: 8080
          - containerPort: 8431
          - containerPort: 8471
          - containerPort: 8476
          - containerPort: 8477
          - containerPort: 8478
          - containerPort: 8479
        resources:
          limits:
            google.com/tpu: '4'
          requests:
            google.com/tpu: '4'
        securityContext:
          privileged: true
        volumeMounts:
          - name: configmap-volume
            mountPath: /my_code
          - mountPath: /dev/shm
            name: shm
      nodeSelector:
        cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
        cloud.google.com/gke-tpu-topology: 2x4
      restartPolicy: Never
      subdomain: headless-svc
      tolerations:
        - effect: NoSchedule
          key: google.com/tpu
          operator: Exists
      volumes:
        - emptyDir:
            medium: Memory
          name: shm
        - name: configmap-volume
          configMap:
            name: my-configmap
  1. then I tested it with other officials docker images (us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_libtpu_3.10_tpuvm and us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.1_libtpu_3.10_tpuvm)
  2. the outputs I got are:
Both attention functions produce about the same output, test time performance now
No mixed precision:
    sdpa average time: 0.714 ms per call, dev: 0.000 ms per call
    standard_attention average time: 0.726 ms per call, dev: 0.001 ms per call
With mixed precision:
    sdpa average time: 0.251 ms per call, dev: 0.000 ms per call
    standard_attention average time: 0.251 ms per call, dev: 0.000 ms per call
torch version: 2.4.0+libtpu, torch_xla version: 2.4.0+libtpu
Both attention functions produce about the same output, test time performance now
No mixed precision:
    sdpa average time: 0.722 ms per call, dev: 0.001 ms per call
    standard_attention average time: 0.726 ms per call, dev: 0.000 ms per call
With mixed precision:
    sdpa average time: 0.253 ms per call, dev: 0.003 ms per call
    standard_attention average time: 0.239 ms per call, dev: 0.009 ms per call
torch version: 2.5.1+libtpu, torch_xla version: 2.5.1+libtpu
Both attention functions produce about the same output, test time performance now
No mixed precision:
    sdpa average time: 116.679 ms per call, dev: 0.447 ms per call
    standard_attention average time: 0.726 ms per call, dev: 0.000 ms per call
With mixed precision:
    sdpa average time: 106.377 ms per call, dev: 0.511 ms per call
    standard_attention average time: 0.347 ms per call, dev: 0.003 ms per call
torch version: 2.6.0+libtpu, torch_xla version: 2.6.0+libtpu

more then x100 times slower in version 2.6.0 regarding previous versions (and the straightforward implementation [which also was degradated in the mixed percision])!
This is very disturbing considering attention is the basic operator for a lot of modern architectures.

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: TPU v5e
  • torch_xla version: 2.4.0, 2.5.1, 2.6.0

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions