Conversation
awni
left a comment
There was a problem hiding this comment.
This is great, thank you!
I left a few minor comments. Also I'm wondering why we don't do type promotion in any of those ops (including tensordot). If you just route stuff to matmul/ other pre-existing binary ops it will handle the type promotion so you could remove the type checking.
mlx/ops.cpp
Outdated
| if (a.ndim() != 1 || b.ndim() != 1) { | ||
| throw std::invalid_argument("[outer] a and b must be 1-dimensional."); | ||
| } |
There was a problem hiding this comment.
Let's follow numpy and flatten if not 1d? https://numpy.org/doc/stable/reference/generated/numpy.outer.html
mlx/ops.cpp
Outdated
| if (a.dtype() != b.dtype()) { | ||
| throw std::invalid_argument("[inner] a and b must have the same dtype."); | ||
| } |
There was a problem hiding this comment.
Must they have the same type? Why not promote them? (Also now that I see it, why don't we promote in tensordot?
There was a problem hiding this comment.
Yeah and since the ops being called handle this can just remove the checks all-together and pass-on. Will remove from tensordot as well.
awni
left a comment
There was a problem hiding this comment.
Looks great!! Minor comment. Please address then we can merge!
Thanks for adding this.
mlx/ops.cpp
Outdated
| if (a.ndim() > 0) { | ||
| t_a = flatten(a, s); | ||
| } | ||
| return multiply(reshape(t_a, {t_a.shape(0), 1}, s), flatten(b, s), s); |
There was a problem hiding this comment.
| if (a.ndim() > 0) { | |
| t_a = flatten(a, s); | |
| } | |
| return multiply(reshape(t_a, {t_a.shape(0), 1}, s), flatten(b, s), s); | |
| return multiply(reshape(a, {a.size(), 1}, s), flatten(b, s), s); |
There was a problem hiding this comment.
had to do a static cast for this to compile
Proposed changes
Adds inner / outer op
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes