Skip to content

Makes floor_divide a method, adds sparse floor division#34552

Closed
mruberry wants to merge 10 commits intomasterfrom
floor_divide_method
Closed

Makes floor_divide a method, adds sparse floor division#34552
mruberry wants to merge 10 commits intomasterfrom
floor_divide_method

Conversation

@mruberry
Copy link
Copy Markdown
Collaborator

@mruberry mruberry commented Mar 10, 2020

(Updated per review feedback)

torch.floor_divide is currently a function that can operate on two tensors or a tensor and a scalar (scalar x scalar floor division is handled natively by Python and the JIT has a builtin function for it). This PR updates it to:

  • have an out variant: floor_divide(x, y, out=z)
  • be a method on a tensor: x.floor_divide(y)
  • have an in-place variant: x.floor_divide_(y)
  • work with sparse tensors

Tests are added to test_sparse.py and test_torch.py for these new behaviors.

In addition, this PR:

  • cleans up the existing sparse division and true_division code and improves their error message
  • adds testing of sparse true_division to test_sparse.py
  • extends existing floor_divide testing in test_torch to run on CUDA, too, not just the CPU

Unfortunately, making floor_divide a method requires breaking backwards compatibility, and floor_divide has been added to the BC whitelist since this is international. The BC issue is that the first parameter name to torch.floor_divide is changing from input to self. If you previously called torch.floor_divide with keyword arguments, e.g. torch.floor_divide(input=x, other=y), you will need to update to torch.floor_divide(self=x, other=y), or the more common torch.floor_divide(x, y).

The intent of this PR is to allow floor_divide to be substituted for division (torch.div, /) wherever division was previously used. In 1.6 we expect torch.div to perform true_division, and floor_divide is how users can continue to perform integer division with tensors.

There are two potential follow-up issues suggested by this PR:

  • the test framework might benefit from additional tensor construction classes, like one to create dividends and divisors for multiple dtypes
  • the test framework might benefit from a universal function test class. while methods have reasonable coverage as part of test_torch.py's TestTensorOp tests, function coverage is spotty. Universal functions are similar enough it should be possible to generate tests for them.

@dr-ci
Copy link
Copy Markdown

dr-ci Bot commented Mar 11, 2020

💊 CircleCI build failures summary and remediations

As of commit 4fcf7c1 (more details on the Dr. CI page):


None of the build failures appear to be your fault 💚


  • 1/1 broken upstream at merge base 1afc584 on Mar 18 from 2:35pm to 7:57pm (10 commits; d927d58 - c747f09)

    Please rebase on the viable/strict branch (expand for instructions)

    If your commit is newer than viable/strict, you can try basing on an older, stable commit:

    git fetch https://github.com/pytorch/pytorch viable/strict
    git rebase --onto FETCH_HEAD $(git merge-base origin/master HEAD)
    

    If your commit is older than viable/strict:

    git fetch https://github.com/pytorch/pytorch viable/strict
    git rebase FETCH_HEAD
    

    Check out the recency history of this "viable master" tracking branch.


🚧 1 upstream failure:

These were probably caused by upstream breakages:


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

This comment has been revised 104 times.

@mruberry mruberry requested a review from gchanan March 11, 2020 05:12
@gchanan gchanan added the module: bc-breaking Related to a BC-breaking change label Mar 11, 2020
@gchanan
Copy link
Copy Markdown
Contributor

gchanan commented Mar 11, 2020

looks like you have a merge conflict.

@gchanan
Copy link
Copy Markdown
Contributor

gchanan commented Mar 11, 2020

for the BC change, you should add the function to the exception list in this file: https://github.com/pytorch/pytorch/blob/master/test/backward_compatibility/check_backward_compatibility.py

@gchanan
Copy link
Copy Markdown
Contributor

gchanan commented Mar 11, 2020

BC note:

The first parameter name to torch.floor_divide changed from input to self. If you called torch.floor_divide with keyword arguments, e.g. torch.floor_divide(input=x, other=y), you need to update to torch.floor_divide(self=x, other=y), or the more common torch.floor_divide(x, y).

@mruberry
Copy link
Copy Markdown
Collaborator Author

BC note:

The first parameter name to torch.floor_divide changed from input to self. If you called torch.floor_divide with keyword arguments, e.g. torch.floor_divide(input=x, other=y), you need to update to torch.floor_divide(self=x, other=y), or the more common torch.floor_divide(x, y).

Is there some place this note should go?

Comment thread aten/src/ATen/native/native_functions.yaml
Comment thread test/onnx/test_pytorch_onnx_caffe2.py Outdated
Comment thread test/test_torch.py Outdated
Comment thread test/test_torch.py Outdated
Comment thread test/test_torch.py Outdated
Comment thread test/test_type_promotion.py Outdated
Comment thread aten/src/ATen/native/BinaryOps.cpp Outdated
Comment thread aten/src/ATen/native/BinaryOps.cpp Outdated

auto out = iter.output();
if (out.is_floating_point()) {
return out.trunc();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: since this is out-of-place and we internally allocated the output, can't we do trunc_()? I guess it depends on what we do for autograd -- what do we do?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we can! Great catch!

floor_divide doesn't have an autograd formula and this PR doesn't add one. I don't think it should, either, since we're already sneaking in a few additions (like sparse functionality) beyond the BC-breaking changes. The autograd formula for floor_divide would (based on our current formulas) return an all zeros grad tensor, too, because that's what torch.trunc() does.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the part I was concerned about what the inplace change triggering a version counter problem (i.e. that the version counter gets incremented so we can't backprop through floor_divide, even to get the tensor of all zeros. This would definitely not be a problem if we had a formula for floor_divide (that just set all zeros) because the output would be totally internal, but I'm unsure if it's a problem if we just rely on the autograd formulas for the underlying implementations.

Did you check if you can autograd through this version of floor_divide (to get all zeros?).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also realize I contradicted myself a bit in saying you should minimize the changes here and then asking why there's no sparse implementation :).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You cannot backprop through floor_divide, you'll get an error when trying:

RuntimeError: derivative for floor_divide is not implemented

We could make it return all zeros in this PR (analogous to trunc). Would that address your concerns? It's a simple change.

(No worries.)

Comment thread test/test_jit.py Outdated
Copy link
Copy Markdown
Contributor

@gchanan gchanan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks pretty reasonable, mostly minor changes suggested. But I think this should really be split up into a minimal BC-breaking change.

Aside: does "//" do the right thing? I assume it does, but just wanted to double check.

@mruberry mruberry force-pushed the floor_divide_method branch from 897a8ac to 396b238 Compare March 12, 2020 21:32
@mruberry mruberry changed the title Makes floor_divide a method, uses floor_divide for integer division Makes floor_divide a method, adds sparse floor division Mar 12, 2020
@mruberry
Copy link
Copy Markdown
Collaborator Author

Aside: does "//" do the right thing? I assume it does, but just wanted to double check.

Yes, and in the new tests I wrote it's also used.

Thanks for the excellent review! The new PR is much leaner and more focused. It also improves on testing, adds sparse functionality, and elaborates in the description.

@mruberry mruberry requested a review from gchanan March 12, 2020 22:30
result.resize_as_(dividend_tmp);
auto indices = result._indices();
indices.resize_as_(dividend_tmp.indices());
indices.resize_as_(dividend_tmp._indices());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this an efficiency change? (so you don't have to create a view?).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. This was actually a typo in the true_divide PR. The sparse div implementation consistently uses _indices(). The difference between indices() and _indices() is documented here:

# Sparse Methods API Design
.


// Resizes and indexes result like dividend_tmp
result.resize_as_(dividend_tmp);
result._indices().resize_as_(dividend_tmp._indices());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this also makes me nervous about autograd for the same reason as the in-place change above.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't backward through floor_divide. We may want to address that as a separate issue, although I don't think there's a non-degenerate gradient formula for it.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD what do you think is the right thing to do here?

In any case, since this isn't changing, this isn't a blocking concern for this PR. But let's file an issue if Alban suggests we do something different for consistency reasons.

Comment thread test/test_sparse.py
self.assertEqual(self.safeToDense(y2), expected)

# Note: true_divide does not have a method variant
y1 = torch.true_divide(x1, 37.5)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, why did we add a method for floor_divide but not true_divide?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both true_divide and floor_divide are only functions, not methods, in NumPy. I think you actually brought this up when floor_divide was first being implemented. When true_divide was implemented it respected the same convention.

When we look forward to deprecating integer division using torch.div, however, we don't want users to lose functionality. In particular, we want to provide a mechanism to perform in-place integer division, hence Tensor.floor_divide_.

We could add Tensor.true_divide for consistency. We decided at the time that it was mostly moot since in the near future Tensor.div would be Tensor.true_divide, anyway.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the "future_div" PR I'll make true_divide a method.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #34794

Comment thread test/test_torch.py Outdated
Comment thread test/test_torch.py Outdated
# (e.g. .9999 vs 1 post truncation is 0 vs 1)
('floor_divide', '', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1, 1, 1, _types),
('floor_divide', 'tensor', _small_3d,
lambda t, d: [_small_3d(t, d, has_zeros=False)], 1, 1, 1, _types),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are these actually comparing? CPU vs CUDA results? Those should be the same, right?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests do compare CPU and CUDA results. should is correct. If you expect division on the CPU, CUDA, and TPU to be bitwise equivalent then you wouldn't have an issue. In practice we could set the precision to our standard 1e-5 and it probably would never bother us, either. But if division isn't bitwise-equivalent and the error is such that it could change 0.9999 to 1., for example, then the test would fail when that occurred.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would optimistically assume they are bitwise equivalent and if things start failing back out our assumption.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the precision. Turns out we still need a 1 for torch.half.

@mruberry mruberry requested a review from gchanan March 14, 2020 02:56
Comment thread aten/src/ATen/native/sparse/SparseTensorMath.cpp
Comment thread test/test_torch.py Outdated
Copy link
Copy Markdown
Contributor

@gchanan gchanan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all of the suggestions are pretty straightforward at this point, so no need for a re-review.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@mruberry merged this pull request in b712905.

@mruberry
Copy link
Copy Markdown
Collaborator Author

Ninja unlanding as a persistent Windows build issue occurred at the same time. That issue is not reflected in this PR's CI, suggesting a merge conflict or exogenous change.

@mruberry mruberry reopened this Mar 18, 2020
Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

facebook-github-bot pushed a commit that referenced this pull request Mar 19, 2020
Summary:
Per title.

In the future we want to make div(), the division operator, and addcdiv perform true division as in Python 3, NumPy, and JAX. To do this without silently breaking users we plan to:

- Warn (once) in 1.5 when a user performs integer division using div or addcdiv
- RuntimeError in 1.6 when a user attempts to perform integer division using div or addcdiv
- Always perform true division in 1.7 using div, /, and addcdiv

Users can use true_divide or floor_divide today to explicitly specify the type of division they like.

A test for this behavior is added to test_type_promotion. Unfortunately, because we are only warning once (to avoid a deluge) the test only uses maybeWarns Regex.

The XLA failure is real but will be solved by #34552. I'll be sure to land that PR first to avoid temporarily breaking the XLA build.
Pull Request resolved: #34570

Differential Revision: D20529211

Pulled By: mruberry

fbshipit-source-id: 65af5a9641c5825175d029e8413c9e1730c661d0
@mruberry mruberry deleted the floor_divide_method branch March 29, 2020 07:50
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
(Updated per review feedback)

`torch.floor_divide` is currently a function that can operate on two tensors or a tensor and a scalar (scalar x scalar floor division is handled natively by Python and the JIT has a builtin function for it). This PR updates it to:

- have an out variant: `floor_divide(x, y, out=z)`
- be a method on a tensor: `x.floor_divide(y)`
- have an in-place variant: `x.floor_divide_(y)`
- work with sparse tensors

Tests are added to test_sparse.py and test_torch.py for these new behaviors.

In addition, this PR:

- cleans up the existing sparse division and true_division code and improves their error message
- adds testing of sparse true_division to test_sparse.py
- extends existing floor_divide testing in test_torch to run on CUDA, too, not just the CPU

Unfortunately, making floor_divide a method requires breaking backwards compatibility, and floor_divide has been added to the BC whitelist since this is international. The BC issue is that the first parameter name to torch.floor_divide is changing from input to self. If you previously called torch.floor_divide with keyword arguments, e.g. torch.floor_divide(input=x, other=y), you will need to update to torch.floor_divide(self=x, other=y), or the more common torch.floor_divide(x, y).

The intent of this PR is to allow floor_divide to be substituted for division (torch.div, /) wherever division was previously used. In 1.6 we expect torch.div to perform true_division, and floor_divide is how users can continue to perform integer division with tensors.

There are two potential follow-up issues suggested by this PR:

- the test framework might benefit from additional tensor construction classes, like one to create dividends and divisors for multiple dtypes
- the test framework might benefit from a universal function test class. while methods have reasonable coverage as part of test_torch.py's TestTensorOp tests, function coverage is spotty. Universal functions are similar enough it should be possible to generate tests for them.
Pull Request resolved: pytorch#34552

Differential Revision: D20497453

Pulled By: mruberry

fbshipit-source-id: ac326f2007d8894f730d1278fef84d63bcb07b5d
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
(Updated per review feedback)

`torch.floor_divide` is currently a function that can operate on two tensors or a tensor and a scalar (scalar x scalar floor division is handled natively by Python and the JIT has a builtin function for it). This PR updates it to:

- have an out variant: `floor_divide(x, y, out=z)`
- be a method on a tensor: `x.floor_divide(y)`
- have an in-place variant: `x.floor_divide_(y)`
- work with sparse tensors

Tests are added to test_sparse.py and test_torch.py for these new behaviors.

In addition, this PR:

- cleans up the existing sparse division and true_division code and improves their error message
- adds testing of sparse true_division to test_sparse.py
- extends existing floor_divide testing in test_torch to run on CUDA, too, not just the CPU

Unfortunately, making floor_divide a method requires breaking backwards compatibility, and floor_divide has been added to the BC whitelist since this is international. The BC issue is that the first parameter name to torch.floor_divide is changing from input to self. If you previously called torch.floor_divide with keyword arguments, e.g. torch.floor_divide(input=x, other=y), you will need to update to torch.floor_divide(self=x, other=y), or the more common torch.floor_divide(x, y).

The intent of this PR is to allow floor_divide to be substituted for division (torch.div, /) wherever division was previously used. In 1.6 we expect torch.div to perform true_division, and floor_divide is how users can continue to perform integer division with tensors.

There are two potential follow-up issues suggested by this PR:

- the test framework might benefit from additional tensor construction classes, like one to create dividends and divisors for multiple dtypes
- the test framework might benefit from a universal function test class. while methods have reasonable coverage as part of test_torch.py's TestTensorOp tests, function coverage is spotty. Universal functions are similar enough it should be possible to generate tests for them.
Pull Request resolved: pytorch#34552

Differential Revision: D20509850

Pulled By: mruberry

fbshipit-source-id: 2cd3c828aad67191c77f2ed8470411e246f604f8
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
…34570)

Summary:
Per title.

In the future we want to make div(), the division operator, and addcdiv perform true division as in Python 3, NumPy, and JAX. To do this without silently breaking users we plan to:

- Warn (once) in 1.5 when a user performs integer division using div or addcdiv
- RuntimeError in 1.6 when a user attempts to perform integer division using div or addcdiv
- Always perform true division in 1.7 using div, /, and addcdiv

Users can use true_divide or floor_divide today to explicitly specify the type of division they like.

A test for this behavior is added to test_type_promotion. Unfortunately, because we are only warning once (to avoid a deluge) the test only uses maybeWarns Regex.

The XLA failure is real but will be solved by pytorch#34552. I'll be sure to land that PR first to avoid temporarily breaking the XLA build.
Pull Request resolved: pytorch#34570

Differential Revision: D20529211

Pulled By: mruberry

fbshipit-source-id: 65af5a9641c5825175d029e8413c9e1730c661d0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: bc-breaking Related to a BC-breaking change

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants