Skip to content

Add support for dynamic shape in dynamo#7676

Merged
JackCaoG merged 12 commits intomasterfrom
wonjoo/dynamo-dynamic-shape
Jul 23, 2024
Merged

Add support for dynamic shape in dynamo#7676
JackCaoG merged 12 commits intomasterfrom
wonjoo/dynamo-dynamic-shape

Conversation

@wonjoo-wj
Copy link
Copy Markdown
Collaborator

Fixes #7614


TODO

  • Remove debugging code and add comments
  • Add unit tests
  • Handle error case when TorchDynamo passes us int types

@wonjoo-wj
Copy link
Copy Markdown
Collaborator Author

With the current changes, the following code generates correct results without recompiling the graph:

    ###
    # torch.compile dynamic shape ON
    torch._dynamo.config.automatic_dynamic_shapes = True
    compiled_fn = torch.compile(fn, backend='openxla', dynamic=True)
    a = torch.randn(3, 4, device=device)
    b = torch.ones(4, device=device)
    ret = compiled_fn(a, b)
    xm.mark_step()
    print(f'[Testing] {ret=}')
    print(f'--------------------')

    c = torch.randn(4, 5, device=device)
    d = torch.ones(5, device=device)
    ret2 = compiled_fn(c, d)
    xm.mark_step()
    print(f'[Testing] {ret2=}')
    print(f'--------------------')

As for next steps, I'll clean up some code and add some unit tests.

@wonjoo-wj wonjoo-wj force-pushed the wonjoo/dynamo-dynamic-shape branch from 630bba3 to 0ef3768 Compare July 13, 2024 00:24
@wonjoo-wj wonjoo-wj changed the title [WIP] Add support for dynamic shape in dynamo Add support for dynamic shape in dynamo Jul 15, 2024
Comment thread test/dynamo/test_dynamo.py Outdated
@wonjoo-wj wonjoo-wj marked this pull request as ready for review July 15, 2024 21:23
@wonjoo-wj wonjoo-wj requested a review from JackCaoG July 16, 2024 17:07
@JackCaoG
Copy link
Copy Markdown
Collaborator

seems like a bunch of test failed and a lot of them are real failures. @wonjoolee95 let me know if you need help debugging them

Comment thread test/dynamo/test_dynamo.py Outdated
Comment thread torch_xla/core/dynamo_bridge.py Outdated
Comment thread torch_xla/core/dynamo_bridge.py Outdated
Comment thread test/dynamo/test_dynamo.py Outdated
Comment on lines +415 to +416
# self.assertTrue(
# torch.allclose(output_cpu_new_shape, output_new_shape.cpu(), rtol=1e-05, atol=1e-05))
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is odd. When I run these tests, the allclose fails because in some iteration of the data loader with this new_shape, the differences are as big as 0.2.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fixed with the explicit mark_step call within else statement under torch._dynamo.config.assume_static_by_default.

Comment thread torch_xla/core/dynamo_bridge.py
@wonjoo-wj wonjoo-wj force-pushed the wonjoo/dynamo-dynamic-shape branch from e8334d4 to fcd08bb Compare July 19, 2024 03:39
@wonjoo-wj wonjoo-wj force-pushed the wonjoo/dynamo-dynamic-shape branch from 858359e to 8b8897c Compare July 19, 2024 21:52
Comment thread test/dynamo/test_dynamo.py Outdated
for data, _ in loader_new_shape:
output_new_shape = dynamo_resnet18(data)
output_cpu_new_shape = resnet18(data.cpu())
# # TPU has some precision issues, skipping allclose check
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove one #

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

Comment thread test/dynamo/test_dynamo.py Outdated
output_new_shape.cpu(),
rtol=1e-05,
atol=1e-05))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe also check the CompileTime and ExecuteTime here

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also can you make another test to test the case of

fn(shape_a)
fn(shape_b)
fn(shape_c)
fn(shape_a)

want to make sure we don't forgot the old shapes that's cached.

Comment thread torch_xla/core/dynamo_bridge.py Outdated
# Values: tuple of (xla_args_sharding_spec, args_and_out, graph_hash,
# arg_index_to_need_update_index, none_remover, graph_input_matcher,
# dumb_return_handler, xla_args_need_update).
input_shape_mappings: dict[tuple[int, ...], tuple[object, ...]] = {}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ust typing.Dict and typing.Tuple otherwise the python 3.8 CI in upstream will fail

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

Comment thread torch_xla/core/dynamo_bridge.py
Comment on lines +499 to +502
input_shape_mappings[arg_input_shapes] = (
xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler, xla_args_need_update)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you don't need this here

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, we actually need this here. And we actually don't need this same logic in extract_internal above (removed this in the newest commit). The reason is when dynamic=True, only optimized_mod is called. Other functions (including extract_internal) are not called.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok then you will run into the same old problem right?
first time

extract_graph_helper -> optimized_mod

in this case you do the compile, but you do not cache the input_shape_mappings

when optimized_mod is called the first tiem you will need to call extract_graph_helper again which is wasteful.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should just do the caching(input_shape_mappings[arg_input_shapes] =) inside the extract_graph_helper

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me fix this too..

Comment on lines +587 to +588
dynamo_extract_graph_helper_metric_count = metrics.counter_value(
'DynamoExtractCompiledGraph')
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will run_node call extract_compiled_graph too?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's hard to see from documentations. However, when I try comparing metrics before/after run_node, from what I can see, it's not calling extract_compiled_graph.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok then I am confused what this dynamo_extract_graph_helper_metric_count is doing here

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code (run_node) is executed when we're fetching the fallback ops. And in this code below, we clear our metric counters via metrics.clear_counters(). So we need a way to restore this counter, so we can verify extract_compiled_graph only gets called once in our unit tests.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, I can fix it later. I think the right thing to do is to define a region where counter does not incremented.

@JackCaoG
Copy link
Copy Markdown
Collaborator

@ysiraichi FYI

@wonjoo-wj
Copy link
Copy Markdown
Collaborator Author

The PR should be in a reasonable state, now just seeing 2 failures the GPU tests requiring torch CUDA tests:

#1: DynamoInferenceBasicTest.test_dynamic_shape_resnet180 (True):
Input tensor is not an XLA tensor: CUDAFloatType

#2: DynamoInferenceBasicTest.test_resnet180 (True)
  File "/__w/xla/xla/pytorch/xla/test/dynamo/test_dynamo.py", line 370, in test_resnet18
    self.assertEqual(met.metric_data('CompileTime')[0], 1)
TypeError: 'NoneType' object is not subscriptable

For the first error, the stack trace points to:

pytree.tree_map_only(
torch.Tensor,
lambda xla_arg: torch_xla._XLAC._xla_get_tensor_id(xla_arg),
xla_args))

It seems like we may want to do an additional isinstance(arg, torch.Tensor) check here.

@JackCaoG
Copy link
Copy Markdown
Collaborator

I will pick this up and try to fix error today

@JackCaoG JackCaoG added the tpuci label Jul 22, 2024
@JackCaoG JackCaoG requested a review from alanwaketan July 22, 2024 23:20
@JackCaoG
Copy link
Copy Markdown
Collaborator

@alanwaketan There are a few places I want to fix but maybe we should just merge this pr to unblock Woosuk now. I am also running some benchmarks

Copy link
Copy Markdown
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved to unblock.

@JackCaoG JackCaoG merged commit 2b6b461 into master Jul 23, 2024
@JackCaoG JackCaoG deleted the wonjoo/dynamo-dynamic-shape branch July 23, 2024 01:18
yitongh pushed a commit to AlibabaPAI/xla that referenced this pull request Oct 11, 2024
Co-authored-by: JackCaoG <jackcao@google.com>
yitongh pushed a commit to AlibabaPAI/xla that referenced this pull request Dec 11, 2024
Co-authored-by: JackCaoG <jackcao@google.com>
yitongh pushed a commit to AlibabaPAI/xla that referenced this pull request Dec 11, 2024
Co-authored-by: JackCaoG <jackcao@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

dynamism Dynamic Shape Features

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Dynamo persistent cache real-time look-up

4 participants