Move _storage_Use_Count to be gerneric#155451
Move _storage_Use_Count to be gerneric#155451guangyey wants to merge 8 commits intogh/guangyey/154/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155451
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit a525407 with merge base 3040ca6 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
albanD
left a comment
There was a problem hiding this comment.
This is a weird API to use but sure why not
| s[2:7] = 1 | ||
| self.assertEqual(s, storage_type(l)) | ||
|
|
||
| @skipIfTorchDynamo("Not a suitable test for TorchDynamo") |
There was a problem hiding this comment.
Hi @albanD, I don't understand why this case will fail on Dynamo. The reproducer code is simple
#demo.py
import torch
@torch._dynamo.optimize("eager")
def get_ref():
a = torch.randn(10)
ref = torch._C._storage_Use_Count(a.untyped_storage()._cdata)
print("prev_cf is ", ref)
explination = torch._dynamo.explain(get_ref)()
print(explination.graph_break_count, explination.graph_count)ref is expected as 1, but got unreasonable value 672385573
>$ python demo.py
ref is 672385573
0 1I check the traced bytecode is
V0610 01:45:32.118000 3774389 torch/_dynamo/symbolic_convert.py:1256] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL torch []
V0610 01:45:32.118000 3774389 torch/_dynamo/symbolic_convert.py:1256] [0/0] [__trace_bytecode] TRACE LOAD_ATTR randn [LazyVariableTracker()]
V0610 01:45:32.119000 3774389 torch/_dynamo/symbolic_convert.py:1256] [0/0] [__trace_bytecode] TRACE LOAD_CONST 10 [LazyVariableTracker()]
V0610 01:45:32.119000 3774389 torch/_dynamo/symbolic_convert.py:1256] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 1 [LazyVariableTracker(), ConstantVariable(int: 10)]
V0610 01:45:32.206000 3774389 torch/_dynamo/symbolic_convert.py:1256] [0/0] [__trace_bytecode] TRACE STORE_FAST a [TensorVariable()]
V0610 01:45:32.206000 3774389 torch/_dynamo/symbolic_convert.py:1256] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL torch []
V0610 01:45:32.206000 3774389 torch/_dynamo/symbolic_convert.py:1256] [0/0] [__trace_bytecode] TRACE LOAD_ATTR _C [LazyVariableTracker()]
V0610 01:45:32.206000 3774389 torch/_dynamo/symbolic_convert.py:1256] [0/0] [__trace_bytecode] TRACE LOAD_ATTR _storage_Use_Count [LazyVariableTracker()]
V0610 01:45:32.207000 3774389 torch/_dynamo/symbolic_convert.py:1256] [0/0] [__trace_bytecode] TRACE LOAD_FAST a [LazyVariableTracker()]
V0610 01:45:32.207000 3774389 torch/_dynamo/symbolic_convert.py:1256] [0/0] [__trace_bytecode] TRACE LOAD_ATTR untyped_storage [LazyVariableTracker(), TensorVariable()]
V0610 01:45:32.207000 3774389 torch/_dynamo/symbolic_convert.py:1256] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 0 [LazyVariableTracker(), GetAttrVariable(TensorVariable(), untyped_storage)]
V0610 01:45:32.207000 3774389 torch/_dynamo/symbolic_convert.py:1256] [0/0] [__trace_bytecode] TRACE LOAD_ATTR _cdata [LazyVariableTracker(), UntypedStorageVariable()]
V0610 01:45:32.207000 3774389 torch/_dynamo/symbolic_convert.py:1256] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 1 [LazyVariableTracker(), GetAttrVariable(UntypedStorageVariable(), _cdata)]
V0610 01:45:32.209000 3774389 torch/_dynamo/symbolic_convert.py:1256] [0/0_1] [__trace_bytecode] TRACE LOAD_GLOBAL torch []
V0610 01:45:32.209000 3774389 torch/_dynamo/symbolic_convert.py:1256] [0/0_1] [__trace_bytecode] TRACE LOAD_ATTR randn [LazyVariableTracker()]
V0610 01:45:32.209000 3774389 torch/_dynamo/symbolic_convert.py:1256] [0/0_1] [__trace_bytecode] TRACE LOAD_CONST 10 [LazyVariableTracker()]
V0610 01:45:32.209000 3774389 torch/_dynamo/symbolic_convert.py:1256] [0/0_1] [__trace_bytecode] TRACE CALL_FUNCTION 1 [LazyVariableTracker(), ConstantVariable(int: 10)]
V0610 01:45:32.210000 3774389 torch/_dynamo/symbolic_convert.py:1256] [0/0_1] [__trace_bytecode] TRACE STORE_FAST a [TensorVariable()]
V0610 01:45:32.210000 3774389 torch/_dynamo/symbolic_convert.py:1256] [0/0_1] [__trace_bytecode] TRACE LOAD_GLOBAL torch []
V0610 01:45:32.210000 3774389 torch/_dynamo/symbolic_convert.py:1256] [0/0_1] [__trace_bytecode] TRACE LOAD_ATTR _C [LazyVariableTracker()]
V0610 01:45:32.210000 3774389 torch/_dynamo/symbolic_convert.py:1256] [0/0_1] [__trace_bytecode] TRACE LOAD_ATTR _storage_Use_Count [LazyVariableTracker()]
V0610 01:45:32.210000 3774389 torch/_dynamo/symbolic_convert.py:1256] [0/0_1] [__trace_bytecode] TRACE LOAD_FAST a [LazyVariableTracker()]
V0610 01:45:32.211000 3774389 torch/_dynamo/symbolic_convert.py:1256] [0/0_1] [__trace_bytecode] TRACE LOAD_ATTR untyped_storage [LazyVariableTracker(), TensorVariable()]
V0610 01:45:32.211000 3774389 torch/_dynamo/symbolic_convert.py:1256] [0/0_1] [__trace_bytecode] TRACE CALL_FUNCTION 0 [LazyVariableTracker(), GetAttrVariable(TensorVariable(), untyped_storage)]
V0610 01:45:32.211000 3774389 torch/_dynamo/symbolic_convert.py:1256] [0/0_1] [__trace_bytecode] TRACE LOAD_ATTR _cdata [LazyVariableTracker(), UntypedStorageVariable()]
V0610 01:45:32.211000 3774389 torch/_dynamo/symbolic_convert.py:1256] [0/0_1] [__trace_bytecode] TRACE CALL_FUNCTION 1 [LazyVariableTracker(), GetAttrVariable(UntypedStorageVariable(), _cdata)]
V0610 01:45:32.235000 3774389 torch/_dynamo/symbolic_convert.py:1256] [1/0] [__trace_bytecode] TRACE LOAD_FAST ___stack0 []
V0610 01:45:32.235000 3774389 torch/_dynamo/symbolic_convert.py:1256] [1/0] [__trace_bytecode] TRACE JUMP_ABSOLUTE 30 [LazyVariableTracker()]
V0610 01:45:32.235000 3774389 torch/_dynamo/symbolic_convert.py:1256] [1/0] [__trace_bytecode] TRACE STORE_FAST prev_cf [LazyVariableTracker()]
V0610 01:45:32.235000 3774389 torch/_dynamo/symbolic_convert.py:1256] [1/0] [__trace_bytecode] TRACE LOAD_GLOBAL print []
V0610 01:45:32.235000 3774389 torch/_dynamo/symbolic_convert.py:1256] [1/0] [__trace_bytecode] TRACE LOAD_FAST prev_cf [LazyVariableTracker()]
V0610 01:45:32.235000 3774389 torch/_dynamo/symbolic_convert.py:1256] [1/0] [__trace_bytecode] TRACE CALL_FUNCTION 1 [LazyVariableTracker(), ConstantVariable(int: 2096351460)]
V0610 01:45:32.236000 3774389 torch/_dynamo/symbolic_convert.py:1256] [1/0_1] [__trace_bytecode] TRACE LOAD_FAST ___stack0 []
V0610 01:45:32.237000 3774389 torch/_dynamo/symbolic_convert.py:1256] [1/0_1] [__trace_bytecode] TRACE JUMP_ABSOLUTE 30 [LazyVariableTracker()]
V0610 01:45:32.237000 3774389 torch/_dynamo/symbolic_convert.py:1256] [1/0_1] [__trace_bytecode] TRACE STORE_FAST prev_cf [LazyVariableTracker()]
V0610 01:45:32.237000 3774389 torch/_dynamo/symbolic_convert.py:1256] [1/0_1] [__trace_bytecode] TRACE LOAD_GLOBAL print []
V0610 01:45:32.237000 3774389 torch/_dynamo/symbolic_convert.py:1256] [1/0_1] [__trace_bytecode] TRACE LOAD_FAST prev_cf [LazyVariableTracker()]
V0610 01:45:32.237000 3774389 torch/_dynamo/symbolic_convert.py:1256] [1/0_1] [__trace_bytecode] TRACE CALL_FUNCTION 1 [LazyVariableTracker(), ConstantVariable(int: 2096351460)]
V0610 01:45:32.256000 3774389 torch/_dynamo/symbolic_convert.py:1256] [2/0] [__trace_bytecode] TRACE LOAD_FAST ___stack0 []
V0610 01:45:32.256000 3774389 torch/_dynamo/symbolic_convert.py:1256] [2/0] [__trace_bytecode] TRACE JUMP_ABSOLUTE 38 [LazyVariableTracker()]
V0610 01:45:32.256000 3774389 torch/_dynamo/symbolic_convert.py:1256] [2/0] [__trace_bytecode] TRACE POP_TOP None [LazyVariableTracker()]
V0610 01:45:32.256000 3774389 torch/_dynamo/symbolic_convert.py:1256] [2/0] [__trace_bytecode] TRACE LOAD_CONST None []
V0610 01:45:32.256000 3774389 torch/_dynamo/symbolic_convert.py:1256] [2/0] [__trace_bytecode] TRACE RETURN_VALUE None [ConstantVariable(NoneType: None)]The bytecode looks good except some extra frame/block ids existed ([0/0_1], [1/0], [1/0_1], [2/0]), I see the dynamo reports the graph break count is 0, I don''t understand why extra frame/block existed.
Is this reasonable to skip this UT on Dynamo, or some bugs on Dynamo?
There was a problem hiding this comment.
Yeah definitely a big in Dynamo you can skip and report it here.
My guess is that we have the cdata as a constant during tracing and thus we try to read the use count of an object that doesn't exist anymore.
There was a problem hiding this comment.
Thanks, I will report it.
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
Heads up that the test that got added is failing in the debug build. I've disabled it in #156731 but please take a look when you have the time to decide what to do with it |
Thanks for the heads-up. It looks like there’s a bug in this API—I’m currently investigating the issue. |
# Motivation #155451 decoupled `torch._C._storage_Use_Count` from CUDA and introduced a corresponding unit test: https://github.com/pytorch/pytorch/blob/815545f2dd6ade563cb1263f8bb7813f355edb2e/test/test_torch.py#L257-L262 However, this test fails when PyTorch is built with debug assertions enabled. @clee2000 disabled this UT in #156731. The root cause is that `_cdata` is obtained from an `intrusive_ptr`, not a `weak_intrusive_ptr`. As a result, calling `c10::weak_intrusive_ptr::use_count` on it triggers the internal assertion: https://github.com/pytorch/pytorch/blob/815545f2dd6ade563cb1263f8bb7813f355edb2e/c10/util/intrusive_ptr.h#L912-L917 For example: ```python a = torch.randn(10, device=device) # refcount=1, weakcount=1 prev_cf = torch._C._storage_Use_Count(a.untyped_storage()._cdata) # violate the assertation ``` This violates the expected invariant inside `weak_intrusive_ptr::use_count`, which assumes the pointer was originally constructed from a valid `weak_intrusive_ptr`. Actually, `storage_impl` is obtained from an `intrusive_ptr`. https://github.com/pytorch/pytorch/blob/815545f2dd6ade563cb1263f8bb7813f355edb2e/torch/csrc/Module.cpp#L2105-L2109 # Solution Use `c10::intrusive_ptr::use_count` instead. Pull Request resolved: #157694 Approved by: https://github.com/albanD
Stack from ghstack (oldest at bottom):
Motivation
torch._C._storage_Use_Countshould be a generic API that is not aware of device type. It is also used in https://github.com/pytorch/torchtune/blob/337cd7c53d7006e2330b2f0b248d48ec5180b6cc/torchtune/training/_activation_offloading.py#L323 to do some memory optimization.