[DTensor] Fix mypy on register_op_strategy#167673
[DTensor] Fix mypy on register_op_strategy#167673wconstab wants to merge 6 commits intogh/wconstab/455/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/167673
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 16d4c15 with merge base 9760a63 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k pragupta msaroufim dcci [ghstack-poisoned]
| def register_op_strategy( | ||
| op, schema_info=None | ||
| ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: | ||
| # pyre-fixme[2]: Parameter must be annotated. |
There was a problem hiding this comment.
Actually, wait this is a decorator so that would erase typing info of the thing it wraps
| def register_op_strategy( | ||
| op, schema_info=None | ||
| ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: | ||
| # pyre-fixme[2]: Parameter must be annotated. | ||
|
|
||
| op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]], | ||
| schema_info: Optional[RuntimeSchemaInfo] = None, | ||
| ) -> Callable[[Callable[[OpSchema], StrategyType]], Callable[[OpSchema], StrategyType]]: |
There was a problem hiding this comment.
Wait, does it preserve args and return type? If so
| def register_op_strategy( | |
| op, schema_info=None | |
| ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: | |
| # pyre-fixme[2]: Parameter must be annotated. | |
| op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]], | |
| schema_info: Optional[RuntimeSchemaInfo] = None, | |
| ) -> Callable[[Callable[[OpSchema], StrategyType]], Callable[[OpSchema], StrategyType]]: | |
| _OpSchemaT = TypeVar("_SchemaT", bound=OpSchema) | |
| _StrategyTypeT = TypeVar("_StrategyT", bound=StrategyType) | |
| def register_op_strategy( | |
| op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]], | |
| schema_info: Optional[RuntimeSchemaInfo] = None, | |
| ) -> Callable[[Callable[[_OpSchemaT], _StrategyTypeT]], Callable[[_OpSchemaT], _StrategyTypeT]]: |
This is way the callable typing is full perserved AND type checked. The wrapper would need to be updated too of course. But I think actually the current typing in this PR might be as specific as one can be unfortunately.
There was a problem hiding this comment.
oh, is this because the original OpSchema passed to the callable could be a subclass of OpSchema, but the wrapping would force the typechecker to treat it specifically as a base OpSchema from the point of wrapping?
cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k pragupta msaroufim dcci [ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
|
@Skylion007 if you don't mind taking another look, I think I understood this now and I updated my PR accordingly. Thanks for your pointers! |
| # MyStrategyType(StrategyType). | ||
| _OpSchemaT = TypeVar("_OpSchemaT", bound=OpSchema) | ||
| _StrategyTypeT = TypeVar("_StrategyTypeT", bound=StrategyType) | ||
| _ShardingStrategyFuncT = Callable[[_OpSchemaT], _StrategyTypeT] |
There was a problem hiding this comment.
Slight nit: Thai is a bit of weird type and is a strict type alias because it references typevars. May want to annotate as
_ShardingStrategyFuncT: TypeAlias = to keep IDEs happy/unconfused
[ghstack-poisoned]
|
Starting merge as part of PR stack under #168113 |
Pull Request resolved: #168113 Approved by: https://github.com/mlazos, https://github.com/zpcore ghstack dependencies: #167673
ghstack-source-id: d9ee620 Pull Request resolved: pytorch/pytorch#167673
Stack from ghstack (oldest at bottom):
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @d4l3k @pragupta @msaroufim @dcci