Skip to content

Function request: support returning multiple values in CPU kernel #51108

@RockingJavaBean

Description

@RockingJavaBean

🚀 Feature

Allow users to return multiple values in the CPU kernel.

Motivation

In the NumPy-like functionality rollup list #38349, there are some functions that require returning two tensors, such as divmod and frexp.

They are more complicated and less straightforward to be implemented compared to other NumPy-like functions, as the current CPU elementwise kernels like cpu_kernel, cpu_kernel_vec and cpu_serial_kernel only support one output tensor.

Adding a new kernel function in aten/src/ATen/native/cpu/Loops.h that supports multiple outputs could decrease the complexity of implementing such NumPy-like functions, and may help developers implement torch functions with multiple outputs more conveniently in the future.

Pitch

PR: #51097
Implement a new kernel function cpu_kernel_multiple_outputs. Instead of a scalar type output, it requires developer return output values using std::tuple.

Example code:

auto iter = at::TensorIteratorConfig()
  .add_output(out1)
  .add_output(out2)
  .add_input(in1)
  .add_input(in2)
  .build();
at::native::cpu_kernel_multiple_outputs(iter,
  [=](float a, float b) -> std::tuple<float, float> {
    float add = a + b;
    float mul = a * b;
    return std::tuple<float, float>(add, mul);
  }
);

The out1 tensor will equal to torch.add(in1, in2), while the out2 will equal to torch.mul(in1, in2).

Alternatives

Instead of leveraging CPU kernel functions, developers have to use a more primitive for_each TensorIterator function.
This requires developers to manually handle logics like data type casting, offset calculations via strides and etc.

Additional context

cc @mruberry @heitorschueroff

Metadata

Metadata

Assignees

No one assigned

    Labels

    function requestA request for a new function or the addition of new arguments/modes to an existing function.module: TensorIteratormodule: reductionstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions