Skip to content

Operator performance plummets on PyTorch 2.2.1 Windows platform #120788

@luow1028

Description

@luow1028

🐛 Describe the bug

I implemented a naive attention block with matmul and softmax, then compared its performance with that of torch.nn.functional.scaled_dot_product_attention. I found a very strange phenomenon.

torch 2.1.2:
repeat:100 times
naive attn total time:9.351994514465332s, avg time: 0.09351994514465332s
scaled_dot_product_attention total time:10.712503671646118s, avg time: 0.10712503671646117s

torch 2.2.1:
repeat:100 times
naive attn total time:262.39105248451233s, avg time: 2.6239105248451233s
scaled_dot_product_attention fa total time:7.838253498077393s, avg time: 0.07838253498077392s

import torch
import torch.nn.functional as F
import math

def naive_sdp2(Q, K, V, mask=None):
    attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(Q.size(-1))
    if mask is not None:
        attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
    attn_probs = torch.softmax(attn_scores, dim=-1)
    output = torch.matmul(attn_probs, V)
    return output

query = torch.rand(1, 32, 1024, 64, dtype=torch.bfloat16)
key = torch.rand(1, 32, 1024, 64, dtype=torch.bfloat16)
value = torch.rand(1, 32, 1024, 64, dtype=torch.bfloat16)

Q = query.view(query.size(0) * query.size(1), query.size(2), query.size(3))

K = key.view(key.size(0) * key.size(1), key.size(2), key.size(3))

V = value.view(value.size(0) * value.size(1), value.size(2), value.size(3))

import time
repeat_time = 100
total_time = 0
for i in range(repeat_time):
    start_time = time.time()
    naive_sdp2(Q, K, V)
    total_time += time.time() - start_time

print(f"attn total time:{total_time}s, avg time: {total_time / repeat_time}s")

total_time = 0
for i in range(repeat_time):
   start_time = time.time()
   F.scaled_dot_product_attention(query,key,value)
   total_time += time.time() - start_time
print(f"fa total time:{total_time}s, avg time: {total_time / repeat_time}s")

Versions

Collecting environment information...
PyTorch version: 2.1.2+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 11 Pro
GCC version: Could not collect
Clang version: 17.0.6
CMake version: version 3.28.3
Libc version: N/A

Python version: 3.9.18 | packaged by conda-forge | (main, Dec 23 2023, 16:29:04) [MSC v.1929 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.22621-SP0
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture=9

CurrentClockSpeed=4001

DeviceID=CPU0

Family=107

L2CacheSize=8192

L2CacheSpeed=

Manufacturer=AuthenticAMD

MaxClockSpeed=4001

Name=AMD Ryzen 9 7940HS w/ Radeon 780M Graphics

ProcessorType=3

Revision=29697

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] onnx==1.15.0
[pip3] onnxruntime==1.17.0
[pip3] ryzenai_torch_cpp==0.0.0
[pip3] torch==2.1.2
[pip3] torchaudio==2.1.2
[pip3] torchmetrics==1.3.1
[pip3] torchvision==0.16.2
[conda] libblas 3.9.0 21_win64_mkl conda-forge
[conda] libcblas 3.9.0 21_win64_mkl conda-forge
[conda] liblapack 3.9.0 21_win64_mkl conda-forge
[conda] mkl 2024.0.0 h66d3029_49657 conda-forge
[conda] numpy 1.26.4 py39hddb5d58_0 conda-forge
[conda] ryzenai-torch-cpp 0.0.0 pypi_0 pypi
[conda] torch 2.1.2 pypi_0 pypi
[conda] torchaudio 2.1.2 pypi_0 pypi
[conda] torchmetrics 1.3.1 pypi_0 pypi
[conda] torchvision 0.16.2 pypi_0 pypi

cc @ezyang @gchanan @zou3519 @kadeng @peterjc123 @mszhanyi @skyline75489 @nbcsm @vladimir-aubrecht @iremyux @Blackhex @cristianPanaite @frank-wei @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10

Metadata

Metadata

Assignees

Labels

high prioritymodule: intelSpecific to x86 architecturemodule: performanceIssues related to performance, either of kernel code or framework gluemodule: windowsWindows support for PyTorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

Status

Done

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions