Skip to content

Lower embedding bag forward only#6951

Merged
bhavya01 merged 2 commits intomasterfrom
embeddingbag
Apr 22, 2024
Merged

Lower embedding bag forward only#6951
bhavya01 merged 2 commits intomasterfrom
embeddingbag

Conversation

@bhavya01
Copy link
Copy Markdown
Collaborator

No description provided.

@bhavya01 bhavya01 requested a review from wonjoo-wj April 22, 2024 18:57
@bhavya01 bhavya01 closed this Apr 22, 2024
@bhavya01 bhavya01 reopened this Apr 22, 2024
Copy link
Copy Markdown
Collaborator

@wonjoo-wj wonjoo-wj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@bhavya01 bhavya01 merged commit 46919a4 into master Apr 22, 2024
@bhavya01 bhavya01 deleted the embeddingbag branch April 22, 2024 22:28
Comment on lines +1293 to +1308
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
XLANativeFunctions::_embedding_bag_forward_only(
const at::Tensor& weight, const at::Tensor& indices,
const at::Tensor& offsets, bool scale_grad_by_freq, int64_t mode,
bool sparse, const c10::optional<at::Tensor>& per_sample_weights,
bool include_last_offset, int64_t padding_idx) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (mode == 1 || scale_grad_by_freq || sparse || padding_idx != -1) {
return at::native::call_fallback_fn<
&xla_cpu_fallback,
ATEN_OP(_embedding_bag_forward_only)>::call(weight, indices, offsets,
scale_grad_by_freq, mode,
sparse, per_sample_weights,
include_last_offset,
padding_idx);
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bhavya01 Is there a reason only _embedding_bag_forward_only (instead of also lowering _embedding_bag)? What about the fallback condition? Is there a specific reason we are not lowering that, too?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I recall correctly, the code for the other rest of _embedding_bag was overly complicated so we measured it as out of scope for this PR. @bhavya01, please correct me if I'm wrong.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right! We still need to lower _embedding_bag.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ysiraichi, would lowering _embedding_bag entirely be something that you need?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it would be nice having that. I won't be working on it right now, since I have things with higher priority to be done. Anyway, I have this draft branch that can help us lowering + maintaining composite operations. It does mainly 2 things:

  • Allow us to write lowerings for composite operations in Python
  • Check for implemented decompositions (PyTorch + PyTorch/XLA) whenever we hit the fallback function. If it finds one, use that to decompose the operation into possibly already decomposed operations
    • At the moment, PyTorch decompositions are only used in dynamo
    • Allow us to use PyTorch decompositions on non-dynamo experiments
    • Allow us to use PyTorch/XLA specific decompositions on both dynamo and non-dynamo experiments

I won't be working on this PR for a while, so if anyone wants to take over it, I don't mind. If this PR gets merged, we would probably have an easier time lowering operations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants