Skip to content

Enable automatic_dynamic_shapes by default#103623

Closed
ezyang wants to merge 20 commits intogh/ezyang/2191/basefrom
gh/ezyang/2191/head
Closed

Enable automatic_dynamic_shapes by default#103623
ezyang wants to merge 20 commits intogh/ezyang/2191/basefrom
gh/ezyang/2191/head

Conversation

@ezyang
Copy link
Copy Markdown
Contributor

@ezyang ezyang commented Jun 14, 2023

Stack from ghstack (oldest at bottom):

Some notes:

  • I now manually turn off _generate jobs from running with cudagraphs, as it is unrealistic to expect to cudagraph autoregressive generation up to max sequence length, this would imply compiling the entire unrolled sequence generation. Concretely, cm3leon_generate was timing out post this change, likely due to the compile time slowdown of dynamic shapes ON TOP OF accidentally unrolling all the loops
  • A few torch._dynamo.reset tactically inserted to force recompiles on tests that expected it
  • expectedFailureAutomaticDynamic flip into patching automatic_dynamic_shapes=False

Signed-off-by: Edward Z. Yang ezyang@meta.com

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @kadeng @anijain2305 @msaroufim

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Jun 14, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/103623

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 53e4893:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 msaroufim

[ghstack-poisoned]
Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 msaroufim

[ghstack-poisoned]
Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 msaroufim

[ghstack-poisoned]
Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 msaroufim

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Jun 15, 2023
Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: 91748e4
Pull Request resolved: #103623
Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 anijain2305 msaroufim

[ghstack-poisoned]
Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 anijain2305 msaroufim

[ghstack-poisoned]
@ezyang ezyang changed the title [TESTING] Enable automatic_dynamic_shapes by default Enable automatic_dynamic_shapes by default Jun 16, 2023
For safety, we don't turn this on in fbcode. Goal is to flush out bugs in OSS nightlies first.

Some notes:

* I now manually turn off `_generate` jobs from running with cudagraphs, as it is unrealistic to expect to cudagraph autoregressive generation up to max sequence length, this would imply compiling the entire unrolled sequence generation. Concretely, cm3leon_generate was timing out post this change, likely due to the compile time slowdown of dynamic shapes ON TOP OF accidentally unrolling all the loops
* A few torch._dynamo.reset tactically inserted to force recompiles on tests that expected it
* expectedFailureAutomaticDynamic flip into patching automatic_dynamic_shapes=False

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 anijain2305 msaroufim

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Jun 18, 2023
Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: 15a6604
Pull Request resolved: #103623
@ezyang ezyang requested a review from eellison June 18, 2023 22:41
assert not self.outputs_weakrefs
for out, static_output_tensor in zip(outputs, self.static_output_tensors):
if out is None or static_output_tensor is not None:
if not isinstance(out, torch.Tensor) or static_output_tensor is not None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add test case for when we would return dynamic shape output in cudagraph trees ?

I don't think we're hitting the 3rd invocation, reconstruct output path:

def reconstruct_outputs(self):
"Reconstruct output tensors according to their saved metadata and alias information"
# Cached tensors will not yet be set on the first execution
# They are also cleared in checkpointing, so if we checkpoint this node
# and then execute it again we will need to repopulate cached tensors
if not self.cached_tensor_outputs:
self._initialize_cached_tensors()
outputs = []
for i, (storage_info, metadata) in enumerate(
zip(self.output_storage_alias, self.outputs_metadata)
):
if metadata is None:
outputs.append(None)
continue
cached_t = self.cached_tensor_outputs[i]
if cached_t is not None:
# No need to update weakrefs, already correctly initialized
outputs.append(cached_t)
continue
static_t = self.static_output_tensors[i]
if static_t is not None:
assert self.outputs_weakrefs[i] is None
outputs.append(static_t)
continue
storage = self.prepare_alias_info_for_tensor_construction(
storage_info, metadata
)
if isinstance(storage, UntypedStorage) or storage is None:
out = self._reconstruct_from_tensor_metadata(metadata, storage)
else:
assert isinstance(storage, int)
out = self._reconstruct_from_tensor_metadata(
metadata, outputs[storage].untyped_storage()
)
outputs.append(out)
self.outputs_weakrefs[i].swap_weakref(out.untyped_storage()._weak_ref())
return outputs

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, I already have this test case 🤔

    @config.patch({"triton.cudagraphs": True})
    @dynamo_config.patch(
        automatic_dynamic_shapes=True,
        assume_static_by_default=False,
    )               
    def test_dynamic_to_static_cudagraphs(self):
        for b in [False, True]:
            with config.patch({"triton.cudagraph_trees": b}):

                @torch._dynamo.optimize("inductor")
                def fn(x, y):
                    r = x + y
                    return r, r.size(0)

                inputs = (
                    torch.randn((5, 5), device="cuda"),
                    torch.randn((5, 5), device="cuda"),
                )
                self.assertTrue(same(fn(*inputs), (inputs[0] + inputs[1], 5)))
        
                inputs = (
                    torch.randn((6, 6), device="cuda"),
                    torch.randn((6, 6), device="cuda"),
                )
                self.assertTrue(same(fn(*inputs), (inputs[0] + inputs[1], 6)))

So I don't understand why this is not exercising these cases...

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I spent a while trying to understand num_fixed setting in inductor and... I give up. I'm going to argue that the yolov3 is good enough test coverage, and also these changes are "obviously OK" and we can easily keep point fixing these until we've got all the sites.

If you want to argue that we need more clear invariants about the data structures, I'm fine with this, but we should use mypy to work it out, and then apply uniform annotations all throughout cudagraph_trees.py

Copy link
Copy Markdown
Contributor

@eellison eellison Jun 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it's because we're adding the symints in the front of the calling convention, and that is messing up the num_fixed logic. https://github.com/pytorch/pytorch/blob/main/torch/_functorch/aot_autograd.py#L2990-L2992. Do we have a test case which exercises the saved symints ?

I think it's unlikely mypy is going to fix our bug here, although I can add more annotations to cudagraph_trees.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison I also thought about this, but it doesn't seem like it should be it. Input SymInt is passed in as example values, so it should not "count" as fixed.

For safety, we don't turn this on in fbcode. Goal is to flush out bugs in OSS nightlies first.

Some notes:

* I now manually turn off `_generate` jobs from running with cudagraphs, as it is unrealistic to expect to cudagraph autoregressive generation up to max sequence length, this would imply compiling the entire unrolled sequence generation. Concretely, cm3leon_generate was timing out post this change, likely due to the compile time slowdown of dynamic shapes ON TOP OF accidentally unrolling all the loops
* A few torch._dynamo.reset tactically inserted to force recompiles on tests that expected it
* expectedFailureAutomaticDynamic flip into patching automatic_dynamic_shapes=False

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 anijain2305 msaroufim

[ghstack-poisoned]
For safety, we don't turn this on in fbcode. Goal is to flush out bugs in OSS nightlies first.

Some notes:

* I now manually turn off `_generate` jobs from running with cudagraphs, as it is unrealistic to expect to cudagraph autoregressive generation up to max sequence length, this would imply compiling the entire unrolled sequence generation. Concretely, cm3leon_generate was timing out post this change, likely due to the compile time slowdown of dynamic shapes ON TOP OF accidentally unrolling all the loops
* A few torch._dynamo.reset tactically inserted to force recompiles on tests that expected it
* expectedFailureAutomaticDynamic flip into patching automatic_dynamic_shapes=False

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 anijain2305 msaroufim

[ghstack-poisoned]
For safety, we don't turn this on in fbcode. Goal is to flush out bugs in OSS nightlies first.

Some notes:

* I now manually turn off `_generate` jobs from running with cudagraphs, as it is unrealistic to expect to cudagraph autoregressive generation up to max sequence length, this would imply compiling the entire unrolled sequence generation. Concretely, cm3leon_generate was timing out post this change, likely due to the compile time slowdown of dynamic shapes ON TOP OF accidentally unrolling all the loops
* A few torch._dynamo.reset tactically inserted to force recompiles on tests that expected it
* expectedFailureAutomaticDynamic flip into patching automatic_dynamic_shapes=False

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 anijain2305 msaroufim

[ghstack-poisoned]
Some notes:

* I now manually turn off `_generate` jobs from running with cudagraphs, as it is unrealistic to expect to cudagraph autoregressive generation up to max sequence length, this would imply compiling the entire unrolled sequence generation. Concretely, cm3leon_generate was timing out post this change, likely due to the compile time slowdown of dynamic shapes ON TOP OF accidentally unrolling all the loops
* A few torch._dynamo.reset tactically inserted to force recompiles on tests that expected it
* expectedFailureAutomaticDynamic flip into patching automatic_dynamic_shapes=False

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 anijain2305 msaroufim

[ghstack-poisoned]
@ezyang
Copy link
Copy Markdown
Contributor Author

ezyang commented Jul 5, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@wonjoo-wj
Copy link
Copy Markdown
Collaborator

Hi @ezyang, it seems like this PR is causing an XLA dynamo test test_simple_model_with_different_input_shape to fail at https://github.com/pytorch/xla/blob/master/test/dynamo/test_dynamo.py#L72:

======================================================================
ERROR: test_simple_model_with_different_input_shape (__main__.DynamoInferenceBasicTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test/dynamo/test_dynamo.py", line 84, in test_simple_model_with_different_input_shape
    res_xla_dynamo_3 = self.fn_simple_dynamo(xla_z, xla_z)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 294, in _fn
    return fn(*args, **kwargs)
  File "test/dynamo/test_dynamo.py", line 47, in fn_simple_dynamo
    @torch.compile(backend='torchxla_trace_once')
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 294, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/backends/torchxla.py", line 24, in fwd
    compiled_graph = bridge.extract_compiled_graph(model, args)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/dynamo_bridge.py", line 398, in extract_compiled_graph
    xla_args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization(
TypeError: _check_tensor_need_materialization(): incompatible function arguments. The following argument types are supported:
    1. (arg0: List[torch.Tensor]) -> List[bool]

Invoked with: (10, 10, tensor([[-1.3407e+00, -5.8537e-01,  5.3619e-01,  5.2462e-01,  1.1412e+00,
          5.1644e-02,  7.2811e-01, -7.1064e-01, -1.0495e+00,  6.0390e-01],
        [-1.7223e+00, -8.2777e-01,  1.3347e+00,  4.8354e-01, -1.9756e-01,
          1.2683e+00,  7.8459e-01,  2.8647e-02,  6.4076e-01,  5.8325e-01],
        [ 1.0669e+00, -4.5015e-01, -6.7875e-01,  5.7432e-01,  4.0476e-01,
          1.7847e-01,  2.6491e-01,  1.2732e+00, -1.3109e-03, -3.0360e-01],
        [-9.8644e-01,  1.2330e-01, -5.9915e-01,  4.7706e-01,  7.2618e-01,
          9.1152e-02, -3.8907e-01,  5.2792e-01,  1.0311e+00, -7.0477e-01],
        [ 1.3254e-01,  7.6424e-01,  1.0950e+00,  3.3989e-01,  7.1997e-01,
          4.1141e-01, -5.7332e-01,  5.0686e-01, -1.4364e+00, -1.1299e+00],
        [-1.3603e-01,  1.6354e+00,  6.5474e-01,  5.7600e-01, -3.6091e-01,
         -6.0590e-02, -1.8058e+00,  9.2543e-01, -3.7534e-01,  1.0331e+00],
        [-6.8665e-01,  6.3681e-01,  2.1755e-01, -4.6655e-02,  1.6192e+00,
          1.4506e+00,  2.6948e-01, -2.1038e-01, -7.3280e-01,  1.0430e-01],
        [ 1.0414e+00, -3.9973e-01, -4.6569e-01,  1.6048e+00, -2.4801e+00,
         -4.1754e-01, -1.1955e+00,  8.1234e-01, -3.0628e-01, -3.3016e-01],
        [ 2.4859e-02, -3.4595e-01,  2.8683e-01, -7.3084e-01, -1.1360e+00,
         -5.2260e-01,  7.1654e-01,  1.5335e+00, -1.4510e+00, -7.8614e-01],
        [-9.5632e-01, -1.2476e+00,  7.0427e-01,  7.0988e-01, -1.5326e+00,
         -7.2513e-01,  4.6640e-01,  6.6672e-01, -4.3871e-02,  2.3681e-01]],
       device='xla:0'))

----------------------------------------------------------------------
Ran 1 test in 3.403s

FAILED (errors=1)

The purpose of this unit test was to ensure that Dynamo in XLA recompiles when input shape changes. However, now it seems like Dynamo is just passing the new input shape to XLA without causing recompilation. And from the error message, we can see that Dynamo is now passing scalar to XLA as xla_args and XLA currently expects all xla_args to be tensor on xla device.

While we investigate more and try to put out a fix, would it be okay to revert this PR to prevent more related changes coming in?

cc @JackCaoG

@JackCaoG
Copy link
Copy Markdown
Collaborator

JackCaoG commented Jul 5, 2023

For the context, Pytorch/XLA expect dynamo to pass us a list of xla_arg which are all tensors on xla device. After this change, we started to see scalar in this list hence the error. I think this test is not included in pytorch CI since the XLA CI pin is old. I am working with @kit1980 to update the XLA pin. Currently we disabled this unit test in pytorch/xla@b799e20 to keep xla side CI green.

@ezyang
Copy link
Copy Markdown
Contributor Author

ezyang commented Jul 6, 2023

Sorry about the delay. For XLA I would advise you to change this configuration variable to False when XLA Dynamo is being used. Are you able to easily do so?

@wonjoo-wj
Copy link
Copy Markdown
Collaborator

Thanks for the advice, Ed. We've temporarily disabled the torch._dynamo.config.automatic_dynamic_shapes in XLA for now as per pytorch/xla#5285 while we come up with a long term fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants