🚀 Feature
torch.nonzero should accept an optional size argument to produce statically shaped output to (1) avoid host-device synchronization and (2) avoid truncating CUDA/LazyTensor IR graph during tracing.
Motivation
Because the number of nonzero elements is data-dependent, we currently fetch the number of nonzero elements from device, and allocates the corresponding output tensor using the shape information at host runtime. To make the shape information available at host, a host-device synchronization is injected after the fetch of an integer, which stalls the pipeline between Python frontend and device runtime, and also breaks the CUDA/LazyTensor IR graph as we need actual execution to correctly get the concrete shape information and continue the tracing.
Pitch
Similar to the nonzero API in JAX, we add a size argument to torch.nonzero:
If specified, the first size nonzero elements will be returned; if there are fewer nonzero elements than size indicates, the index arrays will be zero-padded.
In addition to zero-padding, we could also return a counts device tensor to specify the number of nonzero elements at runtime. Similar to what NVIDIA did to enable CUDA graph in MaskRCNN-ResNet50 MLPerf training benchmark v1.0.
Alternatives
Add a runtime_shape tensor field to the tensor class, which will require all ATen operators to be aware of the possible divergence of runtime shape (as tensor of integers on device) and statically inferred shape bound (as c10::IntArray).
Additional context
The following operators either directly call at::nonzero or have similar machinery:
index.Tensor
masked_select
_unique2
nms
PyTorch-XLA encodes the runtime shape information in the opaque tensor handle and does not truncate IR graph. However, it is not sound when users inspect output.shape in Python which is merely a statically inferred shape bound.
torch.jit.trace does shape erasure to support dynamic shape when shape dynamism does not alter the IR graph.
cc @ezyang @gchanan @zou3519 @ngimel @gmagogsfm @ailzhang @asuhan @JackCaoG @mcarilli
🚀 Feature
torch.nonzeroshould accept an optionalsizeargument to produce statically shaped output to (1) avoid host-device synchronization and (2) avoid truncating CUDA/LazyTensor IR graph during tracing.Motivation
Because the number of nonzero elements is data-dependent, we currently fetch the number of nonzero elements from device, and allocates the corresponding output tensor using the shape information at host runtime. To make the shape information available at host, a host-device synchronization is injected after the fetch of an integer, which stalls the pipeline between Python frontend and device runtime, and also breaks the CUDA/LazyTensor IR graph as we need actual execution to correctly get the concrete shape information and continue the tracing.
Pitch
Similar to the nonzero API in JAX, we add a
sizeargument totorch.nonzero:In addition to zero-padding, we could also return a
countsdevice tensor to specify the number of nonzero elements at runtime. Similar to what NVIDIA did to enable CUDA graph in MaskRCNN-ResNet50 MLPerf training benchmark v1.0.Alternatives
Add a
runtime_shapetensor field to the tensor class, which will require all ATen operators to be aware of the possible divergence of runtime shape (as tensor of integers on device) and statically inferred shape bound (asc10::IntArray).Additional context
The following operators either directly call
at::nonzeroor have similar machinery:PyTorch-XLA encodes the runtime shape information in the opaque tensor handle and does not truncate IR graph. However, it is not sound when users inspect
output.shapein Python which is merely a statically inferred shape bound.torch.jit.tracedoes shape erasure to support dynamic shape when shape dynamism does not alter the IR graph.cc @ezyang @gchanan @zou3519 @ngimel @gmagogsfm @ailzhang @asuhan @JackCaoG @mcarilli