Implement hstack, vstack, dstack#42799
Conversation
💊 CI failures summary and remediationsAs of commit a03e866 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
|
@mruberry PTAL |
There was a problem hiding this comment.
Nit: these examples are excellent but maybe
a = torch.tensor([[1],[2],[3]])
b = torch.tensor([[4],[5],[6]])
would be clearer?
There was a problem hiding this comment.
See numbering suggestion below.
There was a problem hiding this comment.
See numbering suggestion above.
There was a problem hiding this comment.
These tests are good, but this test's case generation is limited to replicating the same tensor shape for each element of the input list. Here are some cases I was thinking about:
- op(t)
- the behavior of
np.hstack(a)is strange andnp.hstack(a) != np.hstack((a,))(the same is true fornp.dstack) - do we even support non-tuple arguments? if not we should validate this throws a runtime error
- if we support single tensor arguments, is
np.hstack(a)'s andnp.dstack(a)'s behavior correct?
- the behavior of
- op((a, b, c, ...))
- validating that if they differ on an unexpected dim an error is thrown (maybe
_test_special_stacksshould take a dim argument corresponding to the op?) - validating that if they differ only on the expected dim the result is equivalent to NumPy
- validating that if they differ on an unexpected dim an error is thrown (maybe
- are tensors with a size zero dim handled correctly? (if not that's OK, but let's assert it doesn't work)
np.hstackhas special-handling of 1D tensors (as your implementation does), doestest_hstackneed a custom elaboration to test that behavior, in particular?- validating that tensors with different shapes but the same post-
atleast_Xdshapes meet the criteria work
For an example of the last bullet:
a = np.array([[[1],[2],[3]]])
b = np.array((4, 5, 6))
np.dstack((a, b))
: array([[[1, 4],
[2, 5],
[3, 6]]])
This is a good number of cases but validating each one by hand shouldn't be too laborious, I hope.
What are your thoughts? Are there other cases I missed?
mruberry
left a comment
There was a problem hiding this comment.
Overall looks excellent. A couple minor nits about the doc examples and questions about test coverage.
315a28d to
18e9013
Compare
|
@mruberry I have added tests that I think cover all of the cases. They cover:
For the last two, those are tested from 1 to 4 dimensions, so the special behavior for hstack is included with that. I have also added some autograd tests in a similar manner to the existing stack autograd test. Does this sound good, or do I need more tests? |
| else: | ||
| # Invalid dimensions, test for error | ||
| with self.assertRaisesRegex(RuntimeError, "Sizes of tensors must match except in dimension"): | ||
| torch_fn(torch_input) |
There was a problem hiding this comment.
Would you add an assert that NumPy also throws a runtime error in this case? You don't need to assert a string is thrown:
with self.assertRaises(RuntimeError):
np_fn(np_input)
mruberry
left a comment
There was a problem hiding this comment.
Nice work, @muthuArivoli!
Would you just fix that one minor nit on the tests and we'll get this merged?
Let me know if you're interested in working on a new problem.
|
@mruberry I added the numpy error check. Is it ok that numpy throws a ValueError, while we throw a RuntimeError? Yes, I'm interested in working on a new problem, do you have any recommendations? |
facebook-github-bot
left a comment
There was a problem hiding this comment.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Absolutely OK. Nice work.
For symmetry there are the split functions, hsplit, vsplit, and dsplit. A slightly more challenging binary function is divmod, because it returns two tensors. There are the "polynomial" functions, like polyadd and polyder, but I'm hoping someone will write all of them near simultaneously because they have a lot of common structure. There are also unary functions, like nan_to_num, that would be very helpful. If you'd like something more exotic or especially numerically challenging there are also functions like the kaiser windowing function. |
|
Two questions:
|
Excellent questions.
|
Summary: Related to pytorch#38349 Pull Request resolved: pytorch#42799 Reviewed By: izdeby Differential Revision: D23140704 Pulled By: mruberry fbshipit-source-id: 6a36363562c50d0abce87021b84b194bb32825fb
Related to #38349