A from-scratch implementation of linear regression in Rust, built while working through Grokking Machine Learning by Luis Serrano.
I wanted to internalize how gradient descent actually works rather than just calling sklearn.fit(), so I implemented the three foundational approaches the book covers for fitting a line to data.
All three approaches follow the same basic loop:
- Pick a point from the dataset
- Predict its y value using the current line
- Adjust the line to reduce the error
- Repeat for many epochs
How each algorithm adjusts the line is what differentiates them:
The most naive approach: if the prediction is too high, nudge the intercept down; if too low, nudge it up. The slope adjusts in the direction that would help at the current point, but always by a fixed amount. No consideration of how wrong the prediction was.
Converges, but slowly compared to the other methods.
Adjusts the line proportionally to the error. Large errors produce large corrections; small errors produce small corrections. This is gradient descent on the Mean Squared Error loss function.
Converges quickly and reliably.
Adjusts the line in the direction of the error, but by a fixed step size regardless of the error magnitude. This is gradient descent on the Mean Absolute Error loss function.
More robust to outliers than the square trick, but converges more slowly.
==========================================
Simple Trick (1000 epochs)
Starting Line(5.6214, 97.0906)
Targeting Line(66.6462, -89.1658)
Epoch 0: Line(5.6214, 96.0906)
Epoch 100: Line(5.6214, -3.9094)
Epoch 200: Line(33.7134, -75.8174)
Epoch 300: Line(66.6394, -89.1714)
...
Final Line(66.6394, -89.1714) (should be Line(66.6462, -89.1658))
Simple Trick Root Mean Square Error: 0.031
==========================================
Square Trick (1000 epochs)
Starting Line(5.6214, 97.0906)
Targeting Line(66.6462, -89.1658)
Epoch 0: Line(35.4475, -19.1042) <-- big jump
Epoch 100: Line(66.6451, -89.1666)
...
Final Line(66.6451, -89.1666) (should be Line(66.6462, -89.1658))
Square Trick Root Mean Square Error: 0.030
==========================================
Absolute Trick (1000 epochs)
Starting Line(5.6214, 97.0906)
Targeting Line(66.6462, -89.1658)
Epoch 0: Line(5.6224, 96.0906)
Epoch 100: Line(5.7224, -3.9094)
Epoch 200: Line(28.0584, -74.9094)
Epoch 300: Line(66.6345, -89.1694)
...
Final Line(66.6352, -89.1674) (should be Line(66.6462, -89.1658))
Absolute Trick Root Mean Square Error: 0.030
All three converge to similar accuracy. The square trick makes large corrections early (note the jump at epoch 0) and settles quickly. The simple and absolute tricks take smaller fixed steps, so they need more epochs to arrive at the same place.
cargo runGenerates a random target line, distributes noisy points along it, then trains each algorithm from the same random starting line and compares results.