Expand the test of torch.addbmm and torch.baddbmm#47079
Expand the test of torch.addbmm and torch.baddbmm#47079
Conversation
Codecov Report
@@ Coverage Diff @@
## master #47079 +/- ##
===========================================
+ Coverage 36.03% 60.82% +24.79%
===========================================
Files 437 2748 +2311
Lines 55239 254096 +198857
===========================================
+ Hits 19906 154565 +134659
- Misses 35333 99531 +64198 |
|
Hi @zasdfgbnm! Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention. You currently have a record in our system, but we do not have a signature on file. In order for us to review and merge your code, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks! |
| @onlyOnCPUAndCUDA | ||
| @dtypesIfCUDA(*(torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM))) | ||
| @dtypes(*(torch.testing.get_all_complex_dtypes() + [torch.float, torch.double])) | ||
| @tf32_on_and_off(0.05) |
There was a problem hiding this comment.
floating_and_complex_types() on the previous line
There was a problem hiding this comment.
It's not missing, after offline discussion, we decided that instead of testing on floating_and_complex_types(), we test on all fp and complex, and assertRaise for unsupported dtypes. So this comment no longer applies.
|
|
||
| for perm1, perm2 in product(permutations((0, 1, 2)), repeat=2): | ||
| for perm3 in permutations((0, 1)): | ||
| b1 = torch.randn(num_batches, M, N, dtype=dtype, device=device) |
There was a problem hiding this comment.
same as for baddbmm, please compare with numpy instead of this convoluted test logic. Also, if it makes sense to unify this test with baddbmm (at first glance it looks similar), please do so.
There was a problem hiding this comment.
I think it's better to test separately. The reason is, the output of addbmm is 2D but baddbmm is 3D, so it is
for perm1, perm2 in product(permutations((0, 1, 2)), repeat=2):
for perm3 in permutations((0, 1)):vs
for perm1, perm2, perm3 in product(permutations((0, 1, 2)), repeat=3):I don't know how these two can be combined easily.
Edit: maybe possible, I think I found a solution
There was a problem hiding this comment.
Also, I don't think numpy has something like addbmm, the the difference is only torch.bmm(b1, b2).sum(0, False) vs numpy.matmul(b1, b2).sum(0, False), in this case, does it still make sense to change to numpy?
💊 CI failures summary and remediationsAs of commit 781f4af (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
|
Broadcasting test added |
|
@ngimel I have resolved most of your comments:
except that
|
|
@ngimel My CLA is resolved |
|
this should be ready |
|
|
||
| res2.addbmm_(b1, b2) | ||
| self.assertEqual(res2, res.sum(0, False)) | ||
| getattr(res2, func + "_")(b1, b2) |
There was a problem hiding this comment.
Consider explicitly passing the variants you want (i.e. passing torch.foo and torch.Tensor.foo and torch.Tensor.foo_)
There was a problem hiding this comment.
I can not, because I need a string for the name of the operator to assert on warning messages.
| self.assertEqual(res2, ref) | ||
| res3.copy_(res2) | ||
|
|
||
| with self.maybeWarnsRegex( |
There was a problem hiding this comment.
Reviewer's note (not for this PR): we need a way of always triggering these warnings. "maybeWarnsRegex" doesn't do anything, currently.
| if self.device_type == 'cpu': | ||
| is_supported = True | ||
| if dtype == torch.bfloat16: | ||
| self.precision = 1 # 43 vs 43.75 |
There was a problem hiding this comment.
Reviewer's note: It'd be nicer if we supported per-dtype precision overrides (also atol and rtol overrides)
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: This is to satisfy the request at pytorch#42553 (comment). See also pytorch#47124 Pull Request resolved: pytorch#47079 Reviewed By: ejguan Differential Revision: D24735356 Pulled By: ngimel fbshipit-source-id: 122fceb4902658f350c2fd6f92455adadd0ec2a4
This is to satisfy the request at #42553 (comment). See also #47124