Conversation
| std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> | ||
| XLANativeFunctions::_embedding_bag_forward_only( | ||
| const at::Tensor& weight, const at::Tensor& indices, | ||
| const at::Tensor& offsets, bool scale_grad_by_freq, int64_t mode, | ||
| bool sparse, const c10::optional<at::Tensor>& per_sample_weights, | ||
| bool include_last_offset, int64_t padding_idx) { | ||
| TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); | ||
| if (mode == 1 || scale_grad_by_freq || sparse || padding_idx != -1) { | ||
| return at::native::call_fallback_fn< | ||
| &xla_cpu_fallback, | ||
| ATEN_OP(_embedding_bag_forward_only)>::call(weight, indices, offsets, | ||
| scale_grad_by_freq, mode, | ||
| sparse, per_sample_weights, | ||
| include_last_offset, | ||
| padding_idx); | ||
| } |
There was a problem hiding this comment.
@bhavya01 Is there a reason only _embedding_bag_forward_only (instead of also lowering _embedding_bag)? What about the fallback condition? Is there a specific reason we are not lowering that, too?
There was a problem hiding this comment.
If I recall correctly, the code for the other rest of _embedding_bag was overly complicated so we measured it as out of scope for this PR. @bhavya01, please correct me if I'm wrong.
There was a problem hiding this comment.
That's right! We still need to lower _embedding_bag.
There was a problem hiding this comment.
Hey @ysiraichi, would lowering _embedding_bag entirely be something that you need?
There was a problem hiding this comment.
Well, it would be nice having that. I won't be working on it right now, since I have things with higher priority to be done. Anyway, I have this draft branch that can help us lowering + maintaining composite operations. It does mainly 2 things:
- Allow us to write lowerings for composite operations in Python
- Check for implemented decompositions (PyTorch + PyTorch/XLA) whenever we hit the fallback function. If it finds one, use that to decompose the operation into possibly already decomposed operations
- At the moment, PyTorch decompositions are only used in dynamo
- Allow us to use PyTorch decompositions on non-dynamo experiments
- Allow us to use PyTorch/XLA specific decompositions on both dynamo and non-dynamo experiments
I won't be working on this PR for a while, so if anyone wants to take over it, I don't mind. If this PR gets merged, we would probably have an easier time lowering operations.
No description provided.