@@ -195,6 +195,32 @@ void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, torch::j
195195 torch::jit::push (stack, self);
196196}
197197
198+ static Tensor safeStack (TensorList tensors) {
199+ auto is_defined = [](const Tensor& t) { return t.defined (); };
200+ if (std::all_of (tensors.begin (), tensors.end (), is_defined)) {
201+ return at::stack (tensors);
202+ }
203+ // NOTE [vmap through backward and undefined grad]
204+ // While vmapping through backward functions (to compute batched grad), it
205+ // is possible for the backward function to return an undefined grad for some
206+ // grad_input for each example. In that case, we return an undefined grad.
207+ //
208+ // It is theoretically posssible for *some* of the examples to produce an
209+ // undefined grad (a kernel could peek at the gradient values and return an
210+ // undefined tensor if it determines the gradient is full of zeros). We
211+ // could handle this by treating the undefined grad as a zero-filled tensor
212+ // of the correct shape while stacking the tensors together. However I expect
213+ // this to happen very rarely (I have not been able to find an example in our
214+ // codebase) so we just error out in this case.
215+ if (std::none_of (tensors.begin (), tensors.end (), is_defined)) {
216+ return Tensor ();
217+ }
218+ TORCH_CHECK (false ,
219+ " vmap: slow fallback received a mix of undefined and defined tensors " ,
220+ " as the result of an operation. This is not supported, please file us " ,
221+ " an issue on github." );
222+ }
223+
198224// The general flow of the algorithm is as follows.
199225// - First, we figure out which arguments are BatchedTensors and save them
200226// to a vector. We also store a vector of which index of the arguments list
@@ -318,7 +344,12 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta
318344 auto output_shards_chunks = MatrixRef<Tensor>(output_shards, num_batches);
319345 for (int64_t return_idx = 0 ; return_idx < num_returns; ++return_idx) {
320346 auto shards = output_shards_chunks[return_idx];
321- auto flat_output = at::stack (shards);
347+ auto flat_output = safeStack (shards);
348+ // See NOTE [vmap through backward and undefined grad]
349+ if (!flat_output.defined ()) {
350+ torch::jit::push (stack, flat_output);
351+ continue ;
352+ }
322353 VmapDimVector output_sizes (batch_sizes);
323354 output_sizes.insert (
324355 output_sizes.end (),
0 commit comments