Skip to content

mul: convert inputs to result type.#7130

Merged
ysiraichi merged 3 commits intomasterfrom
ysiraichi/fix-mul-dtype-promotion
Jun 3, 2024
Merged

mul: convert inputs to result type.#7130
ysiraichi merged 3 commits intomasterfrom
ysiraichi/fix-mul-dtype-promotion

Conversation

@ysiraichi
Copy link
Copy Markdown
Collaborator

Fix: #7084

This PR fixes a data-type related problem for mul operation. It does so by creating a structure OpConfig that behaves similarly to DoBinaryOp. The difference is that it takes care of pre/post-processing of inputs and outputs, casting them to the correct data-type.

Problem

t = torch.rand(10, dtype=torch.half).to(xm.xla_device())
s = torch.tensor(10, dtype=torch.double).to(xm.xla_device())
out = torch.mul_(t, s)
  • Tensor.mul_ is dispatched to its CompositeExplicitAutograd kernel
    • It wraps the scalar into a tensor, and calls torch.mul (functional version)
  • DoBinaryOp is called
    • Computes at::result_type (let's call it common_dtype) and passes it on to bin_op
    • Note that UnwrapNumber does nothing, since s is a tensor with is_wrapped_number_ unset
  • The computed common_dtype is passed on to tensor_methods::mul
    • Creates an IR node with data-type common_dtype
    • Does nothing with its inputs
  • Later, when BuildMul is called, we have 2 XlaOp with different data-types: f16 and f64
    • BuildMul promotes f16 to f64
  • The output is common_dtype (torch.float16), but the actal XlaOp is f64

Solution

Following PyTorch behavior [1, 2, 3], I created OpConfig: a structure that let us specify common pre/post-processing on inputs and outputs.

Affected Models

  • timm_nfnet (training+nondynamo)

cc @miladm @JackCaoG @lezcano

@ysiraichi ysiraichi requested review from JackCaoG and lezcano May 28, 2024 23:34
@JackCaoG
Copy link
Copy Markdown
Collaborator

hmm I am surprise that bf16 and f64's at::result_type is f64..

Comment thread torch_xla/csrc/init_python_bindings.cpp
Comment thread torch_xla/csrc/aten_xla_type.cpp Outdated
Comment thread torch_xla/csrc/aten_xla_type.cpp
Comment thread torch_xla/csrc/aten_xla_type.cpp
@ysiraichi
Copy link
Copy Markdown
Collaborator Author

hmm I am surprise that bf16 and f64's at::result_type is f64.

This is a bit confusing, so let me try and clarify the cases, given op(bf16, f64):

bf16 f64 at::result_type PromoteType
tensor scalar bf16 f64
scalar tensor f64 f64
tensor tensor f64 f64

@ysiraichi ysiraichi force-pushed the ysiraichi/fix-mul-dtype-promotion branch from 7221e69 to 7e28074 Compare May 29, 2024 15:39
@lezcano
Copy link
Copy Markdown
Collaborator

lezcano commented May 29, 2024

The tl;dr is: Scalars don't promote unless they are of a different kind.
Here are the exact rules: https://pytorch.org/docs/stable/tensor_attributes.html

@ysiraichi ysiraichi force-pushed the ysiraichi/fix-mul-dtype-promotion branch from 7e28074 to a1b67f2 Compare May 29, 2024 20:36
@bhavya01
Copy link
Copy Markdown
Collaborator

@ysiraichi This PR is failing the mul operation on TPUs.

It fails this check https://github.com/pytorch/xla/blob/master/torch_xla/csrc/aten_xla_type.cpp#L161

>>> import torch
>>> import torch_xla
>>> x = torch.tensor([1,2,3]).to('xla')
>>> y = torch.tensor([2,4,5]).to('xla')
>>> x*y
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: torch_xla/csrc/aten_xla_type.cpp:161 : Check failed: it != inputs_.end() 
*** Begin stack trace ***
	tsl::CurrentStackTrace[abi:cxx11]()
	torch_xla::XLANativeFunctions::mul(at::Tensor const&, at::Tensor const&)
	
	at::_ops::mul_Tensor::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&)
	
	
	at::_ops::mul_Tensor::call(at::Tensor const&, at::Tensor const&)
	
	PyNumber_Multiply
	_PyEval_EvalFrameDefault
	
	PyEval_EvalCode
	
	_PyRun_InteractiveLoopObject
	
	PyRun_AnyFileExFlags
	
	Py_BytesMain
	
	__libc_start_main
	
*** End stack trace ***

@ysiraichi
Copy link
Copy Markdown
Collaborator Author

Apparently, you are not the only one: #7266

@vanbasten23
Copy link
Copy Markdown
Collaborator

This PR is also impacting DDP:
image

@JackCaoG
Copy link
Copy Markdown
Collaborator

Let me take a look this afternoon

@JackCaoG
Copy link
Copy Markdown
Collaborator

I can't repo this issue but I do see that half of our internal TPU test crashed because of this. Let me revert this pr for now while figuring out what happened.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[torchbench] timm_nfnet training failing on non-dynamo.

5 participants