Add RoPE array-offset overload (prep for continuous batching)#305
Add RoPE array-offset overload (prep for continuous batching)#305davidkoski merged 8 commits intoml-explore:mainfrom
Conversation
Ahh, missed that one. Likely Ronan will do that a different way as he has some automated tools for keeping the API in sync. |
|
@davidkoski Ah, that explains the lack of newlines in MLX-c :) Happy to wait until it's exposed using the automated tools. Is there any ETA? |
Not sure, but I can ping Ronan. |
|
@davidkoski I think this is ready to go. |
Yeah, I think this can probably fit in after #319, which should pick up the new mlx/mlx-c dependencies. |
|
Sounds good! |
|
@ronaldmannak I think this can be rebased -- I just merged the v0.30.1 update. I forgot and already cut a tag but I can cut a new one once this is in. |
|
Also, take a look at: Should this have an array offset method? Will it apply to all of the RoPE variants in mlx-swift-lm? (See ml-explore/mlx-swift-lm#29) |
|
Hi David, sorry for the slow reply. I had to spend some time following the thread through the code :) It looks like Possible ways forward:
Either way, to minimize churn in the big mlx-swift-lm #29 change set, I’m inclined to update this PR after #29 has merged (unless you’d prefer to align it sooner). What direction do you prefer, and should this apply across all RoPE variants in mlx-swift-lm? |
I was thinking about that too -- I think maybe a separate protocol is a better idea. That way we don't force all layers to implement it (or end up in a situation where they can't). Timing-wise, that sounds good. I hope to get all that merged early next week. |
|
Early next week sounds good. I’ll wait for the changes to land. I’m not sure whether or not every implementation could or couldn't realistically support the offset: MLXArray overload, so introducing a separate protocol (for layers that can handle per-sequence offsets) seems like the safest option indeed. |
|
@davidkoski I've added the protocol. I think it's good to go |
|
hit a swift-format item |
|
@davidkoski yep sorry, fixed. |
davidkoski
left a comment
There was a problem hiding this comment.
Looks great, thank you!
Proposed changes
Add array offset support to
MLXFast.RoPEto match the Python MLX API.Dependencies
Motivation
The Python
mlx.core.fast.ropefunction acceptsoffsetas either an int or array.This enables several important use cases:
Changes
MLXFast.RoPE(..., offset: MLXArray, ...)overloadMLXNN.RoPE.callAsFunction(_:offset:)overloadExample
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes