Change FakeTensor constructor to use _make_subclass#77970
Change FakeTensor constructor to use _make_subclass#77970eellison wants to merge 15 commits intogh/eellison/293/basefrom
_make_subclass#77970Conversation
[ghstack-poisoned]
🔗 Helpful links
✅ No Failures (0 Pending)As of commit 67a6a93 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
…om_strides" Add an overload for `new` to handle setting the custom strides/custom device property. [ghstack-poisoned]
…om_strides" Add an overload for `new` to handle setting the custom strides/custom device property. [ghstack-poisoned]
…om_strides" Add an overload for `new` to handle setting the custom strides/custom device property. [ghstack-poisoned]
…om_strides" Add an overload for `new` to handle setting the custom strides/custom device property. [ghstack-poisoned]
| @property | ||
| def device(self): | ||
| return self.fake_device | ||
| # TODO: resolve error in default __repr__ |
There was a problem hiding this comment.
Haven't really looked into it yet, but this is the callstack and the func that is invoked:
aten.add.Tensor
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
aten._reshape_alias.default
prim.device.default
prim.device.default
prim.device.default
aten.abs.default
prim.device.default
prim.device.default
prim.device.default
F
======================================================================
FAIL: test_zero_dim (__main__.FakeTensorTest)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/scratch/eellison/pytorch/test/test_fake_tensor.py", line 31, in test_zero_dim
print(x)
File "/private/home/eellison/anaconda3/lib/python3.8/site-packages/torch/_tensor.py", line 338, in __repr__
return torch._tensor_str._str(self, tensor_contents=tensor_contents)
File "/private/home/eellison/anaconda3/lib/python3.8/site-packages/torch/_tensor_str.py", line 481, in _str
return _str_intern(self, tensor_contents=tensor_contents)
File "/private/home/eellison/anaconda3/lib/python3.8/site-packages/torch/_tensor_str.py", line 447, in _str_intern
tensor_str = _tensor_str(self, indent)
File "/private/home/eellison/anaconda3/lib/python3.8/site-packages/torch/_tensor_str.py", line 270, in _tensor_str
formatter = _Formatter(get_summarized_data(self) if summarize else self)
File "/private/home/eellison/anaconda3/lib/python3.8/site-packages/torch/_tensor_str.py", line 103, in __init__
nonzero_finite_vals = torch.masked_select(tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0))
File "/private/home/eellison/anaconda3/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 98, in __torch_dispatch__
return tree_map(partial(wrap, device=common_device), r)
File "/private/home/eellison/anaconda3/lib/python3.8/site-packages/torch/utils/_pytree.py", line 179, in tree_map
return tree_unflatten([fn(i) for i in flat_args], spec)
File "/private/home/eellison/anaconda3/lib/python3.8/site-packages/torch/utils/_pytree.py", line 179, in <listcomp>
return tree_unflatten([fn(i) for i in flat_args], spec)
File "/private/home/eellison/anaconda3/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 72, in wrap
return FakeTensor(e, device)
File "/private/home/eellison/anaconda3/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 39, in __init__
assert elem.device.type == "meta"
AssertionError
There was a problem hiding this comment.
Oh you need to add a special case for fake tensor in printing the same way we have a short circuit for meta printing. The meta printing short circuit no longer works because your device isn't meta anymore
|
I'm a bit iffy on this. Why not extend |
|
The chain of calls is
make_subclass does the parsing of schemas ( So it's really only that could be shared without a lot of refactoring, since Unless there's some easy refactoring I'm missing or larger refactoring I should do doesn't seem like a lot to share ? |
…om_strides" Add an overload for `new` to handle setting the custom strides/custom device property. [ghstack-poisoned]
|
It's not about refactoring to share; it's related to the earlier question which is why you needed to use BTW, the underlying motivation here is that |
…om_strides" Add an overload for `new` to handle setting the custom strides/custom device property. [ghstack-poisoned]
|
From @samdow
as to why is required but I updated to just call |
…om_strides" Add an overload for `new` to handle setting the custom strides/custom device property. [ghstack-poisoned]
_make_subclass
Add an overload for `new` to handle setting the custom strides/custom device property. [ghstack-poisoned]
|
@eellison has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Yeah we also saw later that the PR mentioned (which is also mentioned in the subclass zoo) is no longer relevant and now closed without being merged. To @eellison's point, it might be worth to update the subclass zoo so that it isn't too out of date, especially if someone else comes in and tries to extend BaseTensor (or use BaseTensor) |
Add an overload for `new` to handle setting the custom strides/custom device property. Differential Revision: [D36618466](https://our.internmc.facebook.com/intern/diff/D36618466) [ghstack-poisoned]
|
@eellison has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Add an overload for `new` to handle setting the custom strides/custom device property. Differential Revision: [D36618466](https://our.internmc.facebook.com/intern/diff/D36618466) [ghstack-poisoned]
Add an overload for `new` to handle setting the custom strides/custom device property. Differential Revision: [D36618466](https://our.internmc.facebook.com/intern/diff/D36618466) [ghstack-poisoned]
Add an overload for `new` to handle setting the custom strides/custom device property. Differential Revision: [D36618466](https://our.internmc.facebook.com/intern/diff/D36618466) [ghstack-poisoned]
Add an overload for `new` to handle setting the custom strides/custom device property. Differential Revision: [D36618466](https://our.internmc.facebook.com/intern/diff/D36618466) [ghstack-poisoned]
Add an overload for `new` to handle setting the custom strides/custom device property. [ghstack-poisoned]
|
@pytorchbot merge this please |
|
Hey @eellison. |
…77970) Summary: Pull Request resolved: #77970 Approved by: https://github.com/Chillee Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/98e08169867e1d3fcf2a660bbe37f65db0e3528a Reviewed By: seemethere Differential Revision: D36784759 Pulled By: seemethere fbshipit-source-id: 71b177873c941c25b98795e80d00380d9360f5b3
Stack from ghstack (oldest at bottom):
_make_subclass#77970Add an overload for
newto handle setting the custom strides/custom device property.