Skip to content

Commit b2e0f8d

Browse files
Yang Chenpytorchmergebot
authored andcommitted
[mypy] added type annotations to codegen_nodes methods (#119080)
added correct type annotations to scheduler and backends' codegen_nodes methods Pull Request resolved: #119080 Approved by: https://github.com/eellison
1 parent 88e3466 commit b2e0f8d

4 files changed

Lines changed: 7 additions & 7 deletions

File tree

torch/_inductor/codegen/cpp.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3089,18 +3089,18 @@ def is_memory_copy_scheduler_node(node: SchedulerNode):
30893089
body: ir.LoopBody = node._body
30903090
_legalize_lowp_fp(body)
30913091

3092-
def codegen_nodes(self, nodes):
3092+
def codegen_nodes(self, nodes: List[SchedulerNode]):
30933093
# Legalize BF16 node by adding to_dtype explicitly
30943094
self.legalize_lowp_fp_dtype(nodes)
30953095
self.data_type_propagation(nodes)
30963096

30973097
assert len(nodes) >= 1
30983098
first_node = nodes[0]
30993099
vec_dtype = (
3100-
first_node._lowp_fp_type
3100+
first_node._lowp_fp_type # type: ignore[attr-defined]
31013101
if all(
31023102
hasattr(_node, "_lowp_fp_type")
3103-
and _node._lowp_fp_type == first_node._lowp_fp_type
3103+
and _node._lowp_fp_type == first_node._lowp_fp_type # type: ignore[attr-defined]
31043104
for _node in nodes
31053105
)
31063106
else torch.float
@@ -3318,7 +3318,7 @@ def can_fuse_horizontal(self, node1, node2):
33183318
def can_fuse_vertical(self, node1, node2):
33193319
return self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction()
33203320

3321-
def codegen_nodes(self, nodes):
3321+
def codegen_nodes(self, nodes: List[SchedulerNode]):
33223322
"""
33233323
Turn an set of pre-fused nodes into a C++ kernel.
33243324
"""

torch/_inductor/codegen/cuda_combined_scheduling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def codegen_template(
5959
template_node, epilogue_nodes
6060
)
6161

62-
def codegen_nodes(self, nodes: List[BaseSchedulerNode]):
62+
def codegen_nodes(self, nodes: List[SchedulerNode]):
6363
return self._triton_scheduling.codegen_nodes(nodes)
6464

6565
def codegen_sync(self):

torch/_inductor/codegen/triton.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3142,7 +3142,7 @@ def requires_closing_previous_reduction(node, node_schedule):
31423142

31433143
return node_schedule
31443144

3145-
def codegen_nodes(self, nodes):
3145+
def codegen_nodes(self, nodes: List[scheduler.SchedulerNode]):
31463146
"""
31473147
Given a set of pre-fused nodes, generate a Triton kernel.
31483148
"""

torch/_inductor/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2343,7 +2343,7 @@ def codegen_template(
23432343
"""
23442344
raise NotImplementedError()
23452345

2346-
def codegen_nodes(self, nodes: List[BaseSchedulerNode]):
2346+
def codegen_nodes(self, nodes: List[SchedulerNode]):
23472347
"""
23482348
Generate a kernel given a list of pre-fused nodes.
23492349
"""

0 commit comments

Comments
 (0)