Skip to content

Fix overflow for div arguments.#7081

Merged
qihqi merged 1 commit intomasterfrom
ysiraichi/fix-div-overflow
May 20, 2024
Merged

Fix overflow for div arguments.#7081
qihqi merged 1 commit intomasterfrom
ysiraichi/fix-div-overflow

Conversation

@ysiraichi
Copy link
Copy Markdown
Collaborator

This PR fixes the div(Tensor, Scalar) operation implementation.

Problem: consider div(tensor(..., dtype=half), 1_000_000)

  • GetIrValueForScalar will attempt to convert the scalar into a tensor of dtype=half
  • Fails because 1_000_000 is beyond half max value

Solution: use another type for these mathematical operations

  • PyTorch makes use of at::OpMathType trait
  • Cast the arguments to that type, and then cast the result back

Affected Benchmarks

  • (non-dynamo) Super_SloMo training

cc @miladm @JackCaoG

@qihqi qihqi merged commit a2540ac into master May 20, 2024
@qihqi qihqi deleted the ysiraichi/fix-div-overflow branch May 20, 2024 17:20
Comment thread torch_xla/csrc/tensor_methods.cpp
zpcore pushed a commit that referenced this pull request May 20, 2024
@vanbasten23
Copy link
Copy Markdown
Collaborator

I wonder what error did you see before this fix.

@ysiraichi
Copy link
Copy Markdown
Collaborator Author

Traceback (most recent call last):
  File "xla/benchmarks/experiment_runner.py", line 945, in <module>
    main()
  File "xla/benchmarks/experiment_runner.py", line 941, in main
    runner.run()
  File "xla/benchmarks/experiment_runner.py", line 61, in run
    self.run_single_config()
  File "xla/benchmarks/experiment_runner.py", line 256, in run_single_config
    metrics, last_output = self.run_once_and_gather_metrics(
  File "xla/benchmarks/experiment_runner.py", line 345, in run_once_and_gather_metrics
    output, _ = loop(iter_fn=self._default_iter_fn)
  File "xla/benchmarks/experiment_runner.py", line 302, in loop
    output, timing, trace = iter_fn(benchmark_experiment, benchmark_model,
  File "xla/benchmarks/experiment_runner.py", line 218, in _default_iter_fn
    output = benchmark_model.model_iter_fn(
  File "xla/benchmarks/torchbench_model.py", line 411, in train
    super().train(inputs, collect_full_output=collect_full_output)
  File "xla/benchmarks/benchmark_model.py", line 160, in train
    loss.backward()
  File "torch/_tensor.py", line 523, in backward
    torch.autograd.backward(
  File "torch/autograd/__init__.py", line 267, in backward
    _engine_run_backward(
  File "torch/autograd/graph.py", line 767, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: value cannot be converted to type at::Half without overflow

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants