Skip to content

Add inner / outer op#348

Merged
awni merged 8 commits intoml-explore:mainfrom
dc-dc-dc:outer
Jan 7, 2024
Merged

Add inner / outer op#348
awni merged 8 commits intoml-explore:mainfrom
dc-dc-dc:outer

Conversation

@dc-dc-dc
Copy link
Contributor

@dc-dc-dc dc-dc-dc commented Jan 3, 2024

Proposed changes

Adds inner / outer op

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great, thank you!

I left a few minor comments. Also I'm wondering why we don't do type promotion in any of those ops (including tensordot). If you just route stuff to matmul/ other pre-existing binary ops it will handle the type promotion so you could remove the type checking.

mlx/ops.cpp Outdated
Comment on lines +2901 to +2911
if (a.ndim() != 1 || b.ndim() != 1) {
throw std::invalid_argument("[outer] a and b must be 1-dimensional.");
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's follow numpy and flatten if not 1d? https://numpy.org/doc/stable/reference/generated/numpy.outer.html

mlx/ops.cpp Outdated
Comment on lines +2909 to +2919
if (a.dtype() != b.dtype()) {
throw std::invalid_argument("[inner] a and b must have the same dtype.");
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Must they have the same type? Why not promote them? (Also now that I see it, why don't we promote in tensordot?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah and since the ops being called handle this can just remove the checks all-together and pass-on. Will remove from tensordot as well.

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!! Minor comment. Please address then we can merge!

Thanks for adding this.

mlx/ops.cpp Outdated
Comment on lines +2906 to +2909
if (a.ndim() > 0) {
t_a = flatten(a, s);
}
return multiply(reshape(t_a, {t_a.shape(0), 1}, s), flatten(b, s), s);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (a.ndim() > 0) {
t_a = flatten(a, s);
}
return multiply(reshape(t_a, {t_a.shape(0), 1}, s), flatten(b, s), s);
return multiply(reshape(a, {a.size(), 1}, s), flatten(b, s), s);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

had to do a static cast for this to compile

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!! Looks great!

@awni awni merged commit 449b437 into ml-explore:main Jan 7, 2024
@dc-dc-dc dc-dc-dc deleted the outer branch January 7, 2024 17:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants