Remove compile bottlenecks from ZImage pipeline#13461
Remove compile bottlenecks from ZImage pipeline#13461sayakpaul merged 6 commits intohuggingface:mainfrom
Conversation
|
Thanks for your PR! Can we eliminate all the |
…former Boolean mask indexing (tensor[mask] = val) implicitly calls nonzero(), which triggers a DtoH sync that stalls the CPU while the GPU queue drains. Replacing it with torch.where eliminates these syncs from the transformer's pad-token assignment. Profiling (4-step turbo, fix_2 vs fix_1): - Eager: nonzero CPU time drops from ~2091 ms to <1 ms; index_put eliminated - Compile: nonzero CPU time drops from ~3057 ms to <1 ms; index_put eliminated
|
Here are some comparison stats between commit_1 and commit_2
|
|
all the trace files can be accessed here. The cc: @sayakpaul |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Merging as the outputs with and without this PR are the same: """
Minimal script to verify PR #13461 does not change ZImagePipeline outputs.
Compares latent outputs between the current branch (hitchhiker3010-main, with PR changes)
and the main branch (without PR changes) using a fixed seed.
The PR makes two changes:
1. Precompute cfg_truncation t_norms outside the loop (avoids DtoH sync)
2. Use torch.where instead of boolean mask indexing in the transformer
Both are pure optimizations — outputs should be identical.
Usage:
# On the current branch (with PR changes):
python test_zimage_pr13461.py --save current_branch.pt
# On the main branch (without PR changes):
git checkout main
python test_zimage_pr13461.py --save main_branch.pt
# Compare:
python test_zimage_pr13461.py --compare current_branch.pt main_branch.pt
"""
import argparse
import torch
def run_pipeline():
from diffusers import ZImagePipeline
pipe = ZImagePipeline.from_pretrained(
"Tongyi-MAI/Z-Image-Turbo",
torch_dtype=torch.bfloat16,
)
pipe.to("cuda")
generator = torch.Generator(device="cuda").manual_seed(0)
# Use guidance_scale > 1 and cfg_truncation to exercise both changed code paths.
# Small resolution + few steps + latent output for speed.
output = pipe(
prompt="a cat",
height=256,
width=256,
num_inference_steps=2,
guidance_scale=3.5,
cfg_truncation=0.5,
output_type="latent",
generator=generator,
)
return output.images
def save(latents, path):
torch.save(latents.cpu(), path)
print(f"Saved latents with shape {latents.shape} and dtype {latents.dtype} to {path}")
def compare(path_a, path_b):
a = torch.load(path_a, weights_only=True)
b = torch.load(path_b, weights_only=True)
print(f"Tensor A: shape={a.shape}, dtype={a.dtype}")
print(f"Tensor B: shape={b.shape}, dtype={b.dtype}")
if a.shape != b.shape:
print("FAIL: shapes differ")
return
exact_match = torch.equal(a, b)
max_diff = (a.float() - b.float()).abs().max().item()
print(f"Exact match: {exact_match}")
print(f"Max absolute difference: {max_diff}")
if exact_match:
print("PASS: outputs are identical")
elif max_diff < 1e-3:
print(f"PASS: outputs differ by at most {max_diff} (within tolerance)")
else:
print(f"FAIL: outputs differ by {max_diff}")
def main():
parser = argparse.ArgumentParser(description="ZImage PR #13461 output comparison")
parser.add_argument("--save", type=str, help="Run pipeline and save latents to this path")
parser.add_argument("--compare", nargs=2, metavar=("A", "B"), help="Compare two saved latent files")
args = parser.parse_args()
if args.save:
latents = run_pipeline()
save(latents, args.save)
elif args.compare:
compare(args.compare[0], args.compare[1])
else:
parser.print_help()
if __name__ == "__main__":
main() |
|
Hey @sayakpaul @dg845 Thanks for the contributing opportunity, there must be similar issues in the other Z Image pipelines[controlnet, controlnet_inpaint, img2img, inpaint and omni] along with some pipeline specific ones, can I look into those or do you suggest any other pipelines that might be of priority? |
* [core] Remove DtoH syncs from ZImage pipeline denoising loop * [core] Replace boolean mask indexing with torch.where in ZImage transformer Boolean mask indexing (tensor[mask] = val) implicitly calls nonzero(), which triggers a DtoH sync that stalls the CPU while the GPU queue drains. Replacing it with torch.where eliminates these syncs from the transformer's pad-token assignment. Profiling (4-step turbo, fix_2 vs fix_1): - Eager: nonzero CPU time drops from ~2091 ms to <1 ms; index_put eliminated - Compile: nonzero CPU time drops from ~3057 ms to <1 ms; index_put eliminated --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
What does this PR do?
Fixes performance issues identified by profiling
ZImagePipelinewithtorch.profileras part of #13401 .What does this PR do?
Profiled
ZImagePipeline(usingTongyi-MAI/Z-Image-Turbo) in both eager andtorch.compilemodes following the profiling guide. The Chrome traces revealed two device-to-host (DtoH) synchronization points that break asynchronous GPU execution and preventtorch.compilefrom yielding its full speedup.Pipeline denoising loop:
t_norm = timestep[0].item()DtoH synctimestep[0].item()triggers a GPU→CPU sync every step to readt_normfor CFG truncation logic. Since the full timestep schedule is known before the loop begins, we precompute allt_normvalues into a plain Python list before entering the loop and index into it withi.scheduler.set_begin_index(0)upfront to avoid the DtoH sync in_init_step_index(same pattern as Avoid DtoH sync from access of nonzero() item in scheduler #11696 )Profiling ZImagePipeline
GPU - L4
num_inference_steps - 4,
guidance_scale - 0.0 ( Guidance should be 0 for the Turbo models)
Before

The first scheduler_step took 657.8µs
Number of cudaStreamSynchronize blocks - 19
After

The first scheduler_step took 15.49 µs after this fix
Number of cudaStreamSynchronize blocks - 13
Part of #13401 .
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@sayakpaul @dg845