Skip to content

Enable in_dims for vmap frontend api#40717

Closed
zou3519 wants to merge 3 commits intogh/zou3519/269/basefrom
gh/zou3519/269/head
Closed

Enable in_dims for vmap frontend api#40717
zou3519 wants to merge 3 commits intogh/zou3519/269/basefrom
gh/zou3519/269/head

Conversation

@zou3519
Copy link
Copy Markdown
Contributor

@zou3519 zou3519 commented Jun 29, 2020

Stack from ghstack:

in_dims specifies which dimension of the input tensors should be
vmapped over. One can also specify None as an in_dim for a particular
input to indicate that we do not map over said input.

We implement in_dims by creating a BatchedTensor with BatchDim equal
to said in_dim. Most of this PR is error checking. in_dims must
satisfy the following:

  • in_dim can be either an int or a Tuple[Optional[int]]. If it is an
    int, we use it to mean the in_dim for every input.
  • If in_dims is not-None at some index idx, then the input at index
    idx MUST be a tensor (vmap can only map over tensors).

jax supports something more generalized: their in_dims can match the
structure of the inputs to the function (i.e., it is a nested python
data structure matching the data structure of inputs specifying where
in inputs the Tensors to be mapped are and what their map dims should
be). We don't have the infrastruture yet so we only support int or a
flat tuple for in_dims.

Test Plan:

  • pytest test/test_vmap.py -v

Differential Revision: D22397914

`in_dims` specifies which dimension of the input tensors should be
vmapped over. One can also specify `None` as an `in_dim` for a particular
input to indicate that we do not map over said input.

We implement `in_dims` by creating a BatchedTensor with BatchDim equal
to said `in_dim`. Most of this PR is error checking. `in_dims` must
satisfy the following:
- `in_dim` can be either an int or a Tuple[Optional[int]]. If it is an
int, we use it to mean the `in_dim` for every input.
- If `in_dims` is not-None at some index `idx`, then the input at index
`idx` MUST be a tensor (vmap can only map over tensors).

jax supports something more generalized: their `in_dims` can match the
structure of the `inputs` to the function (i.e., it is a nested python
data structure matching the data structure of `inputs` specifying where
in `inputs` the Tensors to be mapped are and what their map dims should
be). We don't have the infrastruture yet so we only support `int` or a
flat tuple for `in_dims`.

Test Plan:
- `pytest test/test_vmap.py -v`

[ghstack-poisoned]
`in_dims` specifies which dimension of the input tensors should be
vmapped over. One can also specify `None` as an `in_dim` for a particular
input to indicate that we do not map over said input.

We implement `in_dims` by creating a BatchedTensor with BatchDim equal
to said `in_dim`. Most of this PR is error checking. `in_dims` must
satisfy the following:
- `in_dim` can be either an int or a Tuple[Optional[int]]. If it is an
int, we use it to mean the `in_dim` for every input.
- If `in_dims` is not-None at some index `idx`, then the input at index
`idx` MUST be a tensor (vmap can only map over tensors).

jax supports something more generalized: their `in_dims` can match the
structure of the `inputs` to the function (i.e., it is a nested python
data structure matching the data structure of `inputs` specifying where
in `inputs` the Tensors to be mapped are and what their map dims should
be). We don't have the infrastruture yet so we only support `int` or a
flat tuple for `in_dims`.

Test Plan:
- `pytest test/test_vmap.py -v`

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Jun 29, 2020
`in_dims` specifies which dimension of the input tensors should be
vmapped over. One can also specify `None` as an `in_dim` for a particular
input to indicate that we do not map over said input.

We implement `in_dims` by creating a BatchedTensor with BatchDim equal
to said `in_dim`. Most of this PR is error checking. `in_dims` must
satisfy the following:
- `in_dim` can be either an int or a Tuple[Optional[int]]. If it is an
int, we use it to mean the `in_dim` for every input.
- If `in_dims` is not-None at some index `idx`, then the input at index
`idx` MUST be a tensor (vmap can only map over tensors).

jax supports something more generalized: their `in_dims` can match the
structure of the `inputs` to the function (i.e., it is a nested python
data structure matching the data structure of `inputs` specifying where
in `inputs` the Tensors to be mapped are and what their map dims should
be). We don't have the infrastruture yet so we only support `int` or a
flat tuple for `in_dims`.

Test Plan:
- `pytest test/test_vmap.py -v`

ghstack-source-id: f77f2dc
Pull Request resolved: #40717
@dr-ci
Copy link
Copy Markdown

dr-ci Bot commented Jun 29, 2020

💊 CI failures summary and remediations

As of commit 7259721 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 7 times.

@zou3519 zou3519 requested review from cpuhrsch and ezyang June 29, 2020 21:19
Comment thread torch/_vmap_internals.py
fn=fn_name, in_dims=in_dims, num_inputs=len(args)))

if len(args) == 0:
raise ValueError(NO_INPUTS.format(fn=fn_name))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason to block the zero length args case, besides "then vmap doesn't do anything"? I'm thinking of how people have found it useful to do zero-size batches; it may be harmless to have zero length args (unless it is not?)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only two reasons I have are: (1) "then vmap doesn't do anything" and (2) "jax doesn't allow it". I agree that it seems harmless to have zero-length arguments.

It's not too hard to modify this to work so I'll add this as a follow-up for later (and think more about if it is actually harmless or not)

Comment thread torch/_vmap_internals.py
EXPECTED_IN_DIMS_TO_BE_INT_OR_TUPLE = (
'vmap({fn}, in_dims={in_dims}, ...): expected `in_dims` to be int or tuple, '
'got: {actual_type}.'
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No action needed on this comment: I personally prefer having messages inline at their use sites, if they're only used once. Makes it easier to see what the error message is and ensure that the format string is up to date :) (also, you can't use f-strings in this style!)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we drop support for Python < 3.6? (I know we dropped support for Python 2, but I didn't realize the < 3.6 part)

I agree with your comment, reading 66 lines of error messages at the top of the file and away from the callsites makes me sad. Will fix in a follow-up.

Comment thread torch/_vmap_internals.py

# Check compatibility of `in_dims` and `args`. More specifically, checks the following:
# Wherever an in_dim is not None, then the corresponding index in args must be
# a Tensor. Furthermore, tensor must have the `in_dim` (0 <= in_dim < tensor.dim())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type signature on this function (and the others) would be very helpful!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RIP Python 2!!!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added type hints to all of the functions :D. We'll have to relax some of these in the future when we support accepting arbitrary nested python data structures, but these do make the code easier to read now.

Comment thread torch/_vmap_internals.py
# Wherever an in_dim is not None, then the corresponding index in args must be
# a Tensor. Furthermore, tensor must have the `in_dim` (0 <= in_dim < tensor.dim())
def _check_args_can_be_mapped_with_in_dims(in_dims_as_tuple, args, fn_name, in_dims):
for idx, (in_dim, arg) in enumerate(zip(in_dims_as_tuple, args)):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you extend this to work on arbitrary Python collections as opposed to just tuples, zipping here isn't going to work anymore, right? Would we expect in-dims to also have the same "shape" as args, in this case?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct. Extending this to work on arbitrary Python collections would make it so that we need new error validation code here. Furthermore, we'd expect in_dims to have the same "shape" as args.

`in_dims` specifies which dimension of the input tensors should be
vmapped over. One can also specify `None` as an `in_dim` for a particular
input to indicate that we do not map over said input.

We implement `in_dims` by creating a BatchedTensor with BatchDim equal
to said `in_dim`. Most of this PR is error checking. `in_dims` must
satisfy the following:
- `in_dim` can be either an int or a Tuple[Optional[int]]. If it is an
int, we use it to mean the `in_dim` for every input.
- If `in_dims` is not-None at some index `idx`, then the input at index
`idx` MUST be a tensor (vmap can only map over tensors).

jax supports something more generalized: their `in_dims` can match the
structure of the `inputs` to the function (i.e., it is a nested python
data structure matching the data structure of `inputs` specifying where
in `inputs` the Tensors to be mapped are and what their map dims should
be). We don't have the infrastruture yet so we only support `int` or a
flat tuple for `in_dims`.

Test Plan:
- `pytest test/test_vmap.py -v`

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Jul 6, 2020
`in_dims` specifies which dimension of the input tensors should be
vmapped over. One can also specify `None` as an `in_dim` for a particular
input to indicate that we do not map over said input.

We implement `in_dims` by creating a BatchedTensor with BatchDim equal
to said `in_dim`. Most of this PR is error checking. `in_dims` must
satisfy the following:
- `in_dim` can be either an int or a Tuple[Optional[int]]. If it is an
int, we use it to mean the `in_dim` for every input.
- If `in_dims` is not-None at some index `idx`, then the input at index
`idx` MUST be a tensor (vmap can only map over tensors).

jax supports something more generalized: their `in_dims` can match the
structure of the `inputs` to the function (i.e., it is a nested python
data structure matching the data structure of `inputs` specifying where
in `inputs` the Tensors to be mapped are and what their map dims should
be). We don't have the infrastruture yet so we only support `int` or a
flat tuple for `in_dims`.

Test Plan:
- `pytest test/test_vmap.py -v`

ghstack-source-id: 36d06d9
Pull Request resolved: #40717
@facebook-github-bot
Copy link
Copy Markdown
Contributor

@zou3519 merged this pull request in 5d1d8a5.

csarofeen pushed a commit to csarofeen/pytorch that referenced this pull request Jul 7, 2020
Summary:
Pull Request resolved: pytorch#40717

`in_dims` specifies which dimension of the input tensors should be
vmapped over. One can also specify `None` as an `in_dim` for a particular
input to indicate that we do not map over said input.

We implement `in_dims` by creating a BatchedTensor with BatchDim equal
to said `in_dim`. Most of this PR is error checking. `in_dims` must
satisfy the following:
- `in_dim` can be either an int or a Tuple[Optional[int]]. If it is an
int, we use it to mean the `in_dim` for every input.
- If `in_dims` is not-None at some index `idx`, then the input at index
`idx` MUST be a tensor (vmap can only map over tensors).

jax supports something more generalized: their `in_dims` can match the
structure of the `inputs` to the function (i.e., it is a nested python
data structure matching the data structure of `inputs` specifying where
in `inputs` the Tensors to be mapped are and what their map dims should
be). We don't have the infrastruture yet so we only support `int` or a
flat tuple for `in_dims`.

Test Plan: - `pytest test/test_vmap.py -v`

Differential Revision: D22397914

Pulled By: zou3519

fbshipit-source-id: 56d2e14be8b6024e4cde2729eff384da305b4ea3
@facebook-github-bot facebook-github-bot deleted the gh/zou3519/269/head branch July 10, 2020 14:18
csarofeen added a commit to csarofeen/pytorch that referenced this pull request Aug 16, 2020
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
Pull Request resolved: pytorch#40717

`in_dims` specifies which dimension of the input tensors should be
vmapped over. One can also specify `None` as an `in_dim` for a particular
input to indicate that we do not map over said input.

We implement `in_dims` by creating a BatchedTensor with BatchDim equal
to said `in_dim`. Most of this PR is error checking. `in_dims` must
satisfy the following:
- `in_dim` can be either an int or a Tuple[Optional[int]]. If it is an
int, we use it to mean the `in_dim` for every input.
- If `in_dims` is not-None at some index `idx`, then the input at index
`idx` MUST be a tensor (vmap can only map over tensors).

jax supports something more generalized: their `in_dims` can match the
structure of the `inputs` to the function (i.e., it is a nested python
data structure matching the data structure of `inputs` specifying where
in `inputs` the Tensors to be mapped are and what their map dims should
be). We don't have the infrastruture yet so we only support `int` or a
flat tuple for `in_dims`.

Test Plan: - `pytest test/test_vmap.py -v`

Differential Revision: D22397914

Pulled By: zou3519

fbshipit-source-id: 56d2e14be8b6024e4cde2729eff384da305b4ea3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants