Skip to content

Restriding logic for structured out kernels looks suspect #53587

@ezyang

Description

@ezyang

Noticed this while reading #53535. Currently the logic is

{maybe_set_guard}
at::native::resize_output(outputs_[output_idx], sizes);
if (!strides.empty()) {{
    TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
    at::native::as_strided_(outputs_[output_idx], sizes, strides);
}} else if (options.memory_format_opt().has_value()) {{
    outputs_[output_idx].get().unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
}}     

This looks a bit suspect. In particular, if the idiomatic way to say "this operator preserves input layout" is:

  set_output(full_output_size, input.options().memory_format(input.suggest_memory_format()));

(as seen in #53535) then as far as I can tell we will just clobber the old strides and use the suggested memory format. Which is totally wrong headed; out= takes precedence.

How did I make this mistake? This logic was introduced in #48718 and I am pretty sure I got it by lifting it straight out of TensorIterator, which looks like this:

void TensorIterator::set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) {
  // NB: intentionally no superclass call
  auto& op = operands_[output_idx];
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx < num_outputs_);
  if (!op.tensor.defined()) {
      if (strides.empty()) {
        op.tensor = at::empty(sizes, options);
      } else {
        op.tensor = at::empty_strided(sizes, strides, options);
      }
      op.current_dtype = op.target_dtype;
  } else if (op.will_resize) {
      at::native::resize_output(op.tensor, sizes);
      if (!strides.empty()) {
        TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
        op.tensor.as_strided_(sizes, strides);
      } else if (options.memory_format_opt().has_value()) {
        op.tensor.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
      }
  }   
  if (!names.empty()) {
    TORCH_INTERNAL_ASSERT(op.tensor.defined());
    namedinference::propagate_names(op.tensor, names);
  }   
}

The distinction I missed, is that we only restride if we resized the tensor. If no resizing occurs, we don't restride. The structured kernel logic fails to capture this distinction.

This sooort of kind of suggests that add is broken. Is it? It is not.

>>> out = torch.empty((2,2,2,2), memory_format=torch.channels_last)
>>> x = torch.empty((2,2,2,2))
>>> torch.add(x, x, out=out)
tensor([[[[ 9.0366e+30,  9.1328e-41],
          [ 9.0366e+30,  9.1328e-41]],

         [[ 3.1131e+28,  3.5487e+28],
          [ 2.1018e-38,  1.3469e+23]]],


        [[[ 1.3424e+23,  0.0000e+00],
          [ 3.6154e-43,  0.0000e+00]],

         [[-2.3647e-01,  6.1385e-41],
          [-1.2387e-02,  6.1385e-41]]]])
>>> print(out.stride())
(8, 1, 4, 2)

I think the reason is that for an output, if it is not undefined in the operand, TensorIterator will read off the stride directly off the original tensor

        auto tensor_stride = invert_perm(op.stride_bytes);
        for (int dim = 0; dim < ndim(); dim++) {
          tensor_stride[dim] /= element_size;
        }
        set_output(i, tensor_shape, tensor_stride, op.options(), names_);

so it always looks "right" and you end up with a no-op in this case. We could force Meta functions to write code in this way:

  if (const auto& t = maybe_get_output()) {
    set_output(full_output_size, t.strides(), input.options());
  } else {
    set_output(full_output_size, input.options().memory_format(input.suggest_memory_format()));
  }

But this is ugly and long and error prone; better to change the semantics of set_output to DTRT.

cc @VitalyFedyunin @glaringlee @bdhirsh

Metadata

Metadata

Assignees

Labels

module: TensorIteratortriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions