Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions test/dynamo/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,8 @@ def bar():
"aot_graphs_effects",
"pre_grad_graphs",
"post_grad_graphs",
"ir_pre_fusion",
"ir_post_fusion",
"compiled_autograd",
"compiled_autograd_verbose",
"recompiles",
Expand Down
29 changes: 23 additions & 6 deletions test/inductor/test_debug_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch._inductor import config, test_operators
from torch._inductor.utils import fresh_inductor_cache
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
from torch.testing._internal.logging_utils import multiple_logs_to_string


try:
Expand All @@ -38,6 +39,10 @@ def fn(a, b):
a = test_operators.realize(a + 1) + 2
return torch.matmul(a, b)

(pre_fusion_stream, post_fusion_stream), ctx = multiple_logs_to_string(
"torch._inductor.debug", "ir_pre_fusion", "ir_post_fusion"
)

# TODO(aakhundov): make this work with fresh_inductor_cache
# instead of force_disable_caches. currently, with the latter
# enabled, we get `inductor [('fxgraph_cache_hit', 1)]` in
Expand All @@ -50,21 +55,30 @@ def fn(a, b):
):
with self.assertLogs(
logging.getLogger("torch._inductor.debug"), level=logging.WARNING
) as cm:
) as cm, ctx():
fn(torch.randn(16, 16), torch.randn(16, 16))

self.assertEqual(len(cm.output), 1)
m = re.match(r"WARNING.* debug trace: (.*)", cm.output[0])
self.assertTrue(m)
m = None
for log_line in cm.output:
# Search for warning message with debug trace file path.
m = re.match(r"WARNING.* debug trace: (.*)", log_line)
if m:
break
self.assertTrue(m, "debug trace file path not found in logs")
# For type checking, have to ensure it's not none.
assert m is not None
filename = Path(m.group(1))
self.assertTrue(filename.is_dir())
self.assertGreater(filesize(filename / "fx_graph_readable.py"), 512)
self.assertGreater(filesize(filename / "fx_graph_runnable.py"), 512)
self.assertGreater(filesize(filename / "fx_graph_transformed.py"), 512)
self.assertGreater(filesize(filename / "output_code.py"), 1024)

pre_fusion_logs = pre_fusion_stream.getvalue().strip()
self.assertExpectedInline(
open(filename / "ir_pre_fusion.txt").read().rstrip(),
pre_fusion_logs,
"""\
BEFORE FUSION
op0: SchedulerNode(ComputedBuffer)
op0.writes = [MemoryDep('buf0', c0, {c0: 256})]
op0.unmet_dependencies = []
Expand Down Expand Up @@ -130,9 +144,12 @@ def body(self, ops):
]
op2.node.kernel = extern_kernels.mm""",
)

post_fusion_logs = post_fusion_stream.getvalue().strip()
self.assertExpectedInline(
open(filename / "ir_post_fusion.txt").read().rstrip(),
post_fusion_logs,
"""\
AFTER FUSION
op0_op1: FusedSchedulerNode(SchedulerNode,SchedulerNode)
op0_op1.writes = [MemoryDep('buf0', c0, {c0: 256}), MemoryDep('buf1', c0, {c0: 256})]
op0_op1.unmet_dependencies = []
Expand Down
25 changes: 12 additions & 13 deletions torch/_inductor/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy
import dataclasses
import functools
import io
import itertools
import json
import logging
Expand Down Expand Up @@ -42,6 +43,8 @@

log = logging.getLogger(__name__)

ir_pre_fusion_log = torch._logging.getArtifactLogger(__name__, "ir_pre_fusion")
ir_post_fusion_log = torch._logging.getArtifactLogger(__name__, "ir_post_fusion")
SchedulerNodeList = list[Any]
BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"])
GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"]
Expand Down Expand Up @@ -522,21 +525,17 @@ def fx_graph_transformed(
fd.write(gm.print_readable(print_output=False))

def ir_pre_fusion(self, nodes: SchedulerNodeList) -> None:
self._write_ir("ir_pre_fusion.txt", nodes)
ir_pre_fusion_log.debug("BEFORE FUSION\n%s", self._write_ir(nodes))

def ir_post_fusion(self, nodes: SchedulerNodeList) -> None:
self._write_ir("ir_post_fusion.txt", nodes)

def _write_ir(
self,
filename: str,
nodes: SchedulerNodeList,
) -> None:
with self.fopen(filename) as fd:
log.info("Writing debug ir to %s", fd.name)
for node in nodes:
fd.write(node.debug_str())
fd.write("\n\n\n")
ir_post_fusion_log.debug("AFTER FUSION\n%s", self._write_ir(nodes))

def _write_ir(self, nodes: SchedulerNodeList) -> str:
buf = io.StringIO()
for node in nodes:
buf.write(node.debug_str())
buf.write("\n\n\n")
return buf.getvalue()

def graph_diagram(self, nodes: SchedulerNodeList) -> None:
draw_buffers(nodes, fname=self.filename("graph_diagram.svg"))
Expand Down
10 changes: 10 additions & 0 deletions torch/_logging/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ def set_logs(
perf_hints: bool = False,
pre_grad_graphs: bool = False,
post_grad_graphs: bool = False,
ir_pre_fusion: bool = False,
ir_post_fusion: bool = False,
onnx_diagnostics: bool = False,
fusion: bool = False,
overlap: bool = False,
Expand Down Expand Up @@ -396,6 +398,12 @@ def set_logs(
post_grad_graphs (:class:`bool`):
Whether to emit the graphs generated by after post grad passes. Default: ``False``

ir_pre_fusion (:class:`bool`):
Whether to emit the graphs before inductor fusion passes. Default: ``False``

ir_post_fusion (:class:`bool`):
Whether to emit the graphs after inductor fusion passes. Default: ``False``

onnx_diagnostics (:class:`bool`):
Whether to emit the ONNX exporter diagnostics in logging. Default: ``False``

Expand Down Expand Up @@ -521,6 +529,8 @@ def _set_logs(**kwargs):
perf_hints=perf_hints,
pre_grad_graphs=pre_grad_graphs,
post_grad_graphs=post_grad_graphs,
ir_pre_fusion=ir_pre_fusion,
ir_post_fusion=ir_post_fusion,
onnx=onnx,
onnx_diagnostics=onnx_diagnostics,
fusion=fusion,
Expand Down
8 changes: 8 additions & 0 deletions torch/_logging/_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,14 @@
"post_grad_graphs",
"Prints the FX graph generated by post grad passes. Useful to understand what's being given to Inductor after post grad passes",
)
register_artifact(
"ir_pre_fusion",
"Prints the IR before inductor fusion passes.",
)
register_artifact(
"ir_post_fusion",
"Prints the IR after inductor fusion passes.",
)
register_artifact(
"compiled_autograd",
"Prints various logs in compiled_autograd, including but not limited to the graphs. Useful for debugging compiled_autograd.",
Expand Down
29 changes: 29 additions & 0 deletions torch/testing/_internal/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import contextlib
import torch._logging
import torch._logging._internal
from typing import Callable, ContextManager, List, Tuple
from torch._dynamo.utils import LazyString
from torch._inductor import config as inductor_config
import logging
Expand Down Expand Up @@ -211,3 +212,31 @@ def ctx_manager():
return exit_stack

return log_stream, ctx_manager


def multiple_logs_to_string(module: str, *log_options: str) -> Tuple[List[io.StringIO], Callable[[], ContextManager[None]]]:
"""Example:
multiple_logs_to_string("torch._inductor.compile_fx", "pre_grad_graphs", "post_grad_graphs")
returns the output of TORCH_LOGS="pre_graph_graphs, post_grad_graphs" from the
torch._inductor.compile_fx module.
"""
log_streams = [io.StringIO() for _ in range(len(log_options))]
handlers = [logging.StreamHandler(stream=log_stream) for log_stream in log_streams]

@contextlib.contextmanager
def tmp_redirect_logs():
loggers = [torch._logging.getArtifactLogger(module, option) for option in log_options]
try:
for logger, handler in zip(loggers, handlers):
logger.addHandler(handler)
yield
finally:
for logger, handler in zip(loggers, handlers):
logger.removeHandler(handler)

def ctx_manager() -> ContextManager[None]:
exit_stack = log_settings(", ".join(log_options))
exit_stack.enter_context(tmp_redirect_logs())
return exit_stack # type: ignore[return-value]

return log_streams, ctx_manager