-
-
Notifications
You must be signed in to change notification settings - Fork 6k
Expand file tree
/
Copy pathdevice_type.py
More file actions
149 lines (130 loc) · 5.26 KB
/
Copy pathdevice_type.py
File metadata and controls
149 lines (130 loc) · 5.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = [
"is_hip",
"get_device_type",
"DEVICE_TYPE",
"DEVICE_TYPE_TORCH",
"DEVICE_COUNT",
"ALLOW_PREQUANTIZED_MODELS",
"ALLOW_BITSANDBYTES",
"is_mlx_available",
]
import functools
import inspect
import os
from unsloth_zoo.utils import Version
def is_mlx_available():
try:
from unsloth_zoo.mlx import is_mlx_available as _is_mlx_available
except ImportError:
return False
return _is_mlx_available()
_IS_MLX = is_mlx_available()
if not _IS_MLX:
import torch
@functools.cache
def is_hip():
if _IS_MLX:
return False
return bool(getattr(getattr(torch, "version", None), "hip", None))
@functools.cache
def get_device_type():
# Test-only CPU fallback: report "cuda" so every DEVICE_TYPE == "cuda"
# branch behaves identically. Read once per process (function is cached).
if os.environ.get("UNSLOTH_ALLOW_CPU", "0") == "1":
return "cuda"
if _IS_MLX:
return "mlx"
if hasattr(torch, "cuda") and torch.cuda.is_available():
if is_hip():
return "hip"
return "cuda"
elif hasattr(torch, "xpu") and torch.xpu.is_available():
return "xpu"
# Check torch.accelerator
if hasattr(torch, "accelerator"):
if not torch.accelerator.is_available():
raise NotImplementedError("Unsloth cannot find any torch accelerator? You need a GPU.")
accelerator = str(torch.accelerator.current_accelerator())
if accelerator in ("cuda", "xpu", "hip"):
raise RuntimeError(
f"Unsloth: Weirdly `torch.cuda.is_available()`, `torch.xpu.is_available()` and `is_hip` all failed.\n"
f"But `torch.accelerator.current_accelerator()` works with it being = `{accelerator}`\n"
f"Please reinstall torch - it's most likely broken :("
)
raise NotImplementedError("Unsloth currently only works on NVIDIA, AMD and Intel GPUs.")
DEVICE_TYPE: str = get_device_type()
# HIP fails for autocast and other torch functions. Use CUDA instead
DEVICE_TYPE_TORCH = DEVICE_TYPE
if DEVICE_TYPE_TORCH == "hip":
DEVICE_TYPE_TORCH = "cuda"
elif DEVICE_TYPE_TORCH == "mlx":
DEVICE_TYPE_TORCH = "mps"
@functools.cache
def get_device_count():
if DEVICE_TYPE in ("cuda", "hip"):
return torch.cuda.device_count()
elif DEVICE_TYPE == "xpu":
return torch.xpu.device_count()
else:
return 1
DEVICE_COUNT: int = get_device_count()
# 4-bit quantization requires a block size of 64
# | Device Type | Warp Size | Block Size |
# |-----------------|-----------|------------|
# | CUDA | 32 | 32 |
# | Radeon (Navi) | 32 | 32 |
# | Instinct (MI) | 64 | 32 |
#
# Since bitsandbytes 0.49.0, pre-quantized models with 64 blockwise now works
# on Radeon GPUs, but not Instinct MI300x for eg
# See https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1748
#
# Since bitsandbytes 0.49.2, blocksize=64 4-bit quantization is supported on
# CDNA (MI Instinct / gfx9xx) GPUs as well
# See https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1856
ALLOW_PREQUANTIZED_MODELS: bool = True
# HSA_STATUS_ERROR_EXCEPTION checks - sometimes AMD fails for BnB
ALLOW_BITSANDBYTES: bool = True
if DEVICE_TYPE == "hip":
try:
import bitsandbytes
except:
print(
"Unsloth: `bitsandbytes` is not installed - 4bit QLoRA unallowed, but 16bit and full finetuning works."
)
ALLOW_PREQUANTIZED_MODELS = False
ALLOW_BITSANDBYTES = False
if ALLOW_BITSANDBYTES:
ALLOW_BITSANDBYTES = Version(bitsandbytes.__version__) > Version("0.48.2.dev0")
if Version(bitsandbytes.__version__) >= Version("0.49.2"):
pass
elif Version(bitsandbytes.__version__) >= Version("0.49.0"):
try:
# Pre-quantized bitsandbytes models use blocksize 64, so we need to check the GPU
from bitsandbytes.cextension import ROCM_WARP_SIZE_64
ALLOW_PREQUANTIZED_MODELS = not ROCM_WARP_SIZE_64
except Exception as e:
print(
"Unsloth: Checking `from bitsandbytes.cextension import ROCM_WARP_SIZE_64` had error = \n"
f"{str(e)}\n"
"4bit QLoRA disabled for now, but 16bit and full finetuning works."
)
ALLOW_PREQUANTIZED_MODELS = False
ALLOW_BITSANDBYTES = False
elif ALLOW_BITSANDBYTES:
from bitsandbytes.nn.modules import Params4bit
if "blocksize = 64 if not HIP_ENVIRONMENT else 128" in inspect.getsource(Params4bit):
ALLOW_PREQUANTIZED_MODELS = False