Fix global_device_count(), local_device_count() for single process on CUDA#6022
Fix global_device_count(), local_device_count() for single process on CUDA#6022vanbasten23 merged 33 commits intomasterfrom
Conversation
| std::optional<std::set<int>> allowed_devices; | ||
| if (global_world_size > 1) { | ||
| allowed_devices = | ||
| std::make_optional<std::set<int>>(std::set{local_process_rank}); |
There was a problem hiding this comment.
nit: do you still need to make_optional here, or can you directly assign? e.g. allowed_devices = std::set{local_process_rank}
|
Thanks for the review! |
|
A few tests are failing:
Failed at The test is doing an all-reduce (
I think we can disable this test because it tests the class
|
| auto allowed_devices = | ||
| std::make_optional<std::set<int>>(std::set{local_process_rank}); | ||
| std::optional<std::set<int>> allowed_devices; | ||
| if (global_world_size > 1) { |
There was a problem hiding this comment.
Synced offline - this is great for single-host single-process development, but in cases where there is a single process per host, this would break in a multihost environment. Outside of SPMD, I'm not aware of a use case for a multihost environment with a single process per host (cc @JackCaoG)
Since we don't officially support SPMD on GPU at the moment, this looks fine to me for now. Once we decide on the right entrypoint for SPMD, we'll need to revisit this.
54bad35 to
d0faac5
Compare
|
hi @JackCaoG , this PR fixes |
|
@vanbasten23 can you rebase and rerun the CI? |
3292787 to
fe03a80
Compare
will-cromar
left a comment
There was a problem hiding this comment.
LGTM once you fix the formatting. Thanks!
|
|
||
|
|
||
| @unittest.skipIf(xr.device_type() == 'CUDA', | ||
| 'Parallelism for DataParallel uses multi-threads. But cuda assumes one GPU device per process instead of relying on threads.') |
There was a problem hiding this comment.
Not for this PR, but IMO we should just delete these tests. Do we support DataParallel anymore @JackCaoG?
fe03a80 to
b98ef93
Compare
This comment was marked as outdated.
This comment was marked as outdated.
cbeabb8 to
6c2b64f
Compare
| kv_store = xla::GetDistributedKeyValueStore(distributed_client, | ||
| /*key_prefix=*/"gpu:"); | ||
| std::optional<std::set<int>> allowed_devices; | ||
| bool spmd = sys_util::GetEnvBool("XLA_USE_SPMD", false); |
There was a problem hiding this comment.
Conditioning on SPMD mode here could cause issues using xr.use_spmd() after the runtime has been initialized.
Is it correct to say that allowed_devices is only needed in the MP case? If so, can we invert the logic to check for MP using one of the env vars instead of checking for SPMD mode?
There was a problem hiding this comment.
Is there ever a reason to call xr.use_spmd() after the runtime is initialized? In any case, I think we can also assume that if LOCAL_WORLD_SIZE=1, then we can use all of the devices (which should be compatible with SPMD)
There was a problem hiding this comment.
AFAIK there's not a strong use case at the moment, but for example our unit tests will check xr.global_runtime_device_count() before calling xr.use_spmd(). Keeping the runtime independent of SPMD mode was something we wanted to maintain, cc @yeounoh
There was a problem hiding this comment.
That's a good point!
Conditioning on SPMD mode here could cause issues using xr.use_spmd() after the runtime has been initialized.
This seems to be a downside of using xr.use_spmd() as opposed to a env flag XLA_USE_SPMD=1. With the latter, it's less flexible but less error-prone. It guarantee we'll use spmd mode at the beginning. With the former, it may also impact other SPMD special cases: user does something pytorch ops, then call xr.use_spmd(), then continue to do something else.
can we invert the logic to check for MP using one of the env vars instead of checking for SPMD mode?
I think we can also assume that if LOCAL_WORLD_SIZE=1, then we can use all of the devices
I'm thinking about the case where the user has 2 GPU machines and she wants to use 1 GPU device on each machine and to do multi-host training. In that case (multi-host-single-process), each process has access to all devices and I guess the user can still do multi-host training
There was a problem hiding this comment.
I think we can also assume that if LOCAL_WORLD_SIZE=1, then we can use all of the devices
Perhaps we also need to check GLOBAL_WORLD_SIZE:
if LOCAL_WORLD_SIZE==1:
if GLOBAL_WORLD_SIZE>1: # multi-host-single-process
initialize coordinator service
else: # single-host-single-process
do nothing
else: multi-process for single-host and multi-host
initialize coordinator service
allowd_devices={current_device}
|
|
||
| std::unique_ptr<XlaCoordinator> SetKeyValueCallback( | ||
| int global_process_rank, int global_world_size, | ||
| std::unique_ptr<XlaCoordinator> coordinator, |
There was a problem hiding this comment.
Why do we need the coordinator as input here?
There was a problem hiding this comment.
We need it get the DistributedRuntimeClient and later create the kv_store below
There was a problem hiding this comment.
It looks like it's being recreated on L60 - should we just make this function return the new value?
|
It looks like a bunch of tests are failing with error One of the examples is The error doesn't exist on the current master branch (01/19, after the openxla pin update). Also, the error doesn't exist on the feature branch before the pin update: #6346. Probably something happened during the pin update. |
960d609 to
aee08df
Compare
jonb377
left a comment
There was a problem hiding this comment.
Looking good, thanks Xiongfei!
| # if self.n_devices>=4, mesh=(2, 2) | ||
| # if self.n_devices>=2, mesh=(2,1) | ||
| # if self.n_devices=1, mesh=(1,1) |
There was a problem hiding this comment.
Thanks for generalizing these tests!
Could we change these comments to e.g. # if self.n_devices==4, mesh=(2, 2)? Other device counts will have different meshes.
| if (local_world_size == 1) { | ||
| if (global_world_size > 1) { | ||
| coordinator = SetGpuClientKVCallBack(global_process_rank, | ||
| global_world_size, kv_store); | ||
| } | ||
| } else { | ||
| allowed_devices = std::set{local_process_rank}; | ||
| coordinator = SetGpuClientKVCallBack(global_process_rank, | ||
| global_world_size, kv_store); | ||
| } | ||
|
|
||
| std::shared_ptr<xla::KeyValueStoreInterface> kv_store; | ||
| if (global_world_size > 1) { | ||
| // Use the distributed key-value store from DistributedRuntimeClient. | ||
| coordinator = std::make_unique<XlaCoordinator>( | ||
| global_process_rank, global_world_size, master_addr, port); | ||
| std::shared_ptr<xla::DistributedRuntimeClient> distributed_client = | ||
| coordinator->GetClient(); | ||
| kv_store = xla::GetDistributedKeyValueStore(distributed_client, | ||
| /*key_prefix=*/"gpu:"); | ||
| } | ||
| TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id=" | ||
| << global_process_rank << ", num_nodes=" << global_world_size; | ||
|
|
There was a problem hiding this comment.
I think we could simplify the logic here some. We want to restrict allowed_devices if local_world_size > 1 and create the coordinator if global_world_size > 1. I'm assuming local_world_size > 1 => global_world_size > 1, would this be equivalent?
if (local_world_size > 1) {
allowed_devices = std::set{local_process_rank};
}
if (global_world_size > 1) {
// We can keep the old initialization block here and remove `SetGpuClientKVCallBack`
}|
Thanks for the review! |
| TF_VLOG(INFO) << "OpSharding (ShardingType: " << sharding_type << "):\n" | ||
| << sharding.DebugString(); | ||
| << sharding.DebugString() | ||
| << ", sharding.type()=" << sharding.type(); |
There was a problem hiding this comment.
DebugString should include the type?
There was a problem hiding this comment.
Actually, it doesn't. The debugString is empty in that case.
This PR fixes
for single process on CUDA so that it returns all GPU devices on the current host.
Before this PR, both APIs always return 1 as reported in this issue.).
global_runtime_device_count is not fixed since it seems it's only used in spmd case and it's been fixed in another pr.
Test:
Note: here is the behavior of torch.cuda.device_count() on multi-host case