Skip to content

[dynamo] Add FakeProcessGroup support for fx_graph_runnable with distributed collectives#157162

Closed
skarjala wants to merge 14 commits intogh/skarjala/10/basefrom
gh/skarjala/10/head
Closed

[dynamo] Add FakeProcessGroup support for fx_graph_runnable with distributed collectives#157162
skarjala wants to merge 14 commits intogh/skarjala/10/basefrom
gh/skarjala/10/head

Conversation

@skarjala
Copy link
Contributor

@skarjala skarjala commented Jun 27, 2025

Stack from ghstack (oldest at bottom):

Summary:

  • Modified generate_compiler_repro_string() to automatically detect distributed operations and inject FakeProcessGroup setup code
  • Added distributed collective tests in test/dynamo/test_fx_graph_runnable.py using FakeProcessGroup API to test distributed collective operations
  • Generated fx_graph_runnable code now runs successfully standalone when containing distributed operations
os.environ['TORCHINDUCTOR_CACHE_DIR'] = '/var/folders/fd/kcv8m1kn0lqgxz42wvgr46sc0000gn/T/torchinductor_skarjala'

import torch
from torch import tensor, device
import torch.fx as fx
from torch._dynamo.testing import rand_strided
from math import inf
import torch._inductor.inductor_prims
import torch.distributed as dist
from torch.testing._internal.distributed.fake_pg import FakeStore

import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config
import torch.fx.experimental._config


torch._functorch.config.functionalize_rng_ops = False
torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access = True
torch._functorch.config.unlift_effect_tokens = True



isolate_fails_code_str = None




# torch version: 2.9.0a0+gitf23d314
# torch cuda version: None
# torch git version: f23d31463ca452918e23063409a2bdc55efc0d46


# torch.cuda.is_available()==False, no GPU info collected

from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    
    
    def forward(self, arg0_1):
        all_reduce = torch.ops._c10d_functional.all_reduce.default(arg0_1, 'sum', '0')
        wait_tensor = torch.ops._c10d_functional.wait_tensor.default(all_reduce);  all_reduce = None
        mul = torch.ops.aten.mul.Tensor(wait_tensor, 2)
        copy_ = torch.ops.aten.copy_.default(arg0_1, wait_tensor);  arg0_1 = wait_tensor = copy_ = None
        return (mul,)
        
def load_args(reader):
    buf0 = reader.storage(None, 64)
    reader.tensor(buf0, (4, 4), is_leaf=True)  # arg0_1
load_args._version = 0
mod = Repro()
if __name__ == '__main__':
    from torch._dynamo.repro.after_aot import run_repro
    # Initialize FakeProcessGroup for distributed operations
    store = FakeStore()
    dist.init_process_group(
        backend="fake",
        rank=0,
        world_size=2,
        store=store
    )
    with torch.no_grad():
        run_repro(mod, load_args, accuracy=False, command='run', save_dir=None, tracing_mode='real', check_str=None)
        # To run it separately, do 
        # mod, args = run_repro(mod, load_args, accuracy=False, command='get_args', save_dir=None, tracing_mode='real', check_str=None)
        # mod(*args)
    dist.destroy_process_group()

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jun 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157162

Note: Links to docs will display an error until the docs builds have been completed.

❌ 5 Cancelled Jobs

As of commit 544ce40 with merge base 178fe7a (image):

CANCELLED JOBS - The following jobs were cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

skarjala added a commit that referenced this pull request Jun 27, 2025
@skarjala skarjala changed the title Add FakeProcessGroup support for fx_graph_runnable with distributed collectives [dynamo] Add FakeProcessGroup support for fx_graph_runnable with distributed collectives Jun 27, 2025
@skarjala skarjala requested review from StrongerXi, bdhirsh and xmfan June 27, 2025 21:43
[ghstack-poisoned]
skarjala added a commit that referenced this pull request Jul 1, 2025
@skarjala skarjala requested a review from bdhirsh July 1, 2025 22:32
[ghstack-poisoned]
skarjala added a commit that referenced this pull request Jul 2, 2025
[ghstack-poisoned]
skarjala added a commit that referenced this pull request Jul 3, 2025
@skarjala skarjala marked this pull request as draft July 3, 2025 18:58
[ghstack-poisoned]
skarjala added a commit that referenced this pull request Jul 3, 2025
@skarjala skarjala marked this pull request as ready for review July 3, 2025 21:40
@skarjala skarjala marked this pull request as draft July 8, 2025 22:10
@skarjala skarjala marked this pull request as ready for review July 8, 2025 22:26
Comment on lines +13 to +15
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE
from torch.testing._internal.distributed.fake_pg import FakeStore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There will probably be some failed tests, the distributed imports other than torch.distributed must be gated under a check:

if torch.distributed.is_available():
  from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
  from torch.testing._internal.distributed.fake_pg import FakeStore

fd.write(
"import torch.distributed as dist\n"
"from torch.testing._internal.distributed.fake_pg import FakeStore\n"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you include a generated fx graph runnable into the PR summary? let's make sure the imports still stay at the top of the file, with others

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed


# Add distributed cleanup after run_repro
if has_distributed_ops:
fd.write("dist.destroy_process_group()\n")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also would like to double check with the generated fx graph runnable file, there's no identation for this line?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

[ghstack-poisoned]
# Add distributed cleanup after run_repro if needed
if has_distributed_ops:
fd.write("dist.destroy_process_group()\n")
fd.write(" \n dist.destroy_process_group()\n")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: take a look at the codegen'd fx graph runnable:

with torch.no_grad():
        run_repro(mod, load_args, accuracy=False, command='run', save_dir=None, tracing_mode='real', check_str=None)
    
    dist.destroy_process_group()
        # To run it separately, do 
        # mod, args = run_repro(mod, load_args, accuracy=False, command='get_args', save_dir=None, tracing_mode='real', check_str=None)
        # mod(*args)

See how the comment below the run_repro is indented? the original intent is likely for people to uncomment those lines if they need them. But now with your added destroy process group, uncommenting those lines would error. Those lines need to run under the no_grad context, so I'd recommend you to move the destroy process group after those comment lines

from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
from torch.testing._internal.distributed.fake_pg import FakeStore
else:
# Define dummy classes if distributed is not available
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests using these classes should be skipped when distributed is not available

@skarjala skarjala mentioned this pull request Jul 9, 2025
skarjala added 4 commits July 9, 2025 10:02
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Comment on lines +448 to +450
)

fd.write(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems unnecessary

[ghstack-poisoned]
@skarjala
Copy link
Contributor Author

@pytorchbot merge -i

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 10, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 5 checks: pull / linux-jammy-py3.9-clang12 / build, pull / linux-jammy-cuda12.8-py3.10-gcc11-build-distributed / build, pull / linux-jammy-py3-clang12-executorch / build, pull / before-test / target-determination, inductor-rocm / rocm-py3.10-inductor / test (inductor, 1, 2, linux.rocm.gpu.2)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Jul 22, 2025
@github-actions github-actions bot deleted the gh/skarjala/10/head branch August 10, 2025 02:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo release notes: fx release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants