Added MLXCustomFunction + MLXClosure#301
Conversation
| import MLX | ||
|
|
||
| @main | ||
| struct GridSampleExample { |
There was a problem hiding this comment.
These examples might make good tests -- both to show an example of use and to exercise it on the CI system
Source/MLX/MLXClosure.swift
Outdated
| import Foundation | ||
|
|
||
| // MARK: - MLXClosure wrapper | ||
| public struct MLXClosure { |
There was a problem hiding this comment.
I don't think we want MLXClosure as a public type. Look at how compile works (e.g. CompiledFunction to encapsulate the state) and a signature like this:
public func compile(
inputs: [any Updatable] = [], outputs: [any Updatable] = [], shapeless: Bool = false,
_ f: @escaping ([MLXArray]) -> [MLXArray]
) -> @Sendable ([MLXArray]) -> [MLXArray] {
I think this should be something similar -- it will hold the mlx_closure and look like a callable function.
There was a problem hiding this comment.
Agreed, migrated everything inside the MLXCustomFunction
Source/MLX/MLXCustomFunction.swift
Outdated
|
|
||
| // MARK: - Example: bridging helpers | ||
| // Convert a C mlx_vector_array pointer to a Swift array of MLXArray | ||
| func mlx_vector_array_to_swift(_ v: UnsafePointer<mlx_vector_array>?) -> [MLXArray] { |
There was a problem hiding this comment.
See mlx_vector_array_values
Source/MLX/MLXCustomFunction.swift
Outdated
| } | ||
|
|
||
| // Convert Swift array of MLXArray back to C mlx_vector_array | ||
| func swift_to_mlx_vector_array(_ arrays: [MLXArray]) -> mlx_vector_array { |
There was a problem hiding this comment.
See new_mlx_vector_array
Source/MLX/MLXCustomFunction.swift
Outdated
|
|
||
| // MARK: - Result Builder | ||
| @resultBuilder | ||
| public enum MLXCustomFunctionBuilder { |
There was a problem hiding this comment.
This looks like an interesting approach -- I wonder if it would make sense to make a function (something like customFunction()) that takes a @MLXCustomFunctionBuilder argument and encapsulate the use of this.
Something like:
let f = customFunction {
Forward { ... }
VJP { ... }
}I am not sure if that is the exact syntax but perhaps something like that is doable.
There was a problem hiding this comment.
Oh, that's great, just moved to use this syntax.
|
Looks really interesting! |
…ied customFunction
|
@davidkoski thank you for the feedback, just pushed the changes, let me know if anything else stands out. |
davidkoski
left a comment
There was a problem hiding this comment.
Very cool, thank you!
Proposed changes
Implemented custom function and VJP functionality in Swift.
Based on python implementation.
Examples implemented follow Python structure too.
Feedback and suggestions are welcome.
Addresses
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes