Skip to content

Distributed prompting/inference utility#1410

Merged
muellerzr merged 27 commits intomainfrom
distributed-prompt
May 17, 2023
Merged

Distributed prompting/inference utility#1410
muellerzr merged 27 commits intomainfrom
distributed-prompt

Conversation

@muellerzr
Copy link
Copy Markdown
Contributor

@muellerzr muellerzr commented May 10, 2023

This PR introduces a new utility in Accelerator, AcceleratorState, and PartialState: Accelerator.split_between_processes.

It is often useful when performing distributed inference in applications such as stable diffusion to send one prompt to GPU A, another prompt to GPU B, and so forth. This PR introduces a new context manager that let's the user send some data in and split it evenly across all instances for them to use. An example application might look like such:

from accelerate import PartialState
from diffusers import DiffusionPipeline

state = PartialState()
pipe= DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)

with state.split_between_processes(["a dog", "a cat"]) as prompt:
    pipe.to(state.device)
    image = pipe(prompt).images[0]

On a two process system, GPU A would receive "a dog" and GPU B would receive "a cat".

This is also especially useful for cases where using a DataLoader to perform the task is too much code, and the user just wants to send in strings or already preprocessed dictionaries and split them.

@muellerzr muellerzr added the enhancement New feature or request label May 10, 2023
@muellerzr muellerzr requested a review from sgugger May 10, 2023 22:50
@muellerzr muellerzr requested a review from pacman100 May 10, 2023 22:58
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

HuggingFaceDocBuilderDev commented May 10, 2023

The documentation is not available anymore as the PR was closed or merged.

@muellerzr muellerzr force-pushed the distributed-prompt branch from 0c484a8 to 866fec0 Compare May 11, 2023 08:51
Copy link
Copy Markdown
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! I think this should contain an option to pad the splits by looping back to the beginning since if users try to gather predictions not of the same shapes, they will get a hang.

Comment thread docs/source/usage_guides/distributed_inference.mdx Outdated
Comment thread src/accelerate/state.py Outdated
@muellerzr muellerzr requested a review from sgugger May 17, 2023 17:29
Copy link
Copy Markdown
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments on the doc but otherwise LGTM!

Comment thread docs/source/usage_guides/distributed_inference.mdx Outdated
Comment thread docs/source/usage_guides/distributed_inference.mdx Outdated
Comment thread docs/source/usage_guides/distributed_inference.mdx Outdated
Comment thread docs/source/usage_guides/distributed_inference.mdx Outdated
Comment thread docs/source/usage_guides/distributed_inference.mdx
muellerzr and others added 4 commits May 17, 2023 13:48
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@muellerzr muellerzr merged commit b93bfac into main May 17, 2023
@muellerzr muellerzr deleted the distributed-prompt branch May 17, 2023 18:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants