-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Description
Describe the bug
A clear and concise description of what the bug is.
The following with enable_cuda_graph=True breaks as the entire model can't be traced under a cuda graph.
hf_auth_key = os.getenv("HF_AUTH_KEY")
if not hf_auth_key:
raise ValueError("HF_AUTH_KEY is not set")
pipe = diffusers.StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
use_auth_token=hf_auth_key,
torch_dtype=torch.float16,
revision="fp16")
print(pipe)
pipe = deepspeed.init_inference(pipe.to("cuda"), dtype=torch.float16, enable_cuda_graph=True)However, by introducing a global cuda graph and local ones through the policies like DSUnet, etc.. works. This enables to gain an extra 100-150 ms.
class InferenceEngine(InferenceEngine):
def __init__(self, *args, enable_cuda_graph_global: bool = False, **kwargs):
super().__init__(*args, **kwargs)
self.enable_cuda_graph_global = enable_cuda_graph_global
def forward(self, *inputs, **kwargs):
"""Execute forward propagation
Arguments:
*inputs: Variable length input list
**kwargs: variable length keyword arguments
"""
start = None
if self.model_profile_enabled and self.enable_cuda_graph_global:
torch.cuda.synchronize()
start = time.time()
if self.enable_cuda_graph_global:
if self.cuda_graph_created:
outputs = self._graph_replay(*inputs, **kwargs)
else:
self._create_cuda_graph(*inputs, **kwargs)
outputs = self._graph_replay(*inputs, **kwargs)
else:
outputs = self.module(*inputs, **kwargs)
if self.model_profile_enabled and self.enable_cuda_graph_global:
torch.cuda.synchronize()
duration = time.time() - start
self._model_times.append(duration)
return outputsYou can found the code there: https://github.com/Lightning-AI/stablediffusion/blob/lit/ldm/deepspeed_replace.py#L34
To Reproduce
Steps to reproduce the behavior:
- Simple inference script to reproduce
- What packages are required and their versions
- How to run the script
- ...
Expected behavior
A clear and concise description of what you expected to happen.
ds_report output
Please run ds_report to give us details about your setup.
Screenshots
If applicable, add screenshots to help explain your problem.
System info (please complete the following information):
- OS: [e.g. Ubuntu 18.04]
- GPU count and types [e.g. two machines with x8 A100s each]
- (if applicable) what DeepSpeed-MII version are you using
- (if applicable) Hugging Face Transformers/Accelerate/etc. versions
- Python version
- Any other relevant info about your setup
Docker context
Are you using a specific docker image that you can share?
Additional context
Add any other context about the problem here.