[JIT] Support a single TensorList argument anywhere in the argument list#8141
[JIT] Support a single TensorList argument anywhere in the argument list#8141jamesr66a wants to merge 5 commits intopytorch:masterfrom
Conversation
cmake/Modules/FindMKL.cmake
Outdated
| SET(MKL_VERSION) | ||
| ENDIF (MKL_LIBRARIES) | ||
|
|
||
| set(MKL_FIND_REQUIRED Off) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| if sum(arg['simple_type'] == 'TensorList' for arg in arguments) > 1: | ||
| return False | ||
| # omit einsum | ||
| if any(arg['simple_type'] == 'std::string' for arg in arguments): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| real_inputs = count() | ||
| if has_tensorlist: | ||
| kw_assignments.append('size_t _Arg_idx = 0;') |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| # the first argument, that is then followed by a number of positional args. | ||
| if arg['simple_type'] == 'TensorList': | ||
| arguments.append('peekSlice(stack, 0, varargs_length - {}, varargs_length)'.format(static_inputs)) | ||
| kw_assignments.append('size_t _Arg_idx_{} = _Arg_idx;'.format(i)) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
I think a tweak to this approach would make it more robust:
Let's say you have
a b c [d e f g] h i # [] is the list
We extract a, b, c, with peek(0, N) where N is the number of arguments. We can adjust N before and after the varargs list is encountered to make peek(i, N) still always extract the ith static argument.
Before tensor_list, set N to (num_static_args - 1 + num_args_in_list)
a b c [d e f g] h i # [] is the list
~~~~~~~~~~~~ <- N
When we handle the varargs_list, then adjust N to be
num_static_args
a b c [d e f g] h i # [] is the list
~~~~~~ <- N
| if has_tensorlist: | ||
| kw_assignments.append('size_t _Arg_idx = 0;') | ||
| for i, arg in enumerate(decl['arguments']): | ||
| # XXX: we currently support only TensorList ops that have a TensorList as |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| elif is_tensor_arg(arg): | ||
| arguments.append('std::move(peek(stack, {}, {}))'.format(next(real_inputs), static_inputs)) | ||
| if has_tensorlist: | ||
| kw_assignments.append('size_t _Arg_idx_{} = _Arg_idx++;'.format(i)) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
No description provided.