Skip to content

JIT: DifferentiableGraph/Requires grad handled badly by ProfilingExecutor and fuser fallbacks #49299

@t-vi

Description

@t-vi

🐛 Bug

To me, it seems like the interplay between requiring gradients and the profiling executor/fuser fallback seems not ideal:

  1. in the fuser1 (TensorExpr) and fuser2 (CudaFuser), if I run a few grad-requiring tensors through my model and then run non-grad-requiring, it will run into the DifferentiableSubgraph and give me a grad-requiring output even though none of the inputs requires gradient. (It will also run a kernel outputting intermediates that we don't need, so perf loss...)
  2. in the fuser2 (CudaFuser) the CudaFusionGuard doesn't check for requires grad, so if I run a few non-grad-requiring tensors first, and feed a grad-requiring input, it'll drop the gradient (this is different for fuser1 where the typecheck checks for requires grad and then drops into the fallback).

To me it would seem that if we want to pull off the fallback stuff, we would properly guard DifferentiableGraph with what requires grad and the fusion guards with nothing requires grad (we might also have cases where which gradients are required varies).

To Reproduce

for fuser in ["fuser1", "fuser2"]:
    for rq in [True, False]:
        c = torch.jit.fuser(fuser)
        c.__enter__()

        def ratio_iou(x1, y1, w1, h1, x2, y2, w2, h2):
            xi = torch.max(x1, x2)                                  # Intersection left
            yi = torch.max(y1, y2)                                  # Intersection top
            wi = torch.clamp(torch.min(x1+w1, x2+w2) - xi, min=0.)  # Intersection width
            hi = torch.clamp(torch.min(y1+h1, y2+h2) - yi, min=0.)  # Intersection height
            area_i = wi * hi                                        # Area Intersection
            area_u = w1 * h1 + w2 * h2 - wi * hi                    # Area Union
            return area_i / torch.clamp(area_u, min=1e-5)           # Intersection over Union

        ratio_iou_scripted = torch.jit.script(ratio_iou)

        x1, y1, w1, h1, x2, y2, w2, h2 = torch.randn(8, 100, 1000, device='cuda', requires_grad=not rq).exp()

        for i in range(10):
            ratio_iou_scripted.graph_for(x1, y1, w1, h1, x2, y2, w2, h2)
        #print(ratio_iou_scripted.graph_for(x1, y1, w1, h1, x2, y2, w2, h2))

        x1, y1, w1, h1, x2, y2, w2, h2 = torch.randn(8, 100, 1000, device='cuda', requires_grad=rq).exp()
        print(fuser, x1.requires_grad, ratio_iou_scripted(x1, y1, w1, h1, x2, y2, w2, h2).requires_grad)

Expected behavior

requires_grad of output same as for input.

Environment

PyTorch master.

Additional context

This used to be handled by the ArgSpec mechanism.

cc @gmagogsfm

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: jitAdd this issue/PR to JIT oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions