@@ -134,6 +134,7 @@ Tensor empty_cpu(IntArrayRef size, const TensorOptions& options_, c10::optional<
134134
135135 auto memory_format = options.memory_format_opt ().value_or (MemoryFormat::Contiguous);
136136 tensor.unsafeGetTensorImpl ()->empty_tensor_restride (memory_format);
137+
137138 return tensor;
138139}
139140
@@ -342,18 +343,47 @@ Tensor& eye_out_cpu(Tensor& result, int64_t n, int64_t m) {
342343
343344// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ full ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
344345
345- Tensor full (IntArrayRef size, Scalar fill_value, const TensorOptions& options) {
346- if (options.layout () == kSparse ) {
347- AT_ERROR (" full(...) is not implemented for sparse layout" );
346+ namespace {
347+
348+ // Performs dtype inference for full
349+ TensorOptions infer_full_options (
350+ Scalar fill_value,
351+ const TensorOptions& options) {
352+
353+ if (!options.has_dtype ()) {
354+ if (fill_value.isIntegral (true )) {
355+ TORCH_WARN_ONCE (
356+ " Deprecation warning: In a future PyTorch release torch.full " ,
357+ " will no longer return tensors of floating dtype by default. " ,
358+ " Instead, a bool fill_value will return a tensor of torch.bool dtype, " ,
359+ " and an integral fill_value will return a tensor of torch.long dtype. " ,
360+ " Set the optional `dtype` or `out` arguments to suppress this warning."
361+ );
362+ } else if (fill_value.isComplex ()) {
363+ auto scalar_type = (get_default_dtype () == ScalarType::Double) ?
364+ ScalarType::ComplexDouble :
365+ ScalarType::ComplexFloat;
366+ return options.dtype (scalar_type);
367+ }
348368 }
349- auto result = at::empty (size, options);
369+
370+ return options;
371+ }
372+
373+ } // anonymous namespace
374+
375+ Tensor full (IntArrayRef size, Scalar fill_value, const TensorOptions& options) {
376+ TORCH_CHECK (options.layout () != kSparse ,
377+ " full(...) is not implemented for sparse layout" );
378+
379+ auto result = at::empty (size, infer_full_options (fill_value, options));
350380 return result.fill_ (fill_value);
351381}
352382
353383Tensor& full_out (Tensor& result, IntArrayRef size, Scalar fill_value) {
354- if ( result.is_sparse ()) {
355- AT_ERROR ( " full(...) is not implemented for sparse layout" );
356- }
384+ TORCH_CHECK (! result.is_sparse (),
385+ " full(...) is not implemented for sparse layout" );
386+
357387 result.resize_ (size);
358388 return result.fill_ (fill_value);
359389}
@@ -404,19 +434,19 @@ Tensor logspace(
404434// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ones ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
405435
406436Tensor ones (IntArrayRef size, const TensorOptions& options) {
407- return native::full (size, /* fill_value=*/ 1 , options);
437+ return native::full (size, /* fill_value=*/ 1 . , options);
408438}
409439
410440Tensor& ones_out (Tensor& result, IntArrayRef size) {
411- return native::full_out (result, size, /* fill_value=*/ 1 );
441+ return native::full_out (result, size, /* fill_value=*/ 1 . );
412442}
413443
414444Tensor ones_like (
415445 const Tensor& self,
416446 const TensorOptions& options,
417447 c10::optional<c10::MemoryFormat> optional_memory_format) {
418448 auto result = at::empty_like (self, options, optional_memory_format);
419- return result.fill_ (1 );
449+ return result.fill_ (1 . );
420450}
421451
422452// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ scalar_tensor ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -746,7 +776,7 @@ Tensor zeros(IntArrayRef size, const TensorOptions& options) {
746776
747777Tensor& zeros_out (Tensor& result, IntArrayRef size) {
748778 if (result.is_sparse ()) {
749- result.sparse_resize_and_clear_ (size, size.size (), 0 );
779+ result.sparse_resize_and_clear_ (size, size.size (), 0 . );
750780 return result;
751781 } else {
752782 result.resize_ (size);
@@ -960,22 +990,26 @@ Tensor full(
960990 Scalar fill_value,
961991 optional<DimnameList> names,
962992 const TensorOptions& options) {
963- auto result = at::empty (size, names, options);
993+
994+ TORCH_CHECK (options.layout () != kSparse ,
995+ " full(...) is not implemented for sparse layout" );
996+
997+ auto result = at::empty (size, names, infer_full_options (fill_value, options));
964998 return result.fill_ (fill_value);
965999}
9661000
9671001Tensor ones (
9681002 IntArrayRef size,
9691003 optional<DimnameList> names,
9701004 const TensorOptions& options) {
971- return native::full (size, /* fill_value=*/ 1 , names, options);
1005+ return native::full (size, /* fill_value=*/ 1 . , names, options);
9721006}
9731007
9741008Tensor zeros (
9751009 IntArrayRef size,
9761010 optional<DimnameList> names,
9771011 const TensorOptions& options) {
978- return native::full (size, /* fill_value=*/ 0 , names, options);
1012+ return native::full (size, /* fill_value=*/ 0 . , names, options);
9791013}
9801014
9811015Tensor randn (
0 commit comments