Skip to content

Commit d7b57c4

Browse files
zou3519pytorchmergebot
authored andcommitted
Fix tensor.data access under inference_mode and compile (#134878)
Fixes #134798 In the regular Tensor case, when you call Tensor.data, there's a check for if inference mode is active. If it is active, then we don't set the version counter. We replicate this check for Tensor Subclasses (the bug was we were trying to set the version counter on a FakeTensor in inference_mode). Test Plan: - new test Pull Request resolved: #134878 Approved by: https://github.com/bdhirsh
1 parent 0d193a0 commit d7b57c4

2 files changed

Lines changed: 14 additions & 1 deletion

File tree

c10/core/TensorImpl.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,9 @@ c10::intrusive_ptr<TensorImpl> TensorImpl::shallow_copy_and_detach_core(
509509
r = (pyobj_slot_.load_pyobj_interpreter())->detach(this);
510510
}
511511
if (r) {
512-
r->set_version_counter(std::forward<VariableVersion>(version_counter));
512+
if (!r->is_inference()) {
513+
r->set_version_counter(std::forward<VariableVersion>(version_counter));
514+
}
513515
r->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
514516
return r;
515517
}

test/dynamo/test_misc.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2000,6 +2000,17 @@ def fn(cfg, x, y):
20002000
self.assertEqual(cnts.frame_count, 1)
20012001
self.assertEqual(cnts.op_count, 2)
20022002

2003+
def test_data_access_in_inference_mode(self):
2004+
@torch.compile(fullgraph=True)
2005+
def f(x):
2006+
y = x.data
2007+
return y
2008+
2009+
with torch.inference_mode():
2010+
x = torch.randn(3)
2011+
y = f(x)
2012+
self.assertEqual(y, x)
2013+
20032014
def test_dataclass_fields(self):
20042015
@dataclasses.dataclass
20052016
class MyDataClass:

0 commit comments

Comments
 (0)