Port tensor variants of normal to structured kernel#69628
Port tensor variants of normal to structured kernel#69628nkaretnikov wants to merge 1 commit intogh/nkaretnikov/1/basefrom
Conversation
- Refactor tensor variants to use structured in native_functions.yaml - Other variants don't fit this model well, so not doing those for now - Remove some normal templates and RNG tests that relied on them. See #69386. [ghstack-poisoned]
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit bac7d8c (more details on the Dr. CI page):
🕵️ 11 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
CI Flow Status⚛️ CI FlowRuleset - Version:
You can add a comment to the PR and tag @pytorchbot with the following commands: # ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun
# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slowFor more information, please take a look at the CI Flow Wiki. |
lezcano
left a comment
There was a problem hiding this comment.
Just left two small comments, but overall this looks good to me.
| TORCH_META_FUNC2(normal, Tensor_Tensor) ( | ||
| Tensor const& mean, | ||
| Tensor const& std, | ||
| c10::optional<Generator> gen | ||
| ) { |
There was a problem hiding this comment.
Don't we need some checks here to make sure that mean and std have compatible dtypes? Or does this operation work with arbitrary dtypes?
There was a problem hiding this comment.
Oh, I see that the checks are currently done within the normal_out_impl function.
Those checks should be moved to the TORCH_META_FUNCs, which is what you do in the next PR. For this PR to stand on its own, I perhaps we could merge the next PR into this one and submit both of them as one?
There was a problem hiding this comment.
the reason the checks are in templates is because those are used by RNG tests. so i'm not sure it's a good idea to remove them if they can be accessed other than via the structured api. or is it the caller's responsibility to ensure they are calling the right thing (see aten/src/ATen/test/cpu_rng_test.cpp)?
| m.impl("normal.Tensor_float_out", normal_Tensor_float_out); | ||
| m.impl("normal.float_Tensor_out", normal_float_Tensor_out); | ||
| m.impl("normal.Tensor_Tensor_out", normal_Tensor_Tensor_out); | ||
| m.impl("normal.Tensor_float", normal_Tensor_float); | ||
| m.impl("normal.float_Tensor", normal_float_Tensor); | ||
| m.impl("normal.Tensor_Tensor", normal_Tensor_Tensor); |
There was a problem hiding this comment.
How come just half of them were deleted? In fact, is this necessary given that now these operations are implemented as structured kernels? cc @ysiraichi @peterbell10
There was a problem hiding this comment.
it's because these previously used the template version directly (to pass a custom rng for testing and demo this functionality). not sure what to do here since these are gone now
There was a problem hiding this comment.
I think the distribution templates are part of the public generator API, so can't be removed. e.g. the cryptographic PRNG uses them:
https://github.com/pytorch/csprng/blob/5a6d9458c142190d5d713744687434c73c06ad01/torchcsprng/csrc/kernels_body.inc#L257
There was a problem hiding this comment.
What do you think we should do here @mruberry ? Should we do the checks twice in these functions?
See #69628 (comment) for context.
There was a problem hiding this comment.
I think, ideally, we'd want those checks in the META function. One way out of this is to factor out the implementation (code after the type checks) into a new function. Then, we would have something like this:
normal_impl_out: dtype checks and callsnormal_impl_impl_out(not a very good name)normal_impl_impl_out: executes the rest of the implementation
Then, the IMPL function can just call normal_impl_impl_out directly, which would bypass dtype checks (these can be factored into a function of its own, and called in META, too).
Not sure whether the extra indirection is worth it, though.
There was a problem hiding this comment.
@bdhirsh will take a look soon -- I think he's the best person to help answer this question
There was a problem hiding this comment.
cc @pbelevich (who I think wrote the rng api).
It looks like the templates are public API to help write out-of-tree kernel extensions for distribution ops, so we can't get easily rid of them (without finding all external usages and making them structured too, which... would require external codegen and doesn't seem super beneficial to do). There's also more context described here.
If that's right, then I don't think that porting normal_* ops to structured will really help to clean up much code - we have to keep all of the functional/inplace/out= template variants around. @ysiraichi is also right, you'd need to make sure that all of the error checking logic currently in the template is run in the meta function (and also directly in the template, since out-of-tree kernel writers still need to rely on them).
Given all of that, it sounds to me like it would be easiest to just directly write meta kernels for all of the distribution ops.
There was a problem hiding this comment.
Given all of that, it sounds to me like it would be easiest to just directly write meta kernels for all of the distribution ops.
I'm going to do this and will create a new stack with the changes. This stack will stay open for now for reference.
I'll also fix the broadcasting issue that I introduced, which breaks BC.
| @@ -7768,28 +7768,28 @@ | |||
| Meta: normal_meta_ | |||
There was a problem hiding this comment.
Why is this one not ported to structured kernels as well? I reckon that we should have all the combinations of functions here (in-place / out-place / _out) for all the types of inputs (Tensor / float) for (mean / std), right?
There was a problem hiding this comment.
i'll look into it and follow up later. at first, it looked like it wasn't possible for some reason. maybe i just got confused
| Tensor const& std, | ||
| c10::optional<Generator> gen | ||
| ) { | ||
| auto shape = at::infer_size(mean.sizes(), std.sizes()); |
There was a problem hiding this comment.
Just so we don't forget: probably we want to do something like what resize_output_for_normal does here, inside META. Since we still have the same problem as the dtype checks, we should probably wait for Brian.
| at::native::templates::normal_out_impl<NormalStub, Generator>( | ||
| const_cast<Tensor&>(out), mean, std, gen); |
There was a problem hiding this comment.
I think it's a good idea to propagate the const to normal_out_impl functions (might be BC-breaking, not sure), instead of const_cast-ing.
There was a problem hiding this comment.
fwiw, this just mimics what some other operator does already, but yeah, i agree.
|
to avoid confusion, will open a new stack to address issues related to BC and broadcasting |
Stack from ghstack:
See #69386.