Skip to content

Conversation

@aphedges
Copy link
Contributor

Fixes #2010.

The first commit is the minimal change needed to get DeepSpeed installed. The second commit that fully removes bf16_support from torch_info. If the second commit goes too far, it can be dropped before this PR is merged. I have tested both and confirmed that they both allow me to install DeepSpeed. See #2010 (comment) for more information about my approach.

@stas00
Copy link
Collaborator

stas00 commented Jun 27, 2022

Looks great, @aphedges! Thank you for working on fixing this!

hip_version = ".".join(torch.version.hip.split('.')[:2])
torch_info = {
"version": torch_version,
"bf16_support": bf16_support,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as long as there is nothing relying on this field it should be safe to remove.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this repository, It was only used by a test, and I've replaced the usage in the test with a runtime check.

The usage in the test was only introduced 3 weeks ago in 7fc3065, so it probably hasn't been used much by other programs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I totally agree, but let's see what the maintainers say.

Copy link
Collaborator

@stas00 stas00 Jun 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as you pointed out the key here is that the hardware environment may chance since the installation and thus a memorized hardware-specific bf16 state can easily be invalid at run time.

This would be especially true for cross-platform/cross-hardware prebuilds.

Copy link
Collaborator

@stas00 stas00 Jun 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeffra, also there is a new development with IPEX - which now supports BF16 on CPU! So I highly recommend to already split up into 2 different state checks - have_bf16_cpu and have_bf16_gpu as we have just went through in transformers. huggingface/transformers#17738

Please feel free to reuse the helper utils.

Copy link
Contributor

@mrwyattii mrwyattii left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not necessary to have bf16_support defined in setup.py, but we cannot have it defined during unit testing either (see the error in torch-latest unit tests)

NCCL_MAJOR > 2 or
(NCCL_MAJOR == 2 and NCCL_MINOR >= 10)) and torch_info['bf16_support']:
NCCL_MAJOR > 2 or (NCCL_MAJOR == 2 and NCCL_MINOR >= 10)
) and torch.cuda.is_available() and torch.cuda.is_bf16_supported():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having a call to torch.cuda.is_bf16_supported() causes a problem with our unit tests (specifically the distributed_test decorator) because this function call initializes CUDA

else:
nccl_version = ".".join(map(str, torch.cuda.nccl.version()[:2]))
if hasattr(torch.cuda, 'is_bf16_supported'):
bf16_support = torch.cuda.is_bf16_supported()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bf16_support value is only used in unit testing - perhaps we could add a try...except block around this?

Copy link
Contributor Author

@aphedges aphedges Jun 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My first commit would be sufficient to prevent the crash because it checks whether CUDA is available.

However, that doesn't solve the problem of building and running on different hardware with different bf16 support.

Copy link
Contributor

@mrwyattii mrwyattii Jun 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bf16 support check is only utilized in unit tests - so it shouldn't affect users that build/run on different hardware

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it's set up that way in CI, but if a developer runs the unit tests on a different machine than the build, it could still cause problems.

@aphedges
Copy link
Contributor Author

It is not necessary to have bf16_support defined in setup.py, but we cannot have it defined during unit testing either (see the error in torch-latest unit tests)

Are you sure the CI failure is related to my changes? I can see the same error in other PRs, such as https://github.com/microsoft/DeepSpeed/runs/7096559886?check_suite_focus=true for #2064.

@mrwyattii
Copy link
Contributor

It is not necessary to have bf16_support defined in setup.py, but we cannot have it defined during unit testing either (see the error in torch-latest unit tests)

Are you sure the CI failure is related to my changes? I can see the same error in other PRs, such as https://github.com/microsoft/DeepSpeed/runs/7096559886?check_suite_focus=true for #2064.

You're right the current failures we're seeing on torch-latest is due to torch releasing 1.12 recently and it doesn't play well with our distributed_test decorator. We've disabled this workflow for now. However, please take a look at #1987 and #1990 where we moved bf16 support checking to setup.py for the reason of it causing issues during unit tests. We also added the torch-latest runner in #1990 as the tests that utilize bf16 are skipped on other runners.

@aphedges
Copy link
Contributor Author

You're right the current failures we're seeing on torch-latest is due to torch releasing 1.12 recently and it doesn't play well with our distributed_test decorator. We've disabled this workflow for now. However, please take a look at #1987 and #1990 where we moved bf16 support checking to setup.py for the reason of it causing issues during unit tests. We also added the torch-latest runner in #1990 as the tests that utilize bf16 are skipped on other runners.

I hadn't read the explanation on #1990. Thanks for pointing it out!

As much I disagree that putting it in setup.py is the right place, I care the most about being able to compile deepspeed properly. Would you be fine with this PR if I drop my second commit, or do you want any other changes?

@mrwyattii
Copy link
Contributor

As much I disagree that putting it in setup.py is the right place, I care the most about being able to compile deepspeed properly. Would you be fine with this PR if I drop my second commit, or do you want any other changes?

Yes, after dropping the second commit this should be good to go. And I agree that we should check for bf16 support (and perhaps nccl version too) at runtime - but we need to first resolve the issue we're having with distributed_test decorator. I think given that it's only used in unit tests, it's ok for now.

@aphedges aphedges force-pushed the 2010-fix-no-gpu-build branch from 6b141e1 to 08536ec Compare June 29, 2022 17:34
@aphedges
Copy link
Contributor Author

@mrwyattii, I've removed the second commit and rebased my branch. I can confirm that I can install with out errors on a machine without GPUs.

@aphedges
Copy link
Contributor Author

The failure in https://github.com/microsoft/DeepSpeed/runs/7118253963?check_suite_focus=true seems like a flaky performance test. DeepSpeed inference was around 10% slower than expected.

@aphedges
Copy link
Contributor Author

@mrwyattii, can you re-run the nv-inference unit-tests?

@aphedges aphedges force-pushed the 2010-fix-no-gpu-build branch from 166815c to 4b20a92 Compare June 30, 2022 17:01
@aphedges
Copy link
Contributor Author

@mrwyattii, I force-pushed, and the checks passed this time. Can this PR be approved and merged?

@mrwyattii mrwyattii merged commit 3540ce7 into deepspeedai:master Jul 6, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] Impossible to prebuild w/o having at least one gpu

3 participants