Skip to content

Commit f9a0abf

Browse files
ezyangfacebook-github-bot
authored andcommitted
Fix code review from #48659 and #48116 (#48731)
Summary: Pull Request resolved: #48731 Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: bhosmer Differential Revision: D25278034 Pulled By: ezyang fbshipit-source-id: 73652311b48d8d80c06e9385b7ff18ef3a158ae8
1 parent d6f9e85 commit f9a0abf

2 files changed

Lines changed: 12 additions & 1 deletion

File tree

aten/src/ATen/native/TensorIterator.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1126,6 +1126,9 @@ bool TensorIteratorBase::fast_set_up(const TensorIteratorConfig& config) {
11261126
{
11271127
for (int i = 0; i < num_outputs_; i++){
11281128
auto& op = operands_[i];
1129+
if (!op.tensor.defined()) {
1130+
TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i);
1131+
}
11291132
set_output(i, shape_, {}, op.options().memory_format(MemoryFormat::Contiguous), names_);
11301133
}
11311134
break;
@@ -1134,6 +1137,9 @@ bool TensorIteratorBase::fast_set_up(const TensorIteratorConfig& config) {
11341137
{
11351138
for (int i = 0; i < num_outputs_; i++){
11361139
auto& op = operands_[i];
1140+
if (!op.tensor.defined()) {
1141+
TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i);
1142+
}
11371143
set_output(i, shape_, {}, op.options().memory_format(MemoryFormat::ChannelsLast), names_);
11381144
}
11391145
break;
@@ -1148,6 +1154,9 @@ bool TensorIteratorBase::fast_set_up(const TensorIteratorConfig& config) {
11481154
TORCH_CHECK(i_defined >= 0, "Can not find a defined tensor when fast allocating memory to outputs");
11491155
for (int i = 0; i < num_outputs_; i++){
11501156
auto& op = operands_[i];
1157+
if (!op.tensor.defined()) {
1158+
TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i);
1159+
}
11511160
set_output(i, shape_, operands_[i_defined].tensor.strides(), op.options(), names_);
11521161
}
11531162
break;
@@ -1275,7 +1284,6 @@ void TensorIterator::set_output(int64_t output_idx, IntArrayRef sizes, IntArrayR
12751284
auto& op = operands_[output_idx];
12761285
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx < num_outputs_);
12771286
if (!op.tensor.defined()) {
1278-
TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", output_idx);
12791287
if (strides.empty()) {
12801288
op.tensor = at::empty(sizes, options);
12811289
} else {

test/test_torch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2550,6 +2550,7 @@ def test_empty_meta(self):
25502550
y = torch.empty_meta(2 ** 20)
25512551
z = x + y
25522552
self.assertEqual(z.size(), (2 ** 20, 2 ** 20))
2553+
self.assertRaises(RuntimeError, lambda: z[0][0].item())
25532554

25542555
def test_upsample_nearest1d_meta(self):
25552556
# TODO: this is not a sustainable way of testing meta functions,
@@ -2560,12 +2561,14 @@ def test_upsample_nearest1d_meta(self):
25602561
x = torch.empty_meta(2 * 10 ** 8, 3, 2 * 10 ** 8)
25612562
z = torch.nn.functional.interpolate(x, scale_factor=2)
25622563
self.assertEqual(z.size(), (2 * 10 ** 8, 3, 4 * 10 ** 8))
2564+
self.assertRaises(RuntimeError, lambda: z[0][0][0].item())
25632565

25642566
# interpolate doesn't seem to support out=
25652567
# (not sure why passing None here doesn't work? How strange...)
25662568
z = torch.empty_meta(0)
25672569
torch._C._nn.upsample_nearest1d(x, (4 * 10 ** 8,), 2, out=z)
25682570
self.assertEqual(z.size(), (2 * 10 ** 8, 3, 4 * 10 ** 8))
2571+
self.assertRaises(RuntimeError, lambda: z[0][0][0].item())
25692572

25702573
def test_normal_shape(self):
25712574
warned = False

0 commit comments

Comments
 (0)