Skip to content

PyTorch's packaged libgomp causes significant performance penalties on CPU when used together with other Python packages #98836

@mergian

Description

@mergian

🐛 Describe the bug

PyTorch's PYPI packages come with their own libgomp-SOMEHASH.so packaged. Other packages like SciKit Learn do the same. The problem is, that depending on the order of loading your Python modules, the PyTorch OpenMP might be initialized with only a single thread.

This can be easily seen by running (I removed all non-related output):

# python3 -m threadpoolctl -i torch sklearn
[
  {
    "user_api": "openmp",
    "internal_api": "openmp",
    "prefix": "libgomp",
    "filepath": "/.../python3.8/site-packages/torch/lib/libgomp-a34b3233.so.1",
    "version": null,
    "num_threads": 12 # PyTorch 12 Threads
  },
  {
    "user_api": "openmp",
    "internal_api": "openmp",
    "prefix": "libgomp",
    "filepath": "/.../python3.8/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0",
    "version": null,
    "num_threads": 1 # SKlearn 1 Thread
  }
]

and:

# python3 -m threadpoolctl -i sklearn torch
[
  {
    "user_api": "openmp",
    "internal_api": "openmp",
    "prefix": "libgomp",
    "filepath": "/.../python3.8/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0",
    "version": null,
    "num_threads": 24 # SKlearn 24 Threads
  },
  {
    "user_api": "openmp",
    "internal_api": "openmp",
    "prefix": "libgomp",
    "filepath": "/.../python3.8/site-packages/torch/lib/libgomp-a34b3233.so.1",
    "version": null,
    "num_threads": 1 # PyTorch 1 Thread
  }
]

In the first case, PyTorch gets all threads, in the second case, SciKit Learn gets all threads.

This minimal example shows the effect on the performance:

import sklearn # remove or swap with 2nd line
import torch
import torchvision
from time import perf_counter_ns as timer

model = torchvision.models.resnet50()
model.eval()

data = torch.rand(64, 3, 224, 224)

start = timer()
with torch.no_grad():
    for i in range(5):
        model(data)
end = timer()
print(f'Total: {(end-start)/1000000.0}ms')

Result without import sklearn or by swapping the two import lines: Total: 5020.870435ms
And with import sklearn: Total: 27399.992653ms

Even if we would manually set the number of threads correctly, it still would have a performance penalty when switching between PyTorch and SKlearn, as the thread pools need to be swapped.

My current workaround is to remove all libgomp-*.so within my Python user site and replace them with symlinks to the system's libgomp.so. This causes that Sklearn and Pytorch use the same thread pool, which in my opinion is the desired behavior. Another solution would be to compile PyTorch from source.

I'm not sure why PyTorch is shipping it's own libgomp. I'm guessing it's for compatibility reasons on older systems, that don't have libgomp or an outdated/incompatible version. However, the current approach causes significant downsides when using PyTorch with other packages or user applications, that are linked against the system's libgomp. So far I identified onnxruntime-openmp and scikit-learn that do the same, but I assume there are many more.

I came up with multiple solutions:

  1. A hacky solution would be to ensure that all packages use the identical libgomp-SOMEHASH.so.SO_VERSION, e.g., SKlearn and onnxruntime use libgomp-a34b3233.so.1.0.0 while PyTorch uses libgomp-a34b3233.so.1. This works as libdl only checks the file name. But that does not solve the fundamental problem of shipping your own libgomp, and still would have the problem when the user include own libraries linked against system libgomp.

  2. A proper solution would be to do something like the intel-openmp package, that provides a centralized way of accessing the libraries and then can be easily taken up by multiple python packages without conflicts. Here, PyTorch, SKlearn, etc. could just have this package as common requirement, and load all the same library.

As this is a cross project issue, I'm not sure what the best way is to coordinate with the other projects.

This issue is related to: #44282, #19764

Versions

Collecting environment information...
PyTorch version: 2.0.0+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: CentOS Linux 7 (Core) (x86_64)
GCC version: (GCC) 10.3.0
Clang version: Could not collect
CMake version: version 3.26.3
Libc version: glibc-2.17

Python version: 3.8.16 (default, Mar 17 2023, 07:42:34)  [GCC 10.2.1 20210130 (Red Hat 10.2.1-11)] (64-bit runtime)
Python platform: Linux-3.10.0-1160.76.1.el7.x86_64-x86_64-with-glibc2.17
Is CUDA available: False
CUDA runtime version: 11.4.120
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] pytorch-lightning==2.0.1
[pip3] torch==2.0.0
[pip3] torchmetrics==0.11.4
[pip3] torchvision==0.15.1
[conda] Could not collect

cc @malfet @seemethere

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: buildBuild system issuesmodule: multithreadingRelated to issues that occur when running on multiple CPU threadstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    Status

    In Progress

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions