@@ -98,23 +98,25 @@ static inline void check_cat_shape_except_dim(const Tensor & first, const Tensor
9898 if (dim == dimension) {
9999 continue ;
100100 }
101- int64_t first_dim_size = first.size (dim) ;
102- int64_t second_dim_size = second.size (dim) ;
101+ int64_t first_dim_size = first.sizes ()[dim] ;
102+ int64_t second_dim_size = second.sizes ()[dim] ;
103103 TORCH_CHECK (first_dim_size == second_dim_size, " Sizes of tensors must match except in dimension " ,
104104 dimension, " . Got " , first_dim_size, " and " , second_dim_size, " in dimension " , dim,
105105 " (The offending index is " , index, " )" );
106106 }
107107}
108108
109+ static bool should_skip (const Tensor& t) {
110+ return t.numel () == 0 && t.dim () == 1 ;
111+ }
112+
109113Tensor & _cat_out_cpu (Tensor& result, TensorList tensors, int64_t dim) {
110114 // previously, size [0] tensors were the only possible empty tensors; thus, it wasn't possible
111115 // to cat empty tensors unless all the other tensors were 1-dimensional, so we allowed these tensors
112116 // to be "skipped". We maintain this behavior for backwards compatibility, but only for this specific
113117 // size (i.e. other empty sizes are not skipped).
114- // FIXME: warn if this is the case
115- bool allSkipped = true ;
118+
116119 bool allContiguous = true ;
117- Tensor notSkippedTensor;
118120
119121 // Inputs cannot alias the output tensor
120122 for (int64_t i = 0 ; i < tensors.size (); i++) {
@@ -126,19 +128,23 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
126128 }
127129 at::assert_no_internal_overlap (result);
128130
129- auto should_skip = [](const Tensor& t) { return t.numel () == 0 && t.dim () == 1 ; };
130- for (auto const &tensor : tensors) {
131- if (should_skip (tensor)) {
132- continue ;
131+ const Tensor* pnotSkippedTensor = [](TensorList tensors) -> const Tensor* {
132+ for (auto const &tensor : tensors) {
133+ if (should_skip (tensor)) {
134+ continue ;
135+ }
136+ // we've found a non-empty tensor
137+ return &tensor;
133138 }
134- // we've found a non-empty tensor
135- allSkipped = false ;
136- notSkippedTensor = tensor;
137- break ;
138- }
139- if (allSkipped) {
139+ return nullptr ;
140+ }(tensors) ;
141+
142+ if (!pnotSkippedTensor) {
143+ // FIXME: warn if this is the case -- see comment about skipped
144+ // tensors at top of function.
140145 return result;
141146 }
147+ const Tensor& notSkippedTensor = *pnotSkippedTensor;
142148
143149 TORCH_CHECK (tensors.size () > 0 , " expected a non-empty list of Tensors" );
144150 TORCH_CHECK (dim <= notSkippedTensor.dim (), " dimension " , dim, " out of range" );
@@ -161,7 +167,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
161167 continue ;
162168 }
163169 check_cat_shape_except_dim (notSkippedTensor, tensor, dim, i);
164- cat_dim_size += tensor.size (dim) ;
170+ cat_dim_size += tensor.sizes ()[dim] ;
165171
166172 if (!tensor.is_contiguous (first_tensor_mem_format)) {
167173 allContiguous = false ;
@@ -196,8 +202,8 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
196202 if (reuse_iterator &&
197203 result.is_contiguous (first_tensor_mem_format) &&
198204 no_type_promotion) {
199- auto source_slice = notSkippedTensor;
200- auto slice_dim_size = source_slice.size (dim) ;
205+ const auto & source_slice = notSkippedTensor;
206+ auto slice_dim_size = source_slice.sizes ()[dim] ;
201207 auto result_slice = result.narrow (dim, 0 , slice_dim_size);
202208 auto result_slice_data = result_slice.data_ptr ();
203209 auto result_stride_bytes = result.stride (dim) * elementSize (result.scalar_type ());
@@ -226,7 +232,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
226232 if (should_skip (tensor)) {
227233 continue ;
228234 }
229- auto slice_dim_size = tensor.size (dim) ;
235+ auto slice_dim_size = tensor.sizes ()[dim] ;
230236 auto result_slice = result.narrow (dim, offset, slice_dim_size);
231237
232238 auto iter = TensorIteratorConfig ()
0 commit comments