Skip to content

synchronize MLXNN code with python implementation#340

Merged
davidkoski merged 9 commits intomainfrom
gaps
Jan 22, 2026
Merged

synchronize MLXNN code with python implementation#340
davidkoski merged 9 commits intomainfrom
gaps

Conversation

@davidkoski
Copy link
Collaborator

Proposed changes

Looks like some changes made after the initial port that were missed on the swift side.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@davidkoski davidkoski requested a review from awni January 16, 2026 21:17
_ f: @Sendable @escaping (MLXArray, MLXArray, MLXArray) -> MLXArray
)
-> (MLXArray, MLXArray, MLXArray) -> MLXArray
-> @Sendable (MLXArray, MLXArray, MLXArray) -> MLXArray
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noticed this was missing while working on this (match the other compile() implementations)

/// - Parameters:
/// - x: input array
/// - lambda: lambda value
public func softshrink(_ x: MLXArray, lambda: Float = 0.5) -> MLXArray {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Several missing activations

/// ### See Also
/// - <doc:activations>
/// - ``softmin(_:axis:)``
open class Softmin: Module, UnaryLayer {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And their layers

[outputChannels, kernelSize.first, kernelSize.second, kernelSize.third, inputChannels])
[
outputChannels, kernelSize.first, kernelSize.second, kernelSize.third,
inputChannels / groups,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inputChannels should be dived by groups. curiously the swift version has groups for all 3 convolutions but python only has it for 1d and 2d 🤷

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should fix that in Python.

kernelSize: Int,
stride: Int = 1,
padding: Int = 0,
outputPadding: Int = 0,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New parameter on the transposed convolutions

public init(embeddingCount: Int, dimensions: Int) {
let scale = sqrt(1 / Float(dimensions))
self.weight = MLXRandom.normal([embeddingCount, dimensions]) * scale
self.weight = MLXRandom.normal([embeddingCount, dimensions], scale: scale)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should have the same result but match the python implementation

/// - affine: if `true` adds a trainable `weight`
/// - bias: if `true` adds a trainable `bias`
public init(
dimensions: Int, eps: Float = 1e-5, affine: Bool = true, bias: Bool = true
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the python side a bias flag was split out from affine.

public let kernelSize: [Int]
public let stride: [Int]
public let padding: [Int]
public let paddingValue: Float
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add missing padding/paddingValue for base implementation

}

/// Applies 3-dimensional max pooling.
open class MaxPool3d: Pool {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add missing pool 3d layers

private func nearestIndices(dimension: Int, scale: Float, dim: Int, ndim: Int) -> MLXArray {
scaledIndices(dimension: dimension, scale: scale, alignCorners: true, dim: dim, ndim: ndim)
.asType(.int32)
private func nearestIndices(dimension N: Int, scale: Float, dim: Int, ndim: Int) -> MLXArray {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Match the python implementation, specifically the 0.5 offset below

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome. That's a good diff! Thanks David and Claude ;)

@davidkoski davidkoski merged commit 0a6df65 into main Jan 22, 2026
7 checks passed
@davidkoski davidkoski deleted the gaps branch January 22, 2026 19:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants