Skip to content

Add nvfuser support for prims.copy_to#84545

Draft
IvanYashchuk wants to merge 30 commits intopytorch:mainfrom
IvanYashchuk:nvfuser-copy-to
Draft

Add nvfuser support for prims.copy_to#84545
IvanYashchuk wants to merge 30 commits intopytorch:mainfrom
IvanYashchuk:nvfuser-copy-to

Conversation

@IvanYashchuk
Copy link
Collaborator

@IvanYashchuk IvanYashchuk commented Sep 5, 2022

I use nvFuser's aliasOutputToInput here and since it implicitly adds outputs to the fusion, I need to drop those within Python.

Now we can lower the batch_norm implementation from torch._decomp to nvprims(see test_batch_norm_forward_nvprims).

cc @EikanWang @jgong5 @wenzhe-nrv @sanchitintel @kevinstephano @jjsjann123 @ezyang @mruberry @ngimel @lezcano @fdrocha @peterbell10

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Sep 5, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit 8fe12d5 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Sep 5, 2022
Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to discuss more on what to expect from copy_to in upstream.

We won't be able to generically support copy_to in its wildest form in an fx graph. IIRC, the last time we discussed this with upstream on their functionalization pass, we agreed that copy_to will be applied sparingly and carefully.

We really only care about the case where we are doing some spot update on running stats for BN and I think that's what we should focus on. Maybe not necessarily in this PR, but we definitely need more checks and be more explicit on when to we would take copy_to in nvfuser graph.


//! Specialized Record Functor for recording removing of outputs.

template <class OutputType>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the template class here necessary? we are only using removeOutputRecord with NvfTensorView in this PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, it's not needed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove output is not needed in this PR anymore, I removed this code.


def _is_func_unsupported_nvfuser(torch_function_mode, func, args, kwargs):
with torch.overrides.enable_torch_function_mode(
with torch.no_grad(), torch.overrides.enable_torch_function_mode(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is no_grad here needed only because we added copy_to and that somehow messed up gradient? Or this is just a patch for an existing bug?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a patch for an existing bug for which I need to file an issue.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem was that the ATen function updates in-place the running mean and var without autograd graph connected and always expects "requires_grad" to be False for these arguments. And the decomposition was not doing that. Fixed in ca2c176.

dest = torch::jit::fuser::cuda::set(source);
// aliasOutputToInput implicitly adds the output to the fusion
// adding it in this path simplifies the logic upstream
fd.addOutput(source);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aliasOutputToInput adds dest to output, while in false block we are adding source to output. It sounds consistent with the comment above.

NVM, I was dumb... it makes sense, we are adding source to output in both cases.

auto source = fd.getFusionState(args.at(1))->as<NvfTensorView>();

if (dest->isFusionInput()) {
fd.fusionPtr()->aliasOutputToInput(source, dest);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the place where we actually need to perform a copy here.

auto tmp = xxxx::set(source); 
fd.fusionPtr()->aliasOutputToInput(tmp, dest);

I saw something in the python side doing this, but I'm pretty uncomfortable having these logic separated.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I also thought about this and tried that. Something was wrong, I don't remember what exactly, I should try again.

This approach is also better because we wouldn't need to try conditionally removing tensors from fusion output and adding them back later in the correct order since source is now completely disconnected.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto tmp = torch::jit::fuser::cuda::set(source); is now used. My initial mistake was to use tmp = set(source) only for dest->isFusionInput() path, it becomes difficult to drop this output for the "else" branch.

# If source is a fusion input, we need to place an operation
# before the copy_to so that the copy is actually performed
if _is_node_in_input(gm, source_node):
source = fd.ops.set(source)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is the python side where we are making the copy? The condition here looks suspicious to me, shouldn't we be checking destination value being on input, instead of source.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"set" is not necessary for the in-place update of destination, any operation on the source tensor is enough, see for example func2 in test_copy_to.

This code patch is a workaround for the case when there were no operations defined to produce source and for the in-place update to be realized we need to place some operation, set seems the most natural choice for this. The case when there are no producers of source is when source is input to the fusion/graph.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this code, there's no special handling on the Python side anymore.

# (it was implicitly added by the copy_to op)
# it will be added back later using correct expected order
if _is_node_in_output(gm, source_node) and not _is_node_in_input(
gm, source_node
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, shouldn't the second source_node be destination node instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When _is_node_in_input(gm, source_node) == True we hit the other path where we create a new temporary source, since it's temporary there's no way it could be marked as output in the graph.
Here we only remove from nvFusion outputs source that was marked as output inside aliasOutputToInput and then is going to be marked as output here:

out = FusionInterpreter(gm).run(*nv_args)
flat_out, unflatten_spec = tree_flatten(out)
for o in flat_out:
fd.add_output(o)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this code, there's no special handling on the Python side anymore.

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 12, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/84545

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit dc529fb:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.


return tree_unflatten(
fusion.execute(concrete_fusion_inputs), # type: ignore[has-type]
fusion.execute(concrete_fusion_inputs)[drop_output_count:], # type: ignore[has-type]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are the implicit extra outputs guaranteed to be in the beginning?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the order of adding outputs is respected and the real outputs we need to return are added last here:

for o in flat_out:
fd.add_output(o)

@IvanYashchuk
Copy link
Collaborator Author

There would be a conflict with #84626 that the tests implemented here would stop working.

@jjsjann123
Copy link
Collaborator

There would be a conflict with #84626 that the tests implemented here would stop working.

#84626 is merged now. I'll get a patch for this PR.

@jjsjann123
Copy link
Collaborator

I'm seeing the failing test. Working on a quick patch.

b_sin = b.sin()
a1 = a.copy_(b_sin)
a1_sin = a1.sin()
a2 = a1.copy_(a1_sin)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks scary.

Currently codegen has no idea on proper race condition. Having a single buffer used as both RW in different operation is not something we support. We should remove this test.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing test is easy, but what happens here? Does codegen generate a kernel for this fusion, or does it fallback?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the codegen generates kernels for this fusion. There are no fallbacks at the nvFuser level, it's either more segments or an error. Since we're using executor="strictly_nvfuser" there's no fallback at the Python level as well.

Codegen generates two segments in this case:
Segmented_Fusion Dump: -- fusion segments:
Segmented_Fusion{
groups:
g{0, 1}

g{2, 3}

edges:

group details:
g{(pointwise)
inputs:
T0_g[ iS0{i0}, iS1{i1} ] float
T1_g[ iS2{i3}, iS3{i4} ] float
outputs:
T2_g[ iS4{i3}, iS5{i4} ] float
T3_g[ iS6{i3}, iS7{i4} ] float


T2_g[ iS4{i3}, iS5{i4} ]
   = sinf(T1_g[ iS2{i3}, iS3{i4} ]);
T3_g[ iS6{i3}, iS7{i4} ]
   = T2_g[ iS4{i3}, iS5{i4} ];
}

g{(pointwise)
inputs:
T0_g[ iS0{i0}, iS1{i1} ] float
outputs:
T4_g[ iS8{i0}, iS9{i1} ] float
T5_g[ iS10{i0}, iS11{i1} ] float


T4_g[ iS8{i0}, iS9{i1} ]
   = sinf(T0_g[ iS0{i0}, iS1{i1} ]);
T5_g[ iS10{i0}, iS11{i1} ]
   = T4_g[ iS8{i0}, iS9{i1} ];
}

} //Segmented_Fusion
This corresponds to two kernels:
======= Codegen output for kernel: kernel1 =======

__global__ void kernel1(Tensor<float, 2> T1, Tensor<float, 2> T0, Tensor<float, 2> T7, Tensor<float, 2> T3) {
  int i71;
  i71 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x);
  if ((i71 < (T1.size[0] * T1.size[1]))) {
    float T6[1];
    T6[0] = 0;
    T6[0]
       = T1[i71];
    float T2[1];
    T2[0]
       = sinf(T6[0]);
    float T8[1];
    T8[0]
       = T2[0];
    T7[i71]
       = T8[0];
    float T9[1];
    T9[0]
       = T2[0];
    T3[i71]
       = T9[0];
  }
}

======================================


======= Codegen output for kernel: kernel2 =======

__global__ void kernel2(Tensor<float, 2> T0, Tensor<float, 2> T7, Tensor<float, 2> T5) {
  int i69;
  i69 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x);
  if ((i69 < (T0.size[0] * T0.size[1]))) {
    float T6[1];
    T6[0] = 0;
    T6[0]
       = T0[i69];
    float T4[1];
    T4[0]
       = sinf(T6[0]);
    float T8[1];
    T8[0]
       = T4[0];
    T7[i69]
       = T8[0];
    float T9[1];
    T9[0]
       = T4[0];
    T5[i69]
       = T9[0];
  }
}

======================================

self.assertEqual(out[0], a)

a = torch.empty(3, 3, device='cuda')
self.assertEqual(out[0], func(a, b)[0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also feel the test is too relaxed and sends out a misleading message on what codegen alias supports.

We should be explicit on that the alias support right now is VERY limited. In the test example here, we should check that fusion maintains consistent behavior. We should check matching result on all outputs, as well as identical aliases among outputs and inputs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't nvfuser give up on aliasing inputs and outputs? (We had this discussion in transpose PR)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed:

-            self.assertEqual(out[0], func(a, b)[0])
+            self.assertEqual(out, func(a, b))

And added a comparison of storage (it should be the same).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't nvfuser give up on aliasing inputs and outputs? (We had this discussion in transpose PR)

Yeah, we sorta did.
Our assumption is that functionalization pass will resolve this for us and we do not need to "return" outputs with correct alias.
Since test here doesn't use functionalization pass, I'm trying to keep tests cleaner.

for (const auto out_i : c10::irange(kernel->outputs().size())) {
// TODO: FIX this short-cut where we trivially forward inputs to outputs
if (kernel->outputs()[out_i]->isFusionInput()) {
TORCH_INTERNAL_ASSERT(false, "trivial input forwarding NOT IMPLEMENTED");
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes added via IvanYashchuk#3

pytorchmergebot pushed a commit that referenced this pull request Oct 3, 2022
This PR adds nvFuser's implementation for batch_norm as there's no reference yet (#81191) and no in-place copy support (#84545).

Pull Request resolved: #85562
Approved by: https://github.com/kevinstephano, https://github.com/ngimel
@facebook-github-bot
Copy link
Contributor

/easycla

As part of the transition to the PyTorch Foundation, this project now requires contributions be covered under the new CLA. See #85559 for additional details.

This comment will trigger a new check of this PR. If you are already covered, you will simply see a new "EasyCLA" check that passes. If you are not covered, a bot will leave a new comment with a link to sign.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Oct 4, 2022

CLA Signed

The committers listed above are authorized under a signed CLA.

jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this pull request Oct 29, 2022
This PR adds nvFuser's implementation for batch_norm as there's no reference yet (pytorch/pytorch#81191) and no in-place copy support (pytorch/pytorch#84545).

Pull Request resolved: pytorch/pytorch#85562
Approved by: https://github.com/kevinstephano, https://github.com/ngimel
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this pull request Nov 10, 2022
This PR adds nvFuser's implementation for batch_norm as there's no reference yet (pytorch/pytorch#81191) and no in-place copy support (pytorch/pytorch#84545).

Pull Request resolved: pytorch/pytorch#85562
Approved by: https://github.com/kevinstephano, https://github.com/ngimel
@github-actions
Copy link
Contributor

github-actions bot commented Dec 3, 2022

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Dec 3, 2022
@github-actions github-actions bot closed this Jan 2, 2023
@IvanYashchuk IvanYashchuk reopened this Jan 10, 2023
@IvanYashchuk IvanYashchuk marked this pull request as draft January 10, 2023 07:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed module: nvfuser module: primTorch no-stale oncall: jit Add this issue/PR to JIT oncall triage queue open source release notes: jit release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants