-
Notifications
You must be signed in to change notification settings - Fork 7.5k
Expand file tree
/
Copy pathconfig.py
More file actions
256 lines (217 loc) · 9 KB
/
config.py
File metadata and controls
256 lines (217 loc) · 9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import logging
import os
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, Dict, Optional
import torch
import torch.distributed as dist
from packaging.version import Version
import ray
from ray._common.network_utils import build_address
from ray._private import ray_constants
from ray.air._internal.device_manager import register_custom_torch_dist_backend
from ray.exceptions import GetTimeoutError
from ray.train._internal.base_worker_group import BaseWorkerGroup
from ray.train._internal.utils import get_address_and_port
from ray.train.backend import Backend, BackendConfig
from ray.train.constants import (
DEFAULT_TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
)
from ray.train.v2._internal.util import TrainingFramework
from ray.util import PublicAPI
logger = logging.getLogger(__name__)
class TorchConfigContextManager:
def __enter__(self):
# Set default cuda device
if torch.cuda.is_available():
device = ray.train.torch.get_device()
if device.type == "cuda":
torch.cuda.set_device(device)
def __exit__(self, type, value, traceback):
# Propagate exceptions if any
return False
@PublicAPI(stability="stable")
@dataclass
class TorchConfig(BackendConfig):
"""Configuration for torch process group setup.
See https://pytorch.org/docs/stable/distributed.html for more info.
Args:
backend: The backend to use for training.
See ``torch.distributed.init_process_group`` for more info and
valid values.
If set to None, nccl will be used if GPUs are requested, else gloo
will be used.
init_method: The initialization method to use. Either "env"
for environment variable initialization or "tcp" for TCP
initialization. Defaults to "env".
timeout_s: Seconds for process group operations to timeout.
"""
backend: Optional[str] = None
init_method: str = "env"
timeout_s: int = 1800
@property
def backend_cls(self):
return _TorchBackend
@property
def train_func_context(self):
return TorchConfigContextManager
@property
def framework(self):
return TrainingFramework.TORCH
def to_dict(self) -> Dict[str, Any]:
config_dict = {
"backend": self.backend,
"init_method": self.init_method,
"timeout_s": self.timeout_s,
}
return config_dict
def _is_backend_nccl(backend: str) -> bool:
# Check containment because comma separated lists of backends like cpu:gloo,cuda:nccl are supported.
return backend == "nccl" or any(
item.split(":")[1] == "nccl"
for item in backend.split(",")
if item.startswith("cuda:")
)
def _setup_torch_process_group(
backend: str,
world_rank: int,
world_size: int,
init_method: str,
timeout_s: int = 1800,
):
"""Connects the distributed PyTorch backend.
Args:
backend: The backend (nccl, gloo, etc.) to use for training.
world_rank: Rank of the current worker.
world_size: Number of workers participating in the job.
init_method: URL specifying how to initialize the process group.
timeout_s: Seconds for process group operations to timeout.
"""
if world_rank == 0:
logger.info(
f"Setting up process group for: {init_method} [rank={world_rank}, "
f"world_size={world_size}]"
)
else:
logger.debug(
f"Setting up process group for: {init_method} [rank={world_rank}, "
f"world_size={world_size}]"
)
logger.debug(f"using {backend}")
if _is_backend_nccl(backend):
# See https://github.com/pytorch/pytorch/blob/c263bd43e8e8502d4726643bc6fd046f0130ac0e/torch/distributed/distributed_c10d.py#L803-L823 # noqa: E501
# We do not use TORCH_NCCL_BLOCKING_WAIT due to performance overhead.
if Version(torch.__version__) < Version("2.2.0"):
TORCH_NCCL_ASYNC_ERROR_HANDLING_ENV_VAR = "NCCL_ASYNC_ERROR_HANDLING"
TORCH_NCCL_BLOCKING_WAIT_ENV_VAR = "NCCL_BLOCKING_WAIT"
else:
TORCH_NCCL_ASYNC_ERROR_HANDLING_ENV_VAR = "TORCH_NCCL_ASYNC_ERROR_HANDLING"
TORCH_NCCL_BLOCKING_WAIT_ENV_VAR = "TORCH_NCCL_BLOCKING_WAIT"
if (
TORCH_NCCL_ASYNC_ERROR_HANDLING_ENV_VAR not in os.environ
and TORCH_NCCL_BLOCKING_WAIT_ENV_VAR not in os.environ
):
logger.debug(
f"Setting {TORCH_NCCL_ASYNC_ERROR_HANDLING_ENV_VAR}=1 to fail if NCCL collective communication operations are timing out. " # noqa: E501
f"To override this behavior, you can set {TORCH_NCCL_ASYNC_ERROR_HANDLING_ENV_VAR}=0." # noqa: E501
)
os.environ[TORCH_NCCL_ASYNC_ERROR_HANDLING_ENV_VAR] = "1"
elif backend == "hccl":
register_custom_torch_dist_backend(backend)
dist.init_process_group(
backend=backend,
init_method=init_method,
rank=world_rank,
world_size=world_size,
timeout=timedelta(seconds=timeout_s),
)
def _shutdown_torch(destroy_process_group=False):
from ray.air._internal.torch_utils import get_devices
devices = get_devices()
if destroy_process_group and dist.is_initialized():
dist.destroy_process_group()
if torch.cuda.is_available():
for device in devices:
if device.type == "cuda":
with torch.cuda.device(device):
torch.cuda.empty_cache()
def _set_torch_distributed_env_vars():
# Same env vars as in
# https://pytorch.org/docs/stable/elastic/run.html#environment-variables
from ray.train.torch import get_device
context = ray.train.get_context()
os.environ["LOCAL_RANK"] = str(context.get_local_rank())
os.environ["LOCAL_WORLD_SIZE"] = str(context.get_local_world_size())
os.environ["NODE_RANK"] = str(context.get_node_rank())
os.environ["RANK"] = str(context.get_world_rank())
os.environ["WORLD_SIZE"] = str(context.get_world_size())
# Makes sure Hugging Face Accelerate uses the correct device
device = get_device()
os.environ["ACCELERATE_TORCH_DEVICE"] = str(device)
class _TorchBackend(Backend):
share_cuda_visible_devices: bool = True
def on_start(self, worker_group: BaseWorkerGroup, backend_config: TorchConfig):
if dist.is_available():
# Set the appropriate training backend.
if backend_config.backend is None:
resources = worker_group.get_resources_per_worker()
num_gpus_per_worker = resources.get("GPU", 0)
if num_gpus_per_worker > 0:
backend = "nccl"
else:
backend = "gloo"
else:
backend = backend_config.backend
master_addr, master_port = worker_group.execute_single(
0, get_address_and_port
)
if backend_config.init_method == "env":
def set_env_vars(addr, port):
os.environ["MASTER_ADDR"] = addr
os.environ["MASTER_PORT"] = str(port)
worker_group.execute(set_env_vars, addr=master_addr, port=master_port)
url = "env://"
elif backend_config.init_method == "tcp":
url = f"tcp://{build_address(master_addr, master_port)}"
else:
raise ValueError(
f"The provided init_method ("
f"{backend_config.init_method}) is not supported. Must "
f"be either 'env' or 'tcp'."
)
setup_futures = []
for i in range(len(worker_group)):
setup_futures.append(
worker_group.execute_single_async(
i,
_setup_torch_process_group,
backend=backend,
world_rank=i,
world_size=len(worker_group),
init_method=url,
timeout_s=backend_config.timeout_s,
)
)
ray.get(setup_futures)
else:
raise RuntimeError("Distributed torch is not available.")
def on_shutdown(self, worker_group: BaseWorkerGroup, backend_config):
futures = worker_group.execute_async(
_shutdown_torch,
destroy_process_group=len(worker_group) > 1,
)
timeout_s = ray_constants.env_integer(
TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
DEFAULT_TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
)
try:
ray.get(futures, timeout=timeout_s)
except GetTimeoutError:
logger.warning(
f"Torch process group shutdown timed out after {timeout_s} seconds"
)
def on_training_start(
self, worker_group: BaseWorkerGroup, backend_config: BackendConfig
):
worker_group.execute(_set_torch_distributed_env_vars)