Skip to content

Change FakeTensor constructor to use _make_subclass#77970

Closed
eellison wants to merge 15 commits intogh/eellison/293/basefrom
gh/eellison/293/head
Closed

Change FakeTensor constructor to use _make_subclass#77970
eellison wants to merge 15 commits intogh/eellison/293/basefrom
gh/eellison/293/head

Conversation

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented May 20, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit 67a6a93 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

…om_strides"


Add an overload for `new` to handle setting the custom strides/custom device property.

[ghstack-poisoned]
@eellison eellison requested a review from ezyang May 20, 2022 16:17
Elias Ellison added 3 commits May 20, 2022 11:23
…om_strides"


Add an overload for `new` to handle setting the custom strides/custom device property.

[ghstack-poisoned]
…om_strides"


Add an overload for `new` to handle setting the custom strides/custom device property.

[ghstack-poisoned]
…om_strides"


Add an overload for `new` to handle setting the custom strides/custom device property.

[ghstack-poisoned]
@property
def device(self):
return self.fake_device
# TODO: resolve error in default __repr__
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't really looked into it yet, but this is the callstack and the func that is invoked:

aten.add.Tensor
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
prim.device.default
aten._reshape_alias.default
prim.device.default
prim.device.default
prim.device.default
aten.abs.default
prim.device.default
prim.device.default
prim.device.default
F
======================================================================
FAIL: test_zero_dim (__main__.FakeTensorTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/scratch/eellison/pytorch/test/test_fake_tensor.py", line 31, in test_zero_dim
    print(x)
  File "/private/home/eellison/anaconda3/lib/python3.8/site-packages/torch/_tensor.py", line 338, in __repr__
    return torch._tensor_str._str(self, tensor_contents=tensor_contents)
  File "/private/home/eellison/anaconda3/lib/python3.8/site-packages/torch/_tensor_str.py", line 481, in _str
    return _str_intern(self, tensor_contents=tensor_contents)
  File "/private/home/eellison/anaconda3/lib/python3.8/site-packages/torch/_tensor_str.py", line 447, in _str_intern
    tensor_str = _tensor_str(self, indent)
  File "/private/home/eellison/anaconda3/lib/python3.8/site-packages/torch/_tensor_str.py", line 270, in _tensor_str
    formatter = _Formatter(get_summarized_data(self) if summarize else self)
  File "/private/home/eellison/anaconda3/lib/python3.8/site-packages/torch/_tensor_str.py", line 103, in __init__
    nonzero_finite_vals = torch.masked_select(tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0))
  File "/private/home/eellison/anaconda3/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 98, in __torch_dispatch__
    return tree_map(partial(wrap, device=common_device), r)
  File "/private/home/eellison/anaconda3/lib/python3.8/site-packages/torch/utils/_pytree.py", line 179, in tree_map
    return tree_unflatten([fn(i) for i in flat_args], spec)
  File "/private/home/eellison/anaconda3/lib/python3.8/site-packages/torch/utils/_pytree.py", line 179, in <listcomp>
    return tree_unflatten([fn(i) for i in flat_args], spec)
  File "/private/home/eellison/anaconda3/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 72, in wrap
    return FakeTensor(e, device)
  File "/private/home/eellison/anaconda3/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 39, in __init__
    assert elem.device.type == "meta"
AssertionError

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh you need to add a special case for fake tensor in printing the same way we have a short circuit for meta printing. The meta printing short circuit no longer works because your device isn't meta anymore

@ezyang
Copy link
Contributor

ezyang commented May 21, 2022

I'm a bit iffy on this. Why not extend _make_subclass and then call it directly from __new__?

@eellison
Copy link
Contributor Author

The chain of calls is
__new__ -> THPVariable_pynew -> base_tensor_ctor -> legacy_tensor_generic_ctor_new.

legacy_tensor_generic_ctor_new is what does the parsing of new schemas and where I added the logic here. Once it returns the new Tensor to THPVariable_pynew then THPVariable_NewWithVar does the wrapping of the new tensor to its potential Tensor subclass.

make_subclass does the parsing of schemas (_make_subclass(PyObject* cls, Tensor data, bool require_grad=False, *, bool dispatch_strides=False) , the creation of the new Tensor, and the wrapping of its new tensor to its potential Tensor subclass.

So it's really only

    if (r.toBool(1)) {
      alias.unsafeGetTensorImpl()->set_sizes_strides_policy(c10::TensorImpl::SizesStridesPolicy::CustomStrides);
    }
    if (r.toBool(2)) {
      alias.unsafeGetTensorImpl()->set_custom_device(true);
    }
    return alias;

that could be shared without a lot of refactoring, since _make_subclass both does parsing of schemas and the wrapping of tensors to its subclasses...

Unless there's some easy refactoring I'm missing or larger refactoring I should do doesn't seem like a lot to share ?

…om_strides"


Add an overload for `new` to handle setting the custom strides/custom device property.

[ghstack-poisoned]
@ezyang
Copy link
Contributor

ezyang commented May 23, 2022

It's not about refactoring to share; it's related to the earlier question which is why you needed to use BaseTensor.__new__ as opposed to defining it directly and calling _make_subclass inside __new__, so you wouldn't exercise the __new__ codepath at all.

BTW, the underlying motivation here is that __new__ is exercisable by anyone who does a Tensor(...) call, so I'd like to avoid sprouting API surface here.

…om_strides"


Add an overload for `new` to handle setting the custom strides/custom device property.

[ghstack-poisoned]
@eellison
Copy link
Contributor Author

From @samdow

So typically if you pass requires_grad to a tensor constructor, that means that you want the subclass that you're creating to requires gradient, not the tensor that it's wrapping. It looks like from this #73727, it accepts a single tensor, so super().new() calls that and if the tensor requires gradient, it will wrap it like normal and won't put requires_grad on the wrapper tensor

as to why

    def __new__(cls, elem, *, requires_grad=None):
        if requires_grad is None:
            return super().__new__(cls, elem)
        else:
            return cls._make_subclass(cls, elem, requires_grad)

is required

but I updated to just call _make_subclass and it seems to work for now. Maybe something that will need to be revisited later

@eellison eellison requested a review from ezyang May 24, 2022 00:00
…om_strides"


Add an overload for `new` to handle setting the custom strides/custom device property.

[ghstack-poisoned]
@eellison eellison changed the title Extend __new__ on subclasses to set custom_device and custom_strides Change FakeTensor constructor to use _make_subclass May 24, 2022
Copy link
Collaborator

@Chillee Chillee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Add an overload for `new` to handle setting the custom strides/custom device property.

[ghstack-poisoned]
@eellison
Copy link
Contributor Author

@eellison has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@samdow
Copy link
Contributor

samdow commented May 24, 2022

but I updated to just call _make_subclass and it seems to work for now. Maybe something that will need to be revisited later

Yeah we also saw later that the PR mentioned (which is also mentioned in the subclass zoo) is no longer relevant and now closed without being merged. To @eellison's point, it might be worth to update the subclass zoo so that it isn't too out of date, especially if someone else comes in and tries to extend BaseTensor (or use BaseTensor)

Add an overload for `new` to handle setting the custom strides/custom device property.

Differential Revision: [D36618466](https://our.internmc.facebook.com/intern/diff/D36618466)

[ghstack-poisoned]
@eellison
Copy link
Contributor Author

@eellison has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Elias Ellison added 3 commits May 24, 2022 15:59
Add an overload for `new` to handle setting the custom strides/custom device property.

Differential Revision: [D36618466](https://our.internmc.facebook.com/intern/diff/D36618466)

[ghstack-poisoned]
Add an overload for `new` to handle setting the custom strides/custom device property.

Differential Revision: [D36618466](https://our.internmc.facebook.com/intern/diff/D36618466)

[ghstack-poisoned]
Add an overload for `new` to handle setting the custom strides/custom device property.

Differential Revision: [D36618466](https://our.internmc.facebook.com/intern/diff/D36618466)

[ghstack-poisoned]
Add an overload for `new` to handle setting the custom strides/custom device property.

Differential Revision: [D36618466](https://our.internmc.facebook.com/intern/diff/D36618466)

[ghstack-poisoned]
Add an overload for `new` to handle setting the custom strides/custom device property.



[ghstack-poisoned]
@eellison eellison mentioned this pull request May 31, 2022
@eellison
Copy link
Contributor Author

@pytorchbot merge this please

@github-actions
Copy link
Contributor

Hey @eellison.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit that referenced this pull request Jun 1, 2022
…77970)

Summary:
Pull Request resolved: #77970

Approved by: https://github.com/Chillee

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/98e08169867e1d3fcf2a660bbe37f65db0e3528a

Reviewed By: seemethere

Differential Revision: D36784759

Pulled By: seemethere

fbshipit-source-id: 71b177873c941c25b98795e80d00380d9360f5b3
@facebook-github-bot facebook-github-bot deleted the gh/eellison/293/head branch June 4, 2022 14:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants