Skip to content

Commit 97891b1

Browse files
oulgenpytorchmergebot
authored andcommitted
[Dynamo] Trace autograd.function in dynamo when inputs require grad (#116358)
For training graphs (when inputs require grad), previously, we would speculate the forward and backward graph to determine if there are any graph breaks, side effect and etc but would not actually use these speculated graphs. We would just insert a call function node on the graph and later rely on autograd's tracing. This approach does not work for more generalized graphs like graphs that include user defined triton kernels because autograd is not able to do the higher order function conversation. This PR speculates the forward and backward functions and emits them in a HOF that later gets used via templating mechanism. While working on this PR, I have exposed some bugs in the current tracing due to trampoline functions losing the source information resulting in incorrect graphs being produced. I have fixed these source information bugs and killed the trampolines. Pull Request resolved: #116358 Approved by: https://github.com/jansel
1 parent c5d9173 commit 97891b1

8 files changed

Lines changed: 336 additions & 203 deletions

File tree

test/dynamo/test_autograd_function.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88
import torch._dynamo.test_case
99
import torch._dynamo.testing
1010
import torch._dynamo.utils
11+
from torch.testing._internal.common_utils import skipIfRocm
12+
from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda
13+
14+
if HAS_CUDA:
15+
import triton
16+
from torch.testing._internal.triton_utils import add_kernel
1117

1218

1319
class CustomFunc1(torch.autograd.Function):
@@ -275,7 +281,7 @@ def test_stride_in_bwd(self):
275281
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
276282
with self.assertRaisesRegex(
277283
torch._dynamo.exc.Unsupported,
278-
"Illegal getattr invocation stride in strict mod",
284+
".*HigherOrderOperator body's output must consist of tensors only",
279285
):
280286
opt_model(x)
281287

@@ -836,6 +842,74 @@ def foo(x):
836842
foo(torch.randn(2, requires_grad=True))
837843
self.assertEqual(cnts.frame_count, 1)
838844

845+
@requires_cuda()
846+
@skipIfRocm
847+
def test_triton_kernel_basic(self):
848+
class Add(torch.autograd.Function):
849+
@staticmethod
850+
def forward(ctx, x, y):
851+
ctx.save_for_backward(x, y)
852+
output = torch.zeros_like(x)
853+
n_elements = output.numel()
854+
grid = lambda meta: ( # noqa: E731
855+
triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
856+
)
857+
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
858+
return output
859+
860+
@staticmethod
861+
def backward(ctx, grad_output):
862+
x, y = ctx.saved_tensors
863+
return x * grad_output, y * grad_output
864+
865+
@torch.compile(fullgraph=True, backend="inductor")
866+
def f(x, y):
867+
z = Add.apply(x, y)
868+
return z
869+
870+
x = torch.randn(10, device="cuda", requires_grad=True)
871+
y = torch.randn(10, device="cuda", requires_grad=True)
872+
z = f(x, y)
873+
loss = z.sum()
874+
loss.backward()
875+
self.assertEqual(x + y, z)
876+
877+
@requires_cuda()
878+
@skipIfRocm
879+
def test_triton_kernel_multiple_out(self):
880+
class Add(torch.autograd.Function):
881+
@staticmethod
882+
def forward(ctx, x, y):
883+
ctx.save_for_backward(x, y)
884+
ctx.t1 = x
885+
ctx.t2 = y
886+
output = torch.zeros_like(x)
887+
n_elements = output.numel()
888+
grid = lambda meta: ( # noqa: E731
889+
triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
890+
)
891+
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
892+
return output, x
893+
894+
@staticmethod
895+
def backward(ctx, grad_output, old_x):
896+
x, y = ctx.saved_tensors
897+
x1 = ctx.t1
898+
y1 = ctx.t2
899+
return old_x * x * x1 * grad_output, y * y1 * grad_output
900+
901+
@torch.compile(fullgraph=True, backend="inductor")
902+
def f(x, y):
903+
z = Add.apply(x, y)
904+
return z
905+
906+
x = torch.randn(10, device="cuda", requires_grad=True)
907+
y = torch.randn(10, device="cuda", requires_grad=True)
908+
z, _ = f(x, y)
909+
loss = z.sum()
910+
loss.backward()
911+
self.assertEqual(x + y, z)
912+
839913

840914
if __name__ == "__main__":
841915
from torch._dynamo.test_case import run_tests

torch/_dynamo/symbolic_convert.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2251,18 +2251,12 @@ def check_inlineable(func):
22512251

22522252
result = skipfiles.check_verbose(func, is_inlined_call=True)
22532253
if result.skipped:
2254-
from torch._dynamo.variables.misc import (
2255-
produce_trampoline_autograd_apply,
2256-
produce_trampoline_autograd_bwd,
2257-
produce_trampoline_autograd_fwd,
2258-
)
2254+
from torch._dynamo.variables.misc import produce_trampoline_autograd_apply
22592255

22602256
# _origin marks this as coming from an internal dynamo known function that is safe to
22612257
# trace through.
22622258
if hasattr(func.fn, "_origin") and func.fn._origin in [
2263-
produce_trampoline_autograd_fwd,
22642259
produce_trampoline_autograd_apply,
2265-
produce_trampoline_autograd_bwd,
22662260
]:
22672261
# Known sound
22682262
return skipfiles.SkipResult(False, "allowlist in dynamo known function")

torch/_dynamo/variables/builder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,9 @@ def index_source(key):
561561
# handle aliased autograd function `apply` calls
562562
self.install_guards(GuardBuilder.FUNCTION_MATCH)
563563
return GetAttrVariable(
564-
AutogradFunctionVariable(value.__self__, source=self.source),
564+
AutogradFunctionVariable(
565+
value.__self__, source=AttrSource(self.source, member="__self__")
566+
),
565567
"apply",
566568
)
567569
elif np and isinstance(value, np.number):

0 commit comments

Comments
 (0)