Add proper shape checking to torch.cat#4087
Conversation
Asserts that the inputs have the same size except in the cat dimension or are empty (or a mix of both).
| int64_t first_dims = first->nDimension; | ||
| int64_t second_dims = second->nDimension; | ||
| THArgCheck(first_dims == second_dims, 0, | ||
| "Tensors must have same number of dimensions: got %d and %d", |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
I hate to ask this now that the code is all written, but is there a compelling reason why this couldn't have been implemented in ATen? (I guess, no easy way to insert the checks into the generated code?) |
|
Good point, I hadn't thought about implementing the shape checks in aten. I think it shouldn't be too bad to do that: I could rename The two downsides I see with this approach are that:
Alternatively I could just rewrite |
|
we should keep them in TH/THC. there's no upside (now that the code is written) to do this in ATen. |
* Fix catArray in THTensor Asserts that the inputs have the same size except in the cat dimension or are empty (or a mix of both). * Fix catArray for THCTensor * Document torch.cat shape checks * Fix types
* Fix catArray in THTensor Asserts that the inputs have the same size except in the cat dimension or are empty (or a mix of both). * Fix catArray for THCTensor * Document torch.cat shape checks * Fix types
Fixes #4071.
THTensor_(catArray)andTHCTensor_(catArray)do some strange thing where the shapes of the tensors don't have to be the same in some dimensions in some cases. This adds full shape checking.Non-empty tensor arguments to torch.cat must all have the same shape, except in the cat dimension.
Test Plan
New unit tests: