Revamped test cases for Gemm#2060
Conversation
* Increased number of tests from 2 to 8. * The tests each test a specific attribute or input case. * Generated test data. * Updated the operators doc. * Updated the test coverage.
|
@JamesAllingham welcome and thank you for the contribution to ONNX. It's ok to remove some test cases as long as the scenarios/cases are also covered by new ones. Please fix the CI failures: +flake8 |
|
@linkerzhang Thanks for the welcome. I'm very keen to be getting involved with this project. And thanks for the feedback. What my test cases don't cover (that the old ones did) is whether the |
|
Any feedback on this? Would be great to get this closed so that I can go ahead with some more PRs (would be nice to know I'm doing things correctly before doing more 😃) |
|
Hey @linkerzhang any feedback on this? I also have another PR which could use some love. I have quite a few more tests I'd like to submit PRs for but I really would like to get some feedback on what I'm doing before submitting! |
postrational
left a comment
There was a problem hiding this comment.
This looks good to me. I think we should merge it.
@JamesAllingham, could you update your branch?
|
@postrational I've updated the branch. Let me know if I need to do anything else to get this merged. |
| a = np.random.ranf([3, 5]).astype(np.float32) | ||
| b = np.random.ranf([5, 4]).astype(np.float32) | ||
| c = np.zeros([1, 4]).astype(np.float32) | ||
| y = 0.5 * np.dot(a, b) + c |
There was a problem hiding this comment.
[nit] This is good enough for tests. On the other hand, it'd be nice if you can create a reference implementation of Gemm using numpy and use it in all tests. For example,
y = your_gemm_reference_impl(a, b, c, ...)
There was a problem hiding this comment.
I agree, it will be good to have a reference implementation, I'll add one. 👍
| b = np.random.ranf([5, 4]).astype(np.float32) | ||
| c = np.zeros([1, 4]).astype(np.float32) | ||
| y = np.dot(a, b) + c | ||
| expect(node, inputs=[a, b, c], outputs=[y], |
There was a problem hiding this comment.
No bias? Do you mean that c will be removed?
There was a problem hiding this comment.
What I meant was that the bias term has no impact on the result. I think this is a useful test case because in many frameworks the bias term of the fully-connected layer is optional with a default of 0. I'll rename the test to make this more clear.
With that in mind, do you think it is worth changing the spec for Gemm so that the bias is optional and has a default of 0? I'm happy to do so, and I don't think it is too much work or too big a change. (Then I can add a test which legitimately has no bias term).
There was a problem hiding this comment.
It totally makes sense to me. Could we do it in another PR? I will try to get it reviewed (and then hopefully merged) before the deadline (9/19 is the targeted release date)? I remember that I had to insert zero initializers when converting cases without bais. The case with C=0 is also a common case in practices, so we should have an one-to-one mapping in the IR.
There was a problem hiding this comment.
Sure, I'll make another PR for it. I'm afraid it will have to be about 11 hours from now though!
| ) | ||
| a = np.random.ranf([2, 3]).astype(np.float32) | ||
| b = np.random.ranf([3, 4]).astype(np.float32) | ||
| c = np.random.ranf([1]).astype(np.float32) |
There was a problem hiding this comment.
[1] is not a scalar. It's a 1-element vector. Maybe we can use np.array(1) whose shape is (). As mentioned in my comment above, it'd even nicer if we have a reference implementation.
There was a problem hiding this comment.
You are totally correct. I've fixed this. And I am going to do a reference implementation.
| 'Gemm', | ||
| inputs=['a', 'b', 'c'], | ||
| outputs=['y'], | ||
| alpha=0.5, |
There was a problem hiding this comment.
May we keep this one? It represents a case where all attributes are non-default.
There was a problem hiding this comment.
Agreed and done 👍
There was a problem hiding this comment.
(I renamed this test, it is now called 'test_gemm_all_attributes')
1. Differentiated between sclar and single element vector biases. 2. Changed the name of 'test_gemm_default_no_bias' to 'test_gemm_default_zero_bias' to be more clear. 3. Added back the test case for specifying all attributes as non-default.
Cool, I think I've addressed all of your comments, please let me know if I misunderstood anything. Also, let me know what you think about allowing the bias input ( |
Also added the mypy type annotation for the Gemm reference implementation.
| c = np.zeros([1, 4]).astype(np.float32) | ||
| a = np.random.ranf([2, 3]).astype(np.float32) | ||
| b = np.random.ranf([3, 4]).astype(np.float32) | ||
| c = np.random.ranf(1).astype(np.float32) |
There was a problem hiding this comment.
This is still not a scalar.
>>> np.random.ranf(1).shape
(1,)
As you already have a reference implementation, could we do
| c = np.random.ranf(1).astype(np.float32) | |
| c = 3.14 or c=np.array(3.14) |
?
There was a problem hiding this comment.
Whoops! Let me fix that!
There was a problem hiding this comment.
Okay, I think I've fixed this for real now, sorry about that :)
| A = A if transA == 0 else A.T | ||
| B = B if transB == 0 else B.T | ||
|
|
||
| Y = alpha * np.dot(A, B) + beta * C |
wschin
left a comment
There was a problem hiding this comment.
Overall looks good and it improves the document considerablely. Just last few comments to address.
6d91ea2 to
076cf51
Compare
One unfortunate thing about the documentation is that the reference implementation is not shown in |
* Revamped test cases for Gemm * Increased number of tests from 2 to 8. * The tests each test a specific attribute or input case. * Generated test data. * Updated the operators doc. * Updated the test coverage. * Fixed linting errors causing CI failures * Updated the TestCoverage and Operators docs for the newly re-linted Gemm tests. * Tweaks to Gemm test cases: 1. Differentiated between sclar and single element vector biases. 2. Changed the name of 'test_gemm_default_no_bias' to 'test_gemm_default_zero_bias' to be more clear. 3. Added back the test case for specifying all attributes as non-default. * Added a reference implementation for the Gemm op * Fixed some mistakes with the types of transA and transB. Also added the mypy type annotation for the Gemm reference implementation. * The Gemm scalar bias test actually adds a scalar bias now
Hi all,
I'm currently working on an ONNX importer for the neural network's framework in Wolfram Mathematica. While developing I've found it useful to add more test cases to make sure that I'm following the ONNX spec properly, so I figured I should contribute them back upstream so that others can also use them. I'll definitely be adding more in the near future, I just wanted to get my hands dirty with this to make sure I'm doing things correctly. So assuming that these changes are welcome I'll be adding more soon. :)
This is my first time contributing to the project, so please do let me know if I'm doing anything wrong! After adding the test cases, I re-installed ONNX, ran the test suite, as well as the type checker, and didn't see any issues. I also made linted the code to make sure it was all looking good.
Here is what the PR changes:
Increases the number of tests for Gemm from 2 to 8.
The tests now each test a specific attribute or input case.
Adds test data.
Updates the operators doc.
Updates the test coverage.
Regarding point 2, I wanted to make it easier for myself, and other developers, to incrementally add support for the attributes and other features of the Gemm operator. By splitting the tests up so that each one only covers a particular attribute or input case I think this goal is accomplished. It also makes it immediately clear what has broken when a test case regression occurs.
One thing I was a little worried about when making this change was that it would not be allowed to remove existing test cases. If this is the case I'll add the old cases back -- I simply removed them because they were a bit redundant. I also imagine there might be some feedback about how I've named the test cases.