Add support for dynamic shape in dynamo#7676
Conversation
|
With the current changes, the following code generates correct results without recompiling the graph: As for next steps, I'll clean up some code and add some unit tests. |
630bba3 to
0ef3768
Compare
|
seems like a bunch of test failed and a lot of them are real failures. @wonjoolee95 let me know if you need help debugging them |
| # self.assertTrue( | ||
| # torch.allclose(output_cpu_new_shape, output_new_shape.cpu(), rtol=1e-05, atol=1e-05)) |
There was a problem hiding this comment.
This part is odd. When I run these tests, the allclose fails because in some iteration of the data loader with this new_shape, the differences are as big as 0.2.
There was a problem hiding this comment.
This is fixed with the explicit mark_step call within else statement under torch._dynamo.config.assume_static_by_default.
e8334d4 to
fcd08bb
Compare
858359e to
8b8897c
Compare
| for data, _ in loader_new_shape: | ||
| output_new_shape = dynamo_resnet18(data) | ||
| output_cpu_new_shape = resnet18(data.cpu()) | ||
| # # TPU has some precision issues, skipping allclose check |
| output_new_shape.cpu(), | ||
| rtol=1e-05, | ||
| atol=1e-05)) | ||
|
|
There was a problem hiding this comment.
maybe also check the CompileTime and ExecuteTime here
There was a problem hiding this comment.
also can you make another test to test the case of
fn(shape_a)
fn(shape_b)
fn(shape_c)
fn(shape_a)
want to make sure we don't forgot the old shapes that's cached.
| # Values: tuple of (xla_args_sharding_spec, args_and_out, graph_hash, | ||
| # arg_index_to_need_update_index, none_remover, graph_input_matcher, | ||
| # dumb_return_handler, xla_args_need_update). | ||
| input_shape_mappings: dict[tuple[int, ...], tuple[object, ...]] = {} |
There was a problem hiding this comment.
ust typing.Dict and typing.Tuple otherwise the python 3.8 CI in upstream will fail
| input_shape_mappings[arg_input_shapes] = ( | ||
| xla_args_sharding_spec, args_and_out, graph_hash, | ||
| arg_index_to_need_update_index, none_remover, graph_input_matcher, | ||
| dumb_return_handler, xla_args_need_update) |
There was a problem hiding this comment.
I think you don't need this here
There was a problem hiding this comment.
IIUC, we actually need this here. And we actually don't need this same logic in extract_internal above (removed this in the newest commit). The reason is when dynamic=True, only optimized_mod is called. Other functions (including extract_internal) are not called.
There was a problem hiding this comment.
ok then you will run into the same old problem right?
first time
extract_graph_helper -> optimized_mod
in this case you do the compile, but you do not cache the input_shape_mappings
when optimized_mod is called the first tiem you will need to call extract_graph_helper again which is wasteful.
There was a problem hiding this comment.
you should just do the caching(input_shape_mappings[arg_input_shapes] =) inside the extract_graph_helper
| dynamo_extract_graph_helper_metric_count = metrics.counter_value( | ||
| 'DynamoExtractCompiledGraph') |
There was a problem hiding this comment.
will run_node call extract_compiled_graph too?
There was a problem hiding this comment.
It's hard to see from documentations. However, when I try comparing metrics before/after run_node, from what I can see, it's not calling extract_compiled_graph.
There was a problem hiding this comment.
ok then I am confused what this dynamo_extract_graph_helper_metric_count is doing here
There was a problem hiding this comment.
This code (run_node) is executed when we're fetching the fallback ops. And in this code below, we clear our metric counters via metrics.clear_counters(). So we need a way to restore this counter, so we can verify extract_compiled_graph only gets called once in our unit tests.
There was a problem hiding this comment.
Ah I see, I can fix it later. I think the right thing to do is to define a region where counter does not incremented.
|
@ysiraichi FYI |
|
The PR should be in a reasonable state, now just seeing 2 failures the For the first error, the stack trace points to: xla/torch_xla/core/dynamo_bridge.py Lines 292 to 295 in 4ba63ff It seems like we may want to do an additional |
|
I will pick this up and try to fix error today |
|
@alanwaketan There are a few places I want to fix but maybe we should just merge this pr to unblock Woosuk now. I am also running some benchmarks |
alanwaketan
left a comment
There was a problem hiding this comment.
Approved to unblock.
Co-authored-by: JackCaoG <jackcao@google.com>
Co-authored-by: JackCaoG <jackcao@google.com>
Co-authored-by: JackCaoG <jackcao@google.com>
Fixes #7614
TODO