Skip to content

Support Unbounded Dynamism for torch.export #6393

@lsy323

Description

@lsy323

🚀 Feature

The unbounded dynamic shape needs to be propagated through ops.

Scope:

  • Only for export use case.

Example:

opcode         name           target                   args                     kwargs
-------------  -------------  -----------------------  -----------------------  --------
placeholder    arg0_1         arg0_1                   ()                       {}
placeholder    l_embeddings_  l_embeddings_            ()                       {}
call_function  sym_size_int   aten.sym_size.int        (l_embeddings_, 0)       {}
call_function  mul            <built-in function mul>  (sym_size_int, 2)        {}
call_function  expand         aten.expand.default      (arg0_1, [mul, -1, -1])  {}
output         output         output                   ((expand,),)             {}

In LazyIR, we need to capture the aten.sym_size.int and the subsequent arithmetic operations on the SymInt. So that the semantic can be lowered.

In the lowered HLO graph, we should have something like

main(%arg0: tensor<?x3x224x224xf32>, %arg1: tensor<1x1x768xf32>) -> tensor<?x1x768xf32> {
    %1 = get_dimension_size(%arg0, dim = 0) 
    %2 = expand(%arg1, %1, dim=0)
    return %2 : tensor<?x1x768xf32>
}

Rough Plan

  • We need to trace and lower torch ops with SymInt output in LTC. (for aten.sym_size.int, which generates a SymInt)
  • The arithmetic on the SymInt needs to be traced. This shouldn't be hard to achieve if the corresponding LazyIR node can be created for the SymInt.
  • When the sym_int version of the op is lowered, it needs to retrieve the underlying LazyIR of the SymInt argument

Open questions

  1. Would it make sense to handle both bounded and unbounded dynamism under the same workflow/infra? The source of the dynamic dim needs to be traced for unbounded dynamic case, but not for bounded dynamism.

Example

dynamic_dim = input.shape[0]
dynamic_dim = dynamic_dim * 2
expanded = input.expand([dynamic * 2, -1, -1])

Let's say input.shape[0] has a bound of <= 5. In bounded dynamism, only knowing the upper bound in the op is enough. In unbounded dynamism, the arithmetic on SymInt needs to be traced and lowered in LTC.

  1. Not sure if there is any API to create an unbounded dynamic tensor, so the graph can be traced with it.

Metadata

Metadata

Assignees

Labels

dynamismDynamic Shape Features

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions