Skip to content

[wip][lazy] A Lazy Tensor Implementation#25753

Closed
bwasti wants to merge 322 commits intopytorch:masterfrom
bwasti:lazy_tensor
Closed

[wip][lazy] A Lazy Tensor Implementation#25753
bwasti wants to merge 322 commits intopytorch:masterfrom
bwasti:lazy_tensor

Conversation

@bwasti
Copy link
Contributor

@bwasti bwasti commented Sep 6, 2019

Lazy tensors create a lazy execution environment for PyTorch. Instead of immediately running operators, lazy tensors defer execution until the last possible moment (or when the user converts the tensor to eager).

By deferring execution, the backend is able to transparently view a larger chunk of the program. It can then optimize the execution plan or compile better code than what is possible in an immediate execution setting.

Laziness may hurt the debugging experience, as users will need to disable the feature to get the exact lines of failure in their Python code.

Design

The general design requires hijacking the dispatch mechanism for all operators, shimming these operators to return Lazy Tensors and storing a record of their invocation in the output Tensor. When we need to execute the operators, we lower the recorded "history" stored within the Tensor to PyTorch's JIT IR and run that. Autograd is set up in the same way as TorchScript.

Status

This diff is not in a land-able state, but provides a fully working implementation of lazy tensor for a large subset of PyTorch operators. I'm putting it up for initial design comments and to motivate conversation surround dispatch with concrete code and needs.

The change is blocked by the development of a stable multi-dispatch mechanism.

The change is also missing a couple of things that will need to exist before landing:

  • non-CPU laziness (not too hard to add)
  • thorough testing (full models etc)

cc @nikitaved @pearu @cpuhrsch @IvanYashchuk @yf225 @glaringlee @zou3519 @ezyang @bhosmer @smessmer @ljk53 @bdhirsh @albanD @mruberry @jbschlosser @ngimel

@bwasti bwasti requested a review from apaszke as a code owner September 6, 2019 05:27
@pytorchbot pytorchbot added oncall: jit Add this issue/PR to JIT oncall triage queue module: internals Related to internal abstractions in c10 and ATen module: operators labels Sep 6, 2019
@vadimkantorov
Copy link
Contributor

Would this enable custom user kernel codegen plug-ins? Like KeOps?

@soumith
Copy link
Collaborator

soumith commented Sep 6, 2019

@vadimkantorov it could if one writes a KeOps-pytorch-jit plugin similar to how https://github.com/pytorch/tvm is a jit plugin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Repeating from what was discussed in realtime: these changes will conflict with the post-multi dispatch implementation at

void* getOp(TensorTypeId tid) const {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The recommendation:

Currently, the implementation of dispatcher logic in #25653

op = dispatch_table[op_id][valid_keys[0]]
if not op:
  # fallback logic...
return op(args)

We adjust fallback handling to now be:

op = dispatch_table[op_id][valid_keys[0]]
if not op and default_dispatch_table[valid_keys[0]]:
  op = BoxedOpFunctor<FuncType>(default_dispatch_table[valid_keys[0]])
if not op:
  # fallback logic...
return op(args)

There are two pieces:

  1. The default dispatch table contains registrations of functions default_dispatch(Operator op, Stack&& ivalues) which get passed in what operator they need to process, and arguments in boxed form. The expectation is a backend like lazy would write this function once to work polymorphically over everything. Operator is some type that gives you the schema string, and also a way to redispatch on the operator.
  2. The boxed op functors boxed_op_functors is a functor which unpacks FuncType into its boxed form, and then calls the default dispatch function. It is templated here, as is done in this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much of this is actually changed, versus just code motion? (Can we split this off into a diff before this in the stack?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Furthermore, I would not expect our executors to change at all with this patch. Why did they need to change? This part of the PR is very hard to review because it mixes a ton of code motion with some changes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, it'll be split out in any later (non-WIP) diffs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know in earlier conversation, @zdevito has objected to API creep here. I personally think these functions are probably useful as a debugging mechanism, but we probably shouldn't make them methods (and give them suitably private sounding names).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't make sense to me, there is already a device_opt_ on the parent struct...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like in some cases, schema_ == nullptr. When this is the case, what does the lazy tensor semantically represent? It kinds of sounds like you have an invariant like "if schema is null, then inps_.size() == 1". Is this true?

Copy link
Contributor

@ezyang ezyang Sep 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on further reading, I think this invariant holds? It would be clearer if there were a more semantic function we use to test for this case, rather than manually always poking to check if schema_ is null; maybe something like:

optional<Tensor> constant_opt();

Copy link
Contributor

@ezyang ezyang Sep 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know in this particular case it is safe, but unchecked static casts always make me nervous (probably because I'm afraid someone will copy paste this code to a place where there the precondition that a tensor is lazy is NOT satisfied.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks highly questionable. Can't we just impose the invariant that the result of to_eager() is always a Variable? When can it not be a Variable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I empirically found this to fail; I'll dig into what is going on here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inferring that the dtype/device of the result is the same as the first argument will not work for many operations. (I guess lazy tensor is going to need this information, somehow, too.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Err... a std::map? Really?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(is this for determinism?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will swap to unordered, oversight on my part

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this struct, I don't see any place for storing materialized results after you run to_eager on a graph. This seems to imply to me that if I do something like:

x = ... some lazy computation ...
x.to_eager()
x.to_eager()

I will end up redoing the entirety of lazy computation twice. Is that true?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, that's true

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't you add a t.is_lazy() specifically to not have to type all this out? XD

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been thinking about whether, in the "type set" world, lazy tensors should also have dispatch keys for, e.g., variable and cpu/cuda tensor.

On the side of yes: tensor type sets are currently used for instanceof tests, and in transparent lazy usage, you want lazy cpu tensors to report as cpu tensors

On the side of no: tensor type sets are used to make claims about the structure of the TensorImpl; but a lazy cpu tensor "is not a" cpu tensor.

@zdevito Maybe this means we really should organize one of these notions some other way on Tensor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dead variable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question for the design of boxed dispatch key fallback: in your prototype, you save const char* schema. However, tracing the usages of it, it seems that what you really care about in the end is the symbol corresponding to the function, so you can make a JIT op for it. (1) Is this true? (2) Are there other uses of, e.g., FunctionSchema or similar that you anticipate needing?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No actual change here right?

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite interesting. I've been thinking quite a bit about how the JIT and autograd should relate in the future, and this seems like a step in a direction that I've been considering. There's one big issue though: torch::jit::Graph and all other JIT abstractions are not designed to be manipulated in the hot path. There's a whole lot of pointer chasing that makes it very convenient to implement many algorithms, but would completely kill our performance. Are we planning to reevaluate that implementation and possibly change how we represent things?

@ezyang
Copy link
Contributor

ezyang commented Sep 10, 2019

Are we planning to reevaluate that implementation and possibly change how we represent things?

Yes, an alternate IR representation for speed is on the table. JIT IR is being used here because it's easy, and there are plenty of other problems to solve even without taking on IR design to start.

Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs more documentation and explanation before I can give an accurate review. In particular I'd like to see:

  • An overview of the lifetime of lazy tensor, and what code pieces introduce lazyness and remove it.
  • Documentation of the machinery: there are things like make_lazy and its counterpart to make things eager. What are they. Can we put this machinery in its own files so it can be understood to be different than standard things in the ATen core.
  • An explanation of why details of the graph executor got exposed. I can't find the code that uses them and I do not believe it should have had to change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Furthermore, I would not expect our executors to change at all with this patch. Why did they need to change? This part of the PR is very hard to review because it mixes a ton of code motion with some changes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Theres are implementation details of the executor. They should not be exposed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just want to use DifferentiableGraphOp, which seems to pull all of this in

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But why do you need that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't follow what is going on here. Can you document it? Why is this working with raw char* schema strings?

@jeffreyksmithjr
Copy link
Contributor

Request/suggestion: we could use some sort of lazy tag on GH, either as a module:lazy variant or just a plain old lazy label.

This is required for boxed backend fallback kernels (e.g. lazy, AMP) because they need to know which op was actually called.

Differential Revision: [D18282746](https://our.internmc.facebook.com/intern/diff/D18282746/)

[ghstack-poisoned]
This makes for a nicer API, especially in backend fallback kernels who get an OperatorHandle instance and can directly call these methods on it.

Differential Revision: [D18357424](https://our.internmc.facebook.com/intern/diff/D18357424/)

[ghstack-poisoned]
…nel API"

This argument is needed by boxing wrappers so they're able to get a pointer to the corresponding unboxed kernel and call into it.
But if a kernel is registered in a boxed way, we don't need it and should hide this from the API.
This is especially needed for the backend fallback API where users would only be left wondering why this argument is there and what it does.
Also, hiding it allows us to potentially totally remove it in a future refactoring if we find some way to do so.

Differential Revision: [D18361991](https://our.internmc.facebook.com/intern/diff/D18361991/)

[ghstack-poisoned]
Remove callUnboxedOnly() and instead use metaprogramming to figure out if an operator can use a boxed fallback or not.
This enables boxed fallback for ops in native_functions.yaml even if they don't have `use_c10_dispatcher: full` set, as long as they're in the range of supported types.

Differential Revision: [D18462653](https://our.internmc.facebook.com/intern/diff/D18462653/)

[ghstack-poisoned]
This PR re-introduces backend_fallback_test.cpp, which was previously called boxed_fallback_test.cpp and showed how to use the backend fallback API.

Differential Revision: [D18462654](https://our.internmc.facebook.com/intern/diff/D18462654/)

[ghstack-poisoned]
…nels"

This is required for boxed backend fallback kernels (e.g. lazy, AMP) because they need to know which op was actually called.

Differential Revision: [D18282746](https://our.internmc.facebook.com/intern/diff/D18282746/)

[ghstack-poisoned]
This is required for boxed backend fallback kernels (e.g. lazy, AMP) because they need to know which op was actually called.

Differential Revision: [D18282746](https://our.internmc.facebook.com/intern/diff/D18282746/)

[ghstack-poisoned]
This makes for a nicer API, especially in backend fallback kernels who get an OperatorHandle instance and can directly call these methods on it.

Differential Revision: [D18357424](https://our.internmc.facebook.com/intern/diff/D18357424/)

[ghstack-poisoned]
…nel API"

This argument is needed by boxing wrappers so they're able to get a pointer to the corresponding unboxed kernel and call into it.
But if a kernel is registered in a boxed way, we don't need it and should hide this from the API.
This is especially needed for the backend fallback API where users would only be left wondering why this argument is there and what it does.
Also, hiding it allows us to potentially totally remove it in a future refactoring if we find some way to do so.

Differential Revision: [D18361991](https://our.internmc.facebook.com/intern/diff/D18361991/)

[ghstack-poisoned]
Remove callUnboxedOnly() and instead use metaprogramming to figure out if an operator can use a boxed fallback or not.
This enables boxed fallback for ops in native_functions.yaml even if they don't have `use_c10_dispatcher: full` set, as long as they're in the range of supported types.

Differential Revision: [D18462653](https://our.internmc.facebook.com/intern/diff/D18462653/)

[ghstack-poisoned]
@bwasti
Copy link
Contributor Author

bwasti commented Nov 16, 2019

rebased onto unblocking PR:
#29682

we're back in business!

smessmer and others added 2 commits November 16, 2019 16:23
This PR re-introduces backend_fallback_test.cpp, which was previously called boxed_fallback_test.cpp and showed how to use the backend fallback API.

Differential Revision: [D18462654](https://our.internmc.facebook.com/intern/diff/D18462654/)

[ghstack-poisoned]
@mcarilli
Copy link
Collaborator

mcarilli commented Nov 18, 2019

@bwasti I'm trying to implement autocasting (automatic casting of inputs to the most performant precision for a given op). We can talk on slack if you want more context but I'm also interposing a layer on the dispatch path (for a new "AutocastTensorId") that casts inputs to the desired type then runs the op. There are a few points I'm curious about, loosely organized in two topics:

1. Will autocasting play well with LazyTensor?

I think autocasting is compatible with your approach, as long as AutocastTensorId is given higher priority that LazyTensorId. Any casts to half or float will redispatch to subsequent dispatch layers, including the LazyTensor layer, so the graph that LazyTensor records will include all the casts, and jit will have the freedom to fuse them as it sees fit.

Currently, I give AmpTensorId higher priority than VariableTensorId, to ensure that inputs saved for backward by a given op are post-cast (saving inputs after a cast to FP16 provides significant memory savings). It appears you're giving LazyTensorId lower priority than VariableTensorId (ie, routing through the LazyTensor hijack occurs after VariableTensorId's autograd history recording). So as-is, AmpTensorId will have higher priority than LazyTensorId, and I think autocasting will work with LazyTensors, but I don't want to accidentally paint myself into a corner. Please let me know if my approach might break something.

2. What's your planned API? Can I draw inspiration from it?

I'm also trying to figure out the python API for autocasting. Any information I can gather on other APIs that alter/control how ops are dispatched is valuable. Right now, we're leaning towards exposing autocasting via a context manager that sets AmpTensorId to be globally included, but not baking AmpTensorId into the tensors themselves. Based on your tests, it seems you're pursuing the other approach: laziness is carried as a tensor attribute, and the user-facing control points are tensor.to_lazy() and tensor.to_eager(). Is that correct? Is it complete, or are there other APIs you're considering? Also, what happens with c = a + b where a is a LazyTensor and b is not? My guess based on the usual dispatch logic is that it would route through LazyTensor and c would be a lazy tensor as well, since the dispatch ID is computed as a superset, but I'm wondering if there's any special-casing you considered.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bwasti has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@bwasti
Copy link
Contributor Author

bwasti commented Nov 19, 2019

  1. Will autocasting play well with LazyTensor?
    I give AmpTensorId higher priority than VariableTensorId

Yep, LazyTensor is conceptually just a hack to inject a lazy execution model into PyTorch. It can really function at any level, but below autograd makes the most sense imo. Seems like it'll play well with automatic casting :)

  1. What's your planned API? Can I draw inspiration from it?
    tensor.to_lazy() and tensor.to_eager()

Yep, that's pretty much it. Ideally there wouldn't be any user interaction at the tensor level, so I don't think I'll be adding anything to the API.

c = a + b where a is a LazyTensor and b is not

I'm not sure this is well handled, but ideally b would be implicitly "cast" to LazyTensor.

@mcarilli
Copy link
Collaborator

mcarilli commented Nov 20, 2019

@bwasti Thanks! To elaborate a bit on what I said earlier, the current default behavior is to compute an op's TensorTypeSet by taking the superset (inclusive or) of the TensorTypeSets of any participating Tensors, then folding in (inclusive or) the thread-local tensor-independent included TensorTypeSet, then subtracting out the thread-local excluded TensorTypeSet. Maybe I'm repeating stuff you already know, but I expect c = a + b (where a is a LazyTensor and b is not) will route through the LazyTensor hijacking and produce c as another LazyTensor. I was wondering if you had additional code to explicitly alter this default behavior.

The to_lazy()/to_eager() interface shows us an interesting alternative to a context manager. We could mimic that with to_amp() to assign autocastability, and explicit calls to float() or half() would remove autocastability.

In your case, there are a couple details I'm additionally curious about:
If I'm reading the code correctly, a.to_lazy() stashes a as an input to the graph. Does it also deep-copy the memory as it does so? And is a itself altered by a.to_lazy(), or does a.to_lazy() simply return a new "Tensor" (lightweight handle without any actual associated memory) that can be used to build the graph further?

Finally, what's the default behavior of print(a) or a.item() if a is a lazy tensor? Does that automatically trigger a to_eager(), or will it error and request that you call to_eager manually?

Comment on lines -50 to -53
if (!tensors.empty()) {
TORCH_INTERNAL_ASSERT(
torch::autograd::compute_requires_grad(tensors),
"Received tensors do not require grad, addRecvRpcBackward should not be called");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if these changes in distributed are intended?

std::vector<char>(pickledPythonUDF.begin(), pickledPythonUDF.end()),
tensors);
return sendMessageWithAutograd(
agent, dst, std::move(*pythonCall).toMessage());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

@bwasti
Copy link
Contributor Author

bwasti commented Dec 3, 2019

due to an internal build requirement -- LazyTensor impl has moved to this PR, which is exported from internal code:

#30674

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed module: cpp Related to C++ API module: cpp-extensions Related to torch.utils.cpp_extension module: cuda Related to torch.cuda, and CUDA support in general module: internals Related to internal abstractions in c10 and ATen module: lazy module: nn Related to torch.nn module: sparse Related to torch.sparse oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.