Implement torch.broadcast_tensors#10075
Conversation
|
Nice! |
|
yes let's match numpy and do |
|
Should it be named "broadcast_arrays" or "broadcast_tensors"? @fmassa and I were thinking broadcast_tensors because we call our data "tensors" but numpy calls their data "arrays". I'll also try to implement varargs for this via a python wrapper function, it shouldn't be too bad. |
|
Nice! The additional behavior of # in torch/distributions/utils.py
def broadcast_all(*values):
"""docstring"""
if not all(map(torch.is_tensor, values)):
# promote floats to tensors
new_tensor = torch.tensor
for value in values:
if torch.is_tensor(value):
new_tensor = value.new_tensor
break
values = [v if torch.is_tensor(v) else new_tensor(v) for v in values]
return torch.broadcast_arrays(*values)This would also be a great test to see that distributions are compatible with the new version 😄 |
|
Or even put the scalar broadcasting in |
|
@vadimkantorov I think there was a discussion around using |
This exposes expand_outplace to python. Fixes pytorch#8076. Fixes pytorch#10041. I didn't name it torch.broadcast because numpy.broadcast does something slightly different (it returns an object with the correct shape information). Test Plan: new test_torch, test_autograd tests.
- s/broadcast_all/broadcast_tensors/ - broadcast_tensors now takes varargs
2909f12 to
c5e418b
Compare
|
This should be good for review, despite the hanging tests. I updated the following: |
torch/distributions/utils.py
Outdated
| return values | ||
| if not all(map(torch.is_tensor, values)): | ||
| # promote numbers to tensors of dtype torch.get_default_dtype() | ||
| def default_promotion(v): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/distributions/utils.py
Outdated
| scalar_idxs = [i for i in range(len(values)) if isinstance(values[i], Number)] | ||
| tensor_idxs = [i for i in range(len(values)) if values[i].__class__.__name__ == 'Tensor'] | ||
| if len(scalar_idxs) + len(tensor_idxs) != len(values): | ||
| if not all(map(lambda v: torch.is_tensor(v) or isinstance(v, Number), values)): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
facebook-github-bot
left a comment
There was a problem hiding this comment.
zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: This exposes expand_outplace to python. Fixes #8076. Fixes #10041. I didn't name it torch.broadcast because numpy.broadcast does something slightly different (it returns an object with the correct shape information). Pull Request resolved: pytorch/pytorch#10075 Differential Revision: D9125816 Pulled By: zou3519 fbshipit-source-id: ebe17c8bb54a73ec84b8f76ce14aff3e9c56f4d1
|
This change is great. How can I start to use it? It does not appear to have landed in master yet. EDIT Sorry, I was pointing to a fork 😊 |
|
It should be on master, the following works for me on a latest checkout: |
Summary: This exposes expand_outplace to python. Fixes pytorch#8076. Fixes pytorch#10041. I didn't name it torch.broadcast because numpy.broadcast does something slightly different (it returns an object with the correct shape information). Pull Request resolved: pytorch#10075 Differential Revision: D9125816 Pulled By: zou3519 fbshipit-source-id: ebe17c8bb54a73ec84b8f76ce14aff3e9c56f4d1
Summary: This uses zou3519's new `torch.broadcast_tensors()` #10075 to make `Categorical.log_prob()` and the `*Normal.__init__()` methods jittable. Previously `.log_prob()` was failing due to calls to `torch._C.infer_size()` with errors like ``` def log_prob(self, value): if self._validate_args: self._validate_sample(value) > value_shape = torch._C._infer_size(value.size(), self.batch_shape) if self.batch_shape else value.size() E RuntimeError: expected int at position 0, but got: Tensor ``` After this change I'm able to jit many more of Pyro's tests. Reviewed By: ezyang Differential Revision: D9477487 Pulled By: apaszke fbshipit-source-id: 5f39b29c6b8fa606ad30b02fefe2dfb618e883d6
Summary: This uses zou3519's new `torch.broadcast_tensors()` pytorch#10075 to make `Categorical.log_prob()` and the `*Normal.__init__()` methods jittable. Previously `.log_prob()` was failing due to calls to `torch._C.infer_size()` with errors like ``` def log_prob(self, value): if self._validate_args: self._validate_sample(value) > value_shape = torch._C._infer_size(value.size(), self.batch_shape) if self.batch_shape else value.size() E RuntimeError: expected int at position 0, but got: Tensor ``` After this change I'm able to jit many more of Pyro's tests. Reviewed By: ezyang Differential Revision: D9477487 Pulled By: apaszke fbshipit-source-id: 5f39b29c6b8fa606ad30b02fefe2dfb618e883d6
This exposes expand_outplace to python. Fixes #8076. Fixes #10041.
I didn't name it torch.broadcast because numpy.broadcast does something
slightly different (it returns an object with the correct shape
information).
Test Plan: new test_torch, test_autograd tests.
cc @soumith @fritzo