Skip to content

Update Pow input types in Opset 12#2666

Merged
ebarsoum merged 6 commits intoonnx:masterfrom
lara-hdr:lahaidar/pow_opset12
Mar 20, 2020
Merged

Update Pow input types in Opset 12#2666
ebarsoum merged 6 commits intoonnx:masterfrom
lara-hdr:lahaidar/pow_opset12

Conversation

@lara-hdr
Copy link
Copy Markdown
Contributor

Update Pow in opset 12 to allow operands of different types as inputs, and allow int types inputs.

@lara-hdr lara-hdr requested a review from a team as a code owner March 17, 2020 23:19
Comment thread onnx/defs/math/defs.cc Outdated
Copy link
Copy Markdown
Member

@snnn snnn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested it locally, looks good.

@snnn
Copy link
Copy Markdown
Member

snnn commented Mar 18, 2020

So, this bug fix is for test_mean_square_distance_mean_3d_expanded and some others, not for test_softmax_cross_entropy_mean_3d, right?

@lara-hdr
Copy link
Copy Markdown
Contributor Author

@snnn yes, this is the link for softmax cross entropy: microsoft/onnxruntime#3237.
@codemzs mentioned still seeing a "Could not find an implementation for the node Max(12)" error after the fix.

@snnn
Copy link
Copy Markdown
Member

snnn commented Mar 18, 2020

@snnn yes, this is the link for softmax cross entropy: microsoft/onnxruntime#3237.
@codemzs mentioned still seeing a "Could not find an implementation for the node Max(12)" error after the fix.

That's normal. Because onnxruntime doesn't have an implementation for the OP yet, that's expected.

Comment thread onnx/defs/math/defs.cc
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain input and output types to float tensors.")
{"tensor(int32)",
Copy link
Copy Markdown
Contributor

@spandantiwari spandantiwari Mar 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The set of allowed types T is more limited than T1 for exponent. Is there a strong reason not to support other integer types?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I restricted the types of T following this comment #2666 (comment) to avoid overflow.
Another option would be to accept any type for X as well.
But then the output type would be impossible to infer..

Copy link
Copy Markdown
Contributor

@spandantiwari spandantiwari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Added a comment on type support.

@lara-hdr
Copy link
Copy Markdown
Contributor Author

@ebarsoum @wschin for review

Comment thread docs/Operators.md
@ebarsoum ebarsoum merged commit 46fe392 into onnx:master Mar 20, 2020
linkerzhang added a commit that referenced this pull request Mar 31, 2020
* Fix Greater/LessOrEqual function definition (#2645)

* Fix Greater/LessOrEqual function definition

* Update test data

Co-authored-by: Ke Zhang <kezhan@microsoft.com>

* Suppress a warning in unsqueeze (#2637)

I keep getting this warning when building PyTorch:

```
In file included from
/home/hong/wsrc/pytorch/third_party/onnx/onnx/defs/tensor/utils.h:6,
                 from
/home/hong/wsrc/pytorch/third_party/onnx/onnx/defs/tensor/defs.cc:4:
/home/hong/wsrc/pytorch/third_party/onnx/onnx/defs/tensor/defs.cc: In
lambda function:
/home/hong/wsrc/pytorch/third_party/onnx/onnx/defs/tensor/defs.cc:1414:22:
warning: unnecessary parentheses in declaration of �i�
[-Wparentheses]
           for (size_t(i) = 0; i < axes.size(); ++i) {
                      ^
/home/hong/wsrc/pytorch/third_party/onnx/onnx/defs/schema.h:959:12:
note: in definition of macro �ONNX_OPERATOR_SET_SCHEMA_EX�
     return impl.SetName(#name)
\
            ^~~~
/home/hong/wsrc/pytorch/third_party/onnx/onnx/defs/tensor/defs.cc:1369:1:
note: in expansion of macro �ONNX_OPERATOR_SET_SCHEMA�
 ONNX_OPERATOR_SET_SCHEMA(
```

This commit should fix it and modernize the code a bit.

Co-authored-by: Ke Zhang <kezhan@microsoft.com>

* [Training] Add Adagrad optimizer operator (#1955)

* Adagrad draft

* MIMO

* Support multiple tensors to be optimized

* Address comments

* Move optimizers to a new place

Remove copied

Add momentum

Save

Remove momentum

Fix

Move constants to attributes

* Fix build

* Add shape test

Add two node tests

Update test coverage

* Fix shape inf

* Fix shape inf

* fix shape inf

* Format

* Add function type

* Merge lines

* Format

* Fix version number

* Update op version in model files

* Fix a test function and update related test files

* Update onnx/backend/test/case/node/adagrad.py

* Remove unused file

* sync docs

* Fix shape test

* sync doc

* sync with master

* Update onnx/defs/training/defs.cc

Co-Authored-By: Michał Karzyński <postrational@users.noreply.github.com>

* sync doc

* address comments

* address a minor comment

* Polish one line

Co-authored-by: Michał Karzyński <postrational@users.noreply.github.com>

* [Training] SG with Momentum Optimizer (#1959)

* SG with Momentum

* Registrate Op

Fix

Update other docs

* Add shape inference code and polish definition

* Update docs

* Add test cases and fix several bugs

* Remove accidently added copy

* Alpha -> alpha & Beta -> beta

* Clarify an attribute

* Fix an attribute

* Fix bug

* Fix missing attributes

* sync doc

* Remove unused domain

* sync with master

Co-authored-by: Chin Huang <chhuang@us.ibm.com>

* Change type of label tensor to int32/int64 in SoftmaxCrossEntropyLoss spec. (#2667)

* Update Pow input types in Opset 12 (#2666)

* Update Pow input types in Opset 12

* gen doc and tests

* remove uints and 8 bit ints

* add tests

* remove uint int x tets

* Adding CI for ONNX Debug mode (Linux, OSX) (#2651)

* adding an osx build, linux build, with and without onnx_ml for debug mode

* test debug mode with ONNX_ML=1

* Rename OPTIONAL to OPTIONAL_VALUE (#2682)

Co-authored-by: G. Ramalingam <grama@microsoft.com>

* Update Batchnorm test (#2674)

* Update Batchnorm test

* relax shape inference on scalar

* Remove unnecessary copies and std::move (#2684)

* Update sequence test case so input is not scalar and splits are specified (#2675)

* Update sequence test case to input is not scalar and splits are specified

* Add spaces to make the checker happy

* Use cmake GNUInstallDirs (#2661)

https://cmake.org/cmake/help/latest/module/GNUInstallDirs.html
this make allow install the libraries (and headers) in different location than `lib` (Gentoo uses lib64 for 64-bits libs)
also change the .cmake files for avoid conclicts if build both 32-bis and 64-bits (avoids conflict/overwrite files)

Co-authored-by: Ke Zhang <kezhan@microsoft.com>

* Add 'ignore_index' input in the spec for SoftmaxCrossEntropyLoss and NLLLoss. (#2680)

* Add 'ignore_index' input in the spec for SoftmaxCrossEntropyLoss and NLLLoss.

* Add tests.

* build break.

* build break.

* clean up.

* build break.

* Change ignore_index to attribute.

* Change ignore_index to attribute.

* PR feedback.

* PR feedback.

* Make ignore_index optional in NLLLoss.

* Build break.

* remove trailing spaces to fix build break.

* Build break.

* Update spec doc.

* Fix NLLLoss function definition to fix test: test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_sum_ignore_index_expanded

* PR feedback.

* Fix test for softmax cross entropy loss to exclude ignored_index'ed weights from the sum of weights.

* Build break.

* Reduce binary size of libraries consuming ONNX (part 1/2) (#2643)

* Change the return type for the zipmap operator to match the description in the spec.

* Reduce binary size of libraries consuming ONNX (part 1/2)

* Fix build error

* Replace separate Get*Doc() functions with easy macro for greater convenience

* Add one more macro for complicated operator doc documentation.

Co-authored-by: Ke Zhang <kezhan@microsoft.com>

* Update pybind (#2340) (#2688)

* Change version number for release verification

Change version number for release verification

Co-authored-by: Takeshi Watanabe <take-cheeze@users.noreply.github.com>
Co-authored-by: Ke Zhang <kezhan@microsoft.com>
Co-authored-by: Hong Xu <hong@topbug.net>
Co-authored-by: Wei-Sheng Chin <wschin@outlook.com>
Co-authored-by: Michał Karzyński <postrational@users.noreply.github.com>
Co-authored-by: M. Zeeshan Siddiqui <mzs@microsoft.com>
Co-authored-by: Lara Haidar <haidar.lara@gmail.com>
Co-authored-by: Vinitra Swamy <vinitras@gmail.com>
Co-authored-by: Changming Sun <chasun@microsoft.com>
Co-authored-by: G. Ramalingam <grama@microsoft.com>
Co-authored-by: Changming Sun <me@sunchangming.com>
Co-authored-by: Scott McKay <skottmckay@gmail.com>
Co-authored-by: Gustavo Alvarez <462213+sl1pkn07@users.noreply.github.com>
Co-authored-by: Pranav Sharma <prs@microsoft.com>
@chinhuang007 chinhuang007 added this to the 1.7 milestone Mar 31, 2020
jcwchen pushed a commit to jcwchen/onnx that referenced this pull request Sep 23, 2020
* Update Pow input types in Opset 12

* gen doc and tests

* remove uints and 8 bit ints

* add tests

* remove uint int x tets
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants