Enable in_dims for vmap frontend api#40717
Enable in_dims for vmap frontend api#40717zou3519 wants to merge 3 commits intogh/zou3519/269/basefrom
in_dims for vmap frontend api#40717Conversation
`in_dims` specifies which dimension of the input tensors should be vmapped over. One can also specify `None` as an `in_dim` for a particular input to indicate that we do not map over said input. We implement `in_dims` by creating a BatchedTensor with BatchDim equal to said `in_dim`. Most of this PR is error checking. `in_dims` must satisfy the following: - `in_dim` can be either an int or a Tuple[Optional[int]]. If it is an int, we use it to mean the `in_dim` for every input. - If `in_dims` is not-None at some index `idx`, then the input at index `idx` MUST be a tensor (vmap can only map over tensors). jax supports something more generalized: their `in_dims` can match the structure of the `inputs` to the function (i.e., it is a nested python data structure matching the data structure of `inputs` specifying where in `inputs` the Tensors to be mapped are and what their map dims should be). We don't have the infrastruture yet so we only support `int` or a flat tuple for `in_dims`. Test Plan: - `pytest test/test_vmap.py -v` [ghstack-poisoned]
`in_dims` specifies which dimension of the input tensors should be vmapped over. One can also specify `None` as an `in_dim` for a particular input to indicate that we do not map over said input. We implement `in_dims` by creating a BatchedTensor with BatchDim equal to said `in_dim`. Most of this PR is error checking. `in_dims` must satisfy the following: - `in_dim` can be either an int or a Tuple[Optional[int]]. If it is an int, we use it to mean the `in_dim` for every input. - If `in_dims` is not-None at some index `idx`, then the input at index `idx` MUST be a tensor (vmap can only map over tensors). jax supports something more generalized: their `in_dims` can match the structure of the `inputs` to the function (i.e., it is a nested python data structure matching the data structure of `inputs` specifying where in `inputs` the Tensors to be mapped are and what their map dims should be). We don't have the infrastruture yet so we only support `int` or a flat tuple for `in_dims`. Test Plan: - `pytest test/test_vmap.py -v` [ghstack-poisoned]
`in_dims` specifies which dimension of the input tensors should be vmapped over. One can also specify `None` as an `in_dim` for a particular input to indicate that we do not map over said input. We implement `in_dims` by creating a BatchedTensor with BatchDim equal to said `in_dim`. Most of this PR is error checking. `in_dims` must satisfy the following: - `in_dim` can be either an int or a Tuple[Optional[int]]. If it is an int, we use it to mean the `in_dim` for every input. - If `in_dims` is not-None at some index `idx`, then the input at index `idx` MUST be a tensor (vmap can only map over tensors). jax supports something more generalized: their `in_dims` can match the structure of the `inputs` to the function (i.e., it is a nested python data structure matching the data structure of `inputs` specifying where in `inputs` the Tensors to be mapped are and what their map dims should be). We don't have the infrastruture yet so we only support `int` or a flat tuple for `in_dims`. Test Plan: - `pytest test/test_vmap.py -v` ghstack-source-id: f77f2dc Pull Request resolved: #40717
💊 CI failures summary and remediationsAs of commit 7259721 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 7 times. |
| fn=fn_name, in_dims=in_dims, num_inputs=len(args))) | ||
|
|
||
| if len(args) == 0: | ||
| raise ValueError(NO_INPUTS.format(fn=fn_name)) |
There was a problem hiding this comment.
Is there any reason to block the zero length args case, besides "then vmap doesn't do anything"? I'm thinking of how people have found it useful to do zero-size batches; it may be harmless to have zero length args (unless it is not?)
There was a problem hiding this comment.
The only two reasons I have are: (1) "then vmap doesn't do anything" and (2) "jax doesn't allow it". I agree that it seems harmless to have zero-length arguments.
It's not too hard to modify this to work so I'll add this as a follow-up for later (and think more about if it is actually harmless or not)
| EXPECTED_IN_DIMS_TO_BE_INT_OR_TUPLE = ( | ||
| 'vmap({fn}, in_dims={in_dims}, ...): expected `in_dims` to be int or tuple, ' | ||
| 'got: {actual_type}.' | ||
| ) |
There was a problem hiding this comment.
No action needed on this comment: I personally prefer having messages inline at their use sites, if they're only used once. Makes it easier to see what the error message is and ensure that the format string is up to date :) (also, you can't use f-strings in this style!)
There was a problem hiding this comment.
Did we drop support for Python < 3.6? (I know we dropped support for Python 2, but I didn't realize the < 3.6 part)
I agree with your comment, reading 66 lines of error messages at the top of the file and away from the callsites makes me sad. Will fix in a follow-up.
|
|
||
| # Check compatibility of `in_dims` and `args`. More specifically, checks the following: | ||
| # Wherever an in_dim is not None, then the corresponding index in args must be | ||
| # a Tensor. Furthermore, tensor must have the `in_dim` (0 <= in_dim < tensor.dim()) |
There was a problem hiding this comment.
Type signature on this function (and the others) would be very helpful!
There was a problem hiding this comment.
I added type hints to all of the functions :D. We'll have to relax some of these in the future when we support accepting arbitrary nested python data structures, but these do make the code easier to read now.
| # Wherever an in_dim is not None, then the corresponding index in args must be | ||
| # a Tensor. Furthermore, tensor must have the `in_dim` (0 <= in_dim < tensor.dim()) | ||
| def _check_args_can_be_mapped_with_in_dims(in_dims_as_tuple, args, fn_name, in_dims): | ||
| for idx, (in_dim, arg) in enumerate(zip(in_dims_as_tuple, args)): |
There was a problem hiding this comment.
If you extend this to work on arbitrary Python collections as opposed to just tuples, zipping here isn't going to work anymore, right? Would we expect in-dims to also have the same "shape" as args, in this case?
There was a problem hiding this comment.
That's correct. Extending this to work on arbitrary Python collections would make it so that we need new error validation code here. Furthermore, we'd expect in_dims to have the same "shape" as args.
`in_dims` specifies which dimension of the input tensors should be vmapped over. One can also specify `None` as an `in_dim` for a particular input to indicate that we do not map over said input. We implement `in_dims` by creating a BatchedTensor with BatchDim equal to said `in_dim`. Most of this PR is error checking. `in_dims` must satisfy the following: - `in_dim` can be either an int or a Tuple[Optional[int]]. If it is an int, we use it to mean the `in_dim` for every input. - If `in_dims` is not-None at some index `idx`, then the input at index `idx` MUST be a tensor (vmap can only map over tensors). jax supports something more generalized: their `in_dims` can match the structure of the `inputs` to the function (i.e., it is a nested python data structure matching the data structure of `inputs` specifying where in `inputs` the Tensors to be mapped are and what their map dims should be). We don't have the infrastruture yet so we only support `int` or a flat tuple for `in_dims`. Test Plan: - `pytest test/test_vmap.py -v` [ghstack-poisoned]
`in_dims` specifies which dimension of the input tensors should be vmapped over. One can also specify `None` as an `in_dim` for a particular input to indicate that we do not map over said input. We implement `in_dims` by creating a BatchedTensor with BatchDim equal to said `in_dim`. Most of this PR is error checking. `in_dims` must satisfy the following: - `in_dim` can be either an int or a Tuple[Optional[int]]. If it is an int, we use it to mean the `in_dim` for every input. - If `in_dims` is not-None at some index `idx`, then the input at index `idx` MUST be a tensor (vmap can only map over tensors). jax supports something more generalized: their `in_dims` can match the structure of the `inputs` to the function (i.e., it is a nested python data structure matching the data structure of `inputs` specifying where in `inputs` the Tensors to be mapped are and what their map dims should be). We don't have the infrastruture yet so we only support `int` or a flat tuple for `in_dims`. Test Plan: - `pytest test/test_vmap.py -v` ghstack-source-id: 36d06d9 Pull Request resolved: #40717
Summary: Pull Request resolved: pytorch#40717 `in_dims` specifies which dimension of the input tensors should be vmapped over. One can also specify `None` as an `in_dim` for a particular input to indicate that we do not map over said input. We implement `in_dims` by creating a BatchedTensor with BatchDim equal to said `in_dim`. Most of this PR is error checking. `in_dims` must satisfy the following: - `in_dim` can be either an int or a Tuple[Optional[int]]. If it is an int, we use it to mean the `in_dim` for every input. - If `in_dims` is not-None at some index `idx`, then the input at index `idx` MUST be a tensor (vmap can only map over tensors). jax supports something more generalized: their `in_dims` can match the structure of the `inputs` to the function (i.e., it is a nested python data structure matching the data structure of `inputs` specifying where in `inputs` the Tensors to be mapped are and what their map dims should be). We don't have the infrastruture yet so we only support `int` or a flat tuple for `in_dims`. Test Plan: - `pytest test/test_vmap.py -v` Differential Revision: D22397914 Pulled By: zou3519 fbshipit-source-id: 56d2e14be8b6024e4cde2729eff384da305b4ea3
This reverts commit 5ff9f58.
Summary: Pull Request resolved: pytorch#40717 `in_dims` specifies which dimension of the input tensors should be vmapped over. One can also specify `None` as an `in_dim` for a particular input to indicate that we do not map over said input. We implement `in_dims` by creating a BatchedTensor with BatchDim equal to said `in_dim`. Most of this PR is error checking. `in_dims` must satisfy the following: - `in_dim` can be either an int or a Tuple[Optional[int]]. If it is an int, we use it to mean the `in_dim` for every input. - If `in_dims` is not-None at some index `idx`, then the input at index `idx` MUST be a tensor (vmap can only map over tensors). jax supports something more generalized: their `in_dims` can match the structure of the `inputs` to the function (i.e., it is a nested python data structure matching the data structure of `inputs` specifying where in `inputs` the Tensors to be mapped are and what their map dims should be). We don't have the infrastruture yet so we only support `int` or a flat tuple for `in_dims`. Test Plan: - `pytest test/test_vmap.py -v` Differential Revision: D22397914 Pulled By: zou3519 fbshipit-source-id: 56d2e14be8b6024e4cde2729eff384da305b4ea3
Stack from ghstack:
in_dimsfor vmap frontend api #40717 Enablein_dimsfor vmap frontend apiin_dimsspecifies which dimension of the input tensors should bevmapped over. One can also specify
Noneas anin_dimfor a particularinput to indicate that we do not map over said input.
We implement
in_dimsby creating a BatchedTensor with BatchDim equalto said
in_dim. Most of this PR is error checking.in_dimsmustsatisfy the following:
in_dimcan be either an int or a Tuple[Optional[int]]. If it is anint, we use it to mean the
in_dimfor every input.in_dimsis not-None at some indexidx, then the input at indexidxMUST be a tensor (vmap can only map over tensors).jax supports something more generalized: their
in_dimscan match thestructure of the
inputsto the function (i.e., it is a nested pythondata structure matching the data structure of
inputsspecifying wherein
inputsthe Tensors to be mapped are and what their map dims shouldbe). We don't have the infrastruture yet so we only support
intor aflat tuple for
in_dims.Test Plan:
pytest test/test_vmap.py -vDifferential Revision: D22397914