Skip to content

Add data-type promotion to stack.#7091

Merged
ysiraichi merged 2 commits intomasterfrom
ysiraichi/fix-stack-dtype-promotion
May 23, 2024
Merged

Add data-type promotion to stack.#7091
ysiraichi merged 2 commits intomasterfrom
ysiraichi/fix-stack-dtype-promotion

Conversation

@ysiraichi
Copy link
Copy Markdown
Collaborator

Fix: #7083

This PR adds data-type promotion to stack operation. Previously, there was none. So, the kernel implicitly expected the arguments to be of the same data-type. This might not be the case when using AMP.

cc @miladm @JackCaoG

@ysiraichi ysiraichi requested a review from JackCaoG May 22, 2024 00:42
at::Tensor XLANativeFunctions::stack(at::TensorList tensors, int64_t dim) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
at::ScalarType result_type = at::native::result_type(tensors);
std::vector<at::Tensor> c_tensors(tensors.size());
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is stack expecting input tensor to be CPU? std::vector<at::Tensor> c_tensors will return a list of tenosrs on CPU right?

Copy link
Copy Markdown
Collaborator Author

@ysiraichi ysiraichi May 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. Unless I'm missing something, they are casted tensors, on XLA.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then I am abit confused. Reading your code, you init the c_tensors vector which I assume they will be cpu tensors since you didn;t provide the device type. In the later code you only update the dtype of these c_tensors, I don't know when are they moved to the XLA device.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's a summary of what this code is doing: considering the arguments tensors (a list of XLA tensors) and dim, the function:

  1. Computes the common data-type of all tensors: result_type
  2. Converts each tensor to the common data-type, storing the result in c_tensors (as in "cast tensors")
  3. Calls tensor_methods::stack with the casted tensors

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. transform is called with tensors.begin()..

@ysiraichi ysiraichi force-pushed the ysiraichi/fix-stack-dtype-promotion branch from 57352fb to 5fbcdd9 Compare May 22, 2024 14:18
@ysiraichi ysiraichi requested a review from JackCaoG May 22, 2024 18:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[torchbench] timm_efficientdet training failing on non-dynamo.

2 participants