Conversation
| at::Tensor XLANativeFunctions::stack(at::TensorList tensors, int64_t dim) { | ||
| TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); | ||
| at::ScalarType result_type = at::native::result_type(tensors); | ||
| std::vector<at::Tensor> c_tensors(tensors.size()); |
There was a problem hiding this comment.
is stack expecting input tensor to be CPU? std::vector<at::Tensor> c_tensors will return a list of tenosrs on CPU right?
There was a problem hiding this comment.
I don't think so. Unless I'm missing something, they are casted tensors, on XLA.
There was a problem hiding this comment.
Then I am abit confused. Reading your code, you init the c_tensors vector which I assume they will be cpu tensors since you didn;t provide the device type. In the later code you only update the dtype of these c_tensors, I don't know when are they moved to the XLA device.
There was a problem hiding this comment.
Here's a summary of what this code is doing: considering the arguments tensors (a list of XLA tensors) and dim, the function:
- Computes the common data-type of all tensors:
result_type - Converts each tensor to the common data-type, storing the result in
c_tensors(as in "cast tensors") - Calls
tensor_methods::stackwith the casted tensors
There was a problem hiding this comment.
Oh I see. transform is called with tensors.begin()..
57352fb to
5fbcdd9
Compare
Fix: #7083
This PR adds data-type promotion to
stackoperation. Previously, there was none. So, the kernel implicitly expected the arguments to be of the same data-type. This might not be the case when using AMP.cc @miladm @JackCaoG