Skip to content

Fix precision issue of pow(int, float)#6103

Merged
qihqi merged 3 commits intomasterfrom
qihqi/pow
Dec 15, 2023
Merged

Fix precision issue of pow(int, float)#6103
qihqi merged 3 commits intomasterfrom
qihqi/pow

Conversation

@qihqi
Copy link
Copy Markdown
Collaborator

@qihqi qihqi commented Dec 12, 2023

fixes #5887

@qihqi qihqi changed the title Qihqi/pow Fix precision issue of pow(int, float) Dec 12, 2023
@qihqi qihqi requested a review from JackCaoG December 12, 2023 00:08
Comment thread torch_xla/csrc/tensor_methods.cpp Outdated
auto* xla_node = dynamic_cast<XlaNode*>(node.get());
at::ScalarType dtype =
TorchTypeFromXlaType(xla_node->xla_shape().element_type());
return input->CreateFrom(node, dtype);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can do something similar to input->CreateFrom(node, /*logical_element_type=*/nullptr) to make sure logical_lement_type is not being inherited from input, and it will by defualt use the xla_shape.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

Copy link
Copy Markdown
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM if CI is green.

Copy link
Copy Markdown
Collaborator

@wonjoo-wj wonjoo-wj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Unit tests CI should be fixed with a rebase.

@qihqi qihqi force-pushed the qihqi/pow branch 8 times, most recently from 2c8e384 to e2bafc2 Compare December 13, 2023 21:47
Currently we are casting float scalar to int32 tensor (to
match input1). This is incorrect as it would do
power in int and get incorrect results.

Fixed version will cast both to float and do the math in
float. The return value will be float tensor instead of
int tensor
@qihqi qihqi merged commit e500129 into master Dec 15, 2023
qihqi added a commit that referenced this pull request Dec 15, 2023
qihqi added a commit that referenced this pull request Dec 15, 2023
qihqi added a commit that referenced this pull request Dec 15, 2023
@qihqi qihqi deleted the qihqi/pow branch April 29, 2024 21:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Core ATen Opset] Lower aten_pow_Tensor_Tensor

3 participants