Conversation
fabfda5 to
4d55c99
Compare
6e49287 to
fc9f54d
Compare
4d55c99 to
4dd71a6
Compare
bd976a1 to
1085c8f
Compare
c8c7681 to
d2ba37d
Compare
1085c8f to
a8487b1
Compare
a8487b1 to
5e3536f
Compare
49791eb to
c223040
Compare
5e3536f to
0002e0a
Compare
| XLA_FN_TRACK(3); | ||
| const auto name = c10::toString(op.operator_name()); | ||
|
|
||
| // Manually applying the XLA_COUNTER macro. |
There was a problem hiding this comment.
So, after staring at some failing C++ tests for a while, I finally realized that some tests were failing because of my usage of the XLA_COUNTER macro here :(
It's defined in the XLA repo here, and defined a local static counter that's unique to the name that you pass in. Which means that it silently does bad things if you try to call the macro with different operator names from the same piece of source code. If I call xla_cpu_fallback() once with add, then with mul, the counter for add will get incremented twice.
Maybe it's worth a patch to xla to add a different version of the macro? For now, I just hardcoded what the macro does here, with a global mapping of counters for every op that the CPU fallback is called with.
There was a problem hiding this comment.
nice catch! XLA_COUNTER is designed for being used with different names, I think it is fine to leave it as it is.
JackCaoG
left a comment
There was a problem hiding this comment.
Mostly LGTM, some minor comments.
torch_xla/csrc/aten_cpu_fallback.cpp
Outdated
|
|
||
| namespace torch_xla { | ||
|
|
||
| std::unordered_map<std::string, ::xla::metrics::Counter*> |
There was a problem hiding this comment.
I think we should make this map static since we don't expect code outside of this file to access it. wdyt?
There was a problem hiding this comment.
good catch, static sounds good
| XLA_FN_TRACK(3); | ||
| const auto name = c10::toString(op.operator_name()); | ||
|
|
||
| // Manually applying the XLA_COUNTER macro. |
There was a problem hiding this comment.
nice catch! XLA_COUNTER is designed for being used with different names, I think it is fine to leave it as it is.
|
BTW, reading the boxing doc you shared, I have a question. Are all pytorch/xla ops boexed kernel? |
Nope, all of the pytorch/xla kernels (lowerings) are unboxed kernels, since each kernel is specialized for a specific operator. The main advantage of a boxed kernel is that you can write it once and it's supposed to work for all operators. It does that by having a very specific schema: This CPU fallback is actually one of the first main usages of a boxed fallback kernel, but we do have a couple of others: we have a boxed fallback for batching in-tree, and there are some on-going features being developed that use boxed fallbacks: conjugation of complex tensors, and dispatching to python |
|
FYI if https://github.com/pytorch/xla/pull/2936/files#diff-5e65c3c1d847191cb691d1874732e971f09fa1aad7a980a555c3b0504a5b6470R2454 merge first (seems like it might since pytorch pr is ready), you will need to fix the fallback call here. |
| } | ||
| } | ||
|
|
||
| // Call the actual boxed CPU fallback. |
There was a problem hiding this comment.
Will be nice to note here which device the tensor/ivalues are on here. (I assume it's still xla?
There was a problem hiding this comment.
Yup they should all be XLA (although technically if we had a meta function that allowed mixed device inputs, they could be mixed).
I could tack it onto TF_VLOG(3) << ivalue.toTensor().toString(); if you think that would be useful. is TF_VLOG(3) the right macro to use for general purpose xla logging?
36b745f to
df190b8
Compare
…e kernels when possible
…la_type_default.h
df190b8 to
9cfa18e
Compare
This PR updates pytorch/xla to use a boxed kernel to implement the CPU fallback, rather than relying on code-generation (see the pytorch wiki on boxing/unboxing)
Summary
In the corresponding Pytorch-side PR, I re-wrote the CPU fallbacks for XLA to use a single boxed kernel, instead of code-generating individual CPU fallback kernels for every operators.
This lets us kill a bunch of codegen logic in PyTorch, and simplifies a lot of the codegen, but it means that pytorch/XLA needs to do slightly more work in order to get access to CPU fallback kernels. I added some convenience helper functions in pytorch core to make the amount of extra work minimal.
Registering the CPU Fallback kernel
In
torch_xla/csrc/aten_cpu_fallback.cpp, I added a boxed CPU fallback kernel that has XLA-specific logging, to preserve the same logging behavior that we had before. The kernel is just a function with signaturevoid (const c10::OperatorHandle&, torch::jit::Stack*), that logs some information using XLA macros and then calls into the actual CPU fallback kernel that's implemented in PyTorch core (at::native::cpu_fallback). I then register that fallback to the dispatcher under the XLA key.It technically would have been possible to have the codegen do all of that for you (the boxed kernel logging + dispatcher registration), but the logging is all XLA-specific and seems more reasonable to write directly in pytorch/XLA.
Calling the CPU Fallback
There are also a bunch of places where pytorch/XLA has to explicitly call the CPU fallback, depending on e.g. the op's input shapes. When each operator's CPU fallback was code-generated, we used to call the fallback like this:
Now, we need to call into the boxed kernel. I added a convenience helper function to make calling into it easier, which looks like this:
Where
xla_cpu_fallbackis boxed fallback kernel with xla-specific logging, andATEN_OP2is a new macro that provides the helper function with all of the metadata that it needs to infer some extra information and call the boxed fallback.Performance
It's also worth calling out perf, since the boxed fallback is a little slower than the unboxed, code-generated CPU fallback kernels. I put a more detailed analysis in the bottom of the description of pytorch/pytorch#58065, but the boxed fallback looks like it's on the order of 10-20% slower. I'm hoping that this isn't a huge concern, since we probably want to write XLA lowerings for operators that are perf critical.