Skip to content

Commit 94f18c0

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo][index] Speedup index method for constant data structures (#173612)
Saves 0.5 seconds on 15.3 seconds of compile time of a model Pull Request resolved: #173612 Approved by: https://github.com/Lucaskabela ghstack dependencies: #173582
1 parent 5a48148 commit 94f18c0

1 file changed

Lines changed: 20 additions & 6 deletions

File tree

torch/_dynamo/variables/lists.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class that handles its unique behaviors while integrating with Dynamo's
4444
range_iterator,
4545
set_example_value,
4646
)
47-
from .base import ValueMutationNew, VariableTracker
47+
from .base import AsPythonConstantNotImplementedError, ValueMutationNew, VariableTracker
4848
from .constant import ConstantVariable
4949
from .functions import UserFunctionVariable
5050
from .iter import IteratorVariable
@@ -238,11 +238,25 @@ def call_method(
238238
f"{len(args)} args and {len(kwargs)} kwargs",
239239
)
240240

241-
return tx.inline_user_function_return(
242-
VariableTracker.build(tx, polyfills.index),
243-
[self] + list(args),
244-
kwargs,
245-
)
241+
try:
242+
# Speedup trace times for constant data structures
243+
items = [item.as_python_constant() for item in self.items]
244+
const_args = [arg.as_python_constant() for arg in args]
245+
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
246+
try:
247+
return ConstantVariable.create(
248+
items.index(*const_args, **const_kwargs)
249+
)
250+
except ValueError:
251+
raise_observed_exception(
252+
ValueError, tx, args=[ConstantVariable.create("tuple.index()")]
253+
)
254+
except AsPythonConstantNotImplementedError:
255+
return tx.inline_user_function_return(
256+
VariableTracker.build(tx, polyfills.index),
257+
[self] + list(args),
258+
kwargs,
259+
)
246260
elif name == "count":
247261
if len(args) != 1:
248262
raise_args_mismatch(

0 commit comments

Comments
 (0)