🐛 Describe the bug
I'm getting an out of memory error when trying to use an exported .so file. The .so is 16GB my GPU has 24GB of memory
(.venv) warden@Vikander:~/source/torchchat$ python3 torchchat.py generate llama3.1 --dso-path exportedModels/llama3.1.so --prompt "Hello my name is" --device cuda
Warning: checkpoint path ignored because an exported DSO or PTE path specified
Warning: checkpoint path ignored because an exported DSO or PTE path specified
Using device=cuda NVIDIA GeForce RTX 4090
Loading model...
Time to load model: 1.85 seconds
Error: CUDA error: out of memory
Traceback (most recent call last):
File "/home/warden/source/torchchat/torchchat/cli/builder.py", line 536, in _initialize_model
model.forward = torch._export.aot_load(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/warden/source/torchchat/.venv/lib/python3.11/site-packages/torch/_export/__init__.py", line 320, in aot_load
runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: create_func_( &container_handle_, num_models, device_str.c_str(), cubin_dir.empty() ? nullptr : cubin_dir.c_str()) API call failed at ../torch/csrc/inductor/aoti_runner/model_container_runner.cpp, line 85
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/warden/source/torchchat/torchchat.py", line 88, in <module>
generate_main(args)
File "/home/warden/source/torchchat/torchchat/generate.py", line 1215, in main
gen = Generator(
^^^^^^^^^^
File "/home/warden/source/torchchat/torchchat/generate.py", line 290, in __init__
self.model = _initialize_model(self.builder_args, self.quantize, self.tokenizer)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/warden/source/torchchat/torchchat/cli/builder.py", line 540, in _initialize_model
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}")
RuntimeError: Failed to load AOTI compiled exportedModels/llama3.1.so
Versions
Operating System Information
Linux Vikander 6.8.0-45-generic #45-Ubuntu SMP PREEMPT_DYNAMIC Fri Aug 30 12:02:04 UTC 2024 x86_64 x86_64 x86_64 GNU/Linux
PRETTY_NAME="Ubuntu 24.04.1 LTS"
NAME="Ubuntu"
VERSION_ID="24.04"
VERSION="24.04.1 LTS (Noble Numbat)"
VERSION_CODENAME=noble
ID=ubuntu
ID_LIKE=debian
HOME_URL="https://www.ubuntu.com/"
SUPPORT_URL="https://help.ubuntu.com/"
BUG_REPORT_URL="https://bugs.launchpad.net/ubuntu/"
PRIVACY_POLICY_URL="https://www.ubuntu.com/legal/terms-and-policies/privacy-policy"
UBUNTU_CODENAME=noble
LOGO=ubuntu-logo
Python Version
Python 3.11.10
PIP Version
pip 24.0 from /home/warden/source/torchchat/.venv/lib/python3.11/site-packages/pip (python 3.11)
Installed Packages
absl-py==2.1.0
accelerate==1.0.1
aiohappyeyeballs==2.4.3
aiohttp==3.10.10
aiosignal==1.3.1
altair==5.4.1
annotated-types==0.7.0
antlr4-python3-runtime==4.9.3
anyio==4.6.2.post1
attrs==24.2.0
blinker==1.8.2
blobfile==3.0.0
cachetools==5.5.0
certifi==2024.8.30
chardet==5.2.0
charset-normalizer==3.4.0
click==8.1.7
cmake==3.30.4
colorama==0.4.6
DataProperty==1.0.1
datasets==3.0.1
dill==0.3.8
distro==1.9.0
evaluate==0.4.3
filelock==3.16.1
Flask==3.0.3
frozenlist==1.4.1
fsspec==2024.6.1
gguf==0.10.0
gitdb==4.0.11
GitPython==3.1.43
h11==0.14.0
httpcore==1.0.6
httpx==0.27.2
huggingface-hub==0.25.2
idna==3.10
itsdangerous==2.2.0
Jinja2==3.1.4
jiter==0.6.1
joblib==1.4.2
jsonlines==4.0.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
lm_eval==0.4.2
lxml==5.3.0
markdown-it-py==3.0.0
MarkupSafe==3.0.1
mbstrdecoder==1.1.3
mdurl==0.1.2
more-itertools==10.5.0
mpmath==1.3.0
multidict==6.1.0
multiprocess==0.70.16
narwhals==1.9.3
networkx==3.4.1
ninja==1.11.1.1
nltk==3.9.1
numexpr==2.10.1
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.6.77
nvidia-nvtx-cu12==12.1.105
omegaconf==2.3.0
openai==1.51.2
packaging==24.1
pandas==2.2.3
pathvalidate==3.2.1
peft==0.13.2
pillow==10.4.0
portalocker==2.10.1
propcache==0.2.0
protobuf==5.28.2
psutil==6.0.0
pyarrow==17.0.0
pybind11==2.13.6
pycryptodomex==3.21.0
pydantic==2.9.2
pydantic_core==2.23.4
pydeck==0.9.1
Pygments==2.18.0
pytablewriter==1.2.0
python-dateutil==2.9.0.post0
pytorch-triton==3.1.0+cf34004b8a
pytz==2024.2
PyYAML==6.0.2
referencing==0.35.1
regex==2024.9.11
requests==2.32.3
rich==13.9.2
rouge-score==0.1.2
rpds-py==0.20.0
sacrebleu==2.4.3
safetensors==0.4.5
scikit-learn==1.5.2
scipy==1.14.1
sentencepiece==0.2.0
six==1.16.0
smmap==5.0.1
snakeviz==2.2.0
sniffio==1.3.1
sqlitedict==2.1.0
streamlit==1.39.0
sympy==1.13.1
tabledata==1.3.3
tabulate==0.9.0
tcolorpy==0.1.6
tenacity==9.0.0
threadpoolctl==3.5.0
tiktoken==0.8.0
tokenizers==0.20.1
toml==0.10.2
torch==2.6.0.dev20241002+cu121
torchao==0.5.0
torchtune==0.3.0.dev20240928+cu121
torchvision==0.20.0.dev20241002+cu121
tornado==6.4.1
tqdm==4.66.5
tqdm-multiprocess==0.0.11
transformers==4.45.2
typepy==1.3.2
typing_extensions==4.12.2
tzdata==2024.2
urllib3==2.2.3
watchdog==5.0.3
Werkzeug==3.0.4
word2number==1.1
xxhash==3.5.0
yarl==1.15.2
zstandard==0.23.0
zstd==1.5.5.1
PyTorch Version
2.6.0.dev20241002+cu121
🐛 Describe the bug
I'm getting an out of memory error when trying to use an exported .so file. The .so is 16GB my GPU has 24GB of memory
Versions
Operating System Information
Linux Vikander 6.8.0-45-generic #45-Ubuntu SMP PREEMPT_DYNAMIC Fri Aug 30 12:02:04 UTC 2024 x86_64 x86_64 x86_64 GNU/Linux
PRETTY_NAME="Ubuntu 24.04.1 LTS"
NAME="Ubuntu"
VERSION_ID="24.04"
VERSION="24.04.1 LTS (Noble Numbat)"
VERSION_CODENAME=noble
ID=ubuntu
ID_LIKE=debian
HOME_URL="https://www.ubuntu.com/"
SUPPORT_URL="https://help.ubuntu.com/"
BUG_REPORT_URL="https://bugs.launchpad.net/ubuntu/"
PRIVACY_POLICY_URL="https://www.ubuntu.com/legal/terms-and-policies/privacy-policy"
UBUNTU_CODENAME=noble
LOGO=ubuntu-logo
Python Version
Python 3.11.10
PIP Version
pip 24.0 from /home/warden/source/torchchat/.venv/lib/python3.11/site-packages/pip (python 3.11)
Installed Packages
absl-py==2.1.0
accelerate==1.0.1
aiohappyeyeballs==2.4.3
aiohttp==3.10.10
aiosignal==1.3.1
altair==5.4.1
annotated-types==0.7.0
antlr4-python3-runtime==4.9.3
anyio==4.6.2.post1
attrs==24.2.0
blinker==1.8.2
blobfile==3.0.0
cachetools==5.5.0
certifi==2024.8.30
chardet==5.2.0
charset-normalizer==3.4.0
click==8.1.7
cmake==3.30.4
colorama==0.4.6
DataProperty==1.0.1
datasets==3.0.1
dill==0.3.8
distro==1.9.0
evaluate==0.4.3
filelock==3.16.1
Flask==3.0.3
frozenlist==1.4.1
fsspec==2024.6.1
gguf==0.10.0
gitdb==4.0.11
GitPython==3.1.43
h11==0.14.0
httpcore==1.0.6
httpx==0.27.2
huggingface-hub==0.25.2
idna==3.10
itsdangerous==2.2.0
Jinja2==3.1.4
jiter==0.6.1
joblib==1.4.2
jsonlines==4.0.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
lm_eval==0.4.2
lxml==5.3.0
markdown-it-py==3.0.0
MarkupSafe==3.0.1
mbstrdecoder==1.1.3
mdurl==0.1.2
more-itertools==10.5.0
mpmath==1.3.0
multidict==6.1.0
multiprocess==0.70.16
narwhals==1.9.3
networkx==3.4.1
ninja==1.11.1.1
nltk==3.9.1
numexpr==2.10.1
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.6.77
nvidia-nvtx-cu12==12.1.105
omegaconf==2.3.0
openai==1.51.2
packaging==24.1
pandas==2.2.3
pathvalidate==3.2.1
peft==0.13.2
pillow==10.4.0
portalocker==2.10.1
propcache==0.2.0
protobuf==5.28.2
psutil==6.0.0
pyarrow==17.0.0
pybind11==2.13.6
pycryptodomex==3.21.0
pydantic==2.9.2
pydantic_core==2.23.4
pydeck==0.9.1
Pygments==2.18.0
pytablewriter==1.2.0
python-dateutil==2.9.0.post0
pytorch-triton==3.1.0+cf34004b8a
pytz==2024.2
PyYAML==6.0.2
referencing==0.35.1
regex==2024.9.11
requests==2.32.3
rich==13.9.2
rouge-score==0.1.2
rpds-py==0.20.0
sacrebleu==2.4.3
safetensors==0.4.5
scikit-learn==1.5.2
scipy==1.14.1
sentencepiece==0.2.0
six==1.16.0
smmap==5.0.1
snakeviz==2.2.0
sniffio==1.3.1
sqlitedict==2.1.0
streamlit==1.39.0
sympy==1.13.1
tabledata==1.3.3
tabulate==0.9.0
tcolorpy==0.1.6
tenacity==9.0.0
threadpoolctl==3.5.0
tiktoken==0.8.0
tokenizers==0.20.1
toml==0.10.2
torch==2.6.0.dev20241002+cu121
torchao==0.5.0
torchtune==0.3.0.dev20240928+cu121
torchvision==0.20.0.dev20241002+cu121
tornado==6.4.1
tqdm==4.66.5
tqdm-multiprocess==0.0.11
transformers==4.45.2
typepy==1.3.2
typing_extensions==4.12.2
tzdata==2024.2
urllib3==2.2.3
watchdog==5.0.3
Werkzeug==3.0.4
word2number==1.1
xxhash==3.5.0
yarl==1.15.2
zstandard==0.23.0
zstd==1.5.5.1
PyTorch Version
2.6.0.dev20241002+cu121