Skip to content

ndouglas/linear-regression

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

linear-regression

A from-scratch implementation of linear regression in Rust, built while working through Grokking Machine Learning by Luis Serrano.

I wanted to internalize how gradient descent actually works rather than just calling sklearn.fit(), so I implemented the three foundational approaches the book covers for fitting a line to data.

The Algorithms

All three approaches follow the same basic loop:

  1. Pick a point from the dataset
  2. Predict its y value using the current line
  3. Adjust the line to reduce the error
  4. Repeat for many epochs

How each algorithm adjusts the line is what differentiates them:

Simple Trick

The most naive approach: if the prediction is too high, nudge the intercept down; if too low, nudge it up. The slope adjusts in the direction that would help at the current point, but always by a fixed amount. No consideration of how wrong the prediction was.

Converges, but slowly compared to the other methods.

Square Trick (Gradient Descent on MSE)

Adjusts the line proportionally to the error. Large errors produce large corrections; small errors produce small corrections. This is gradient descent on the Mean Squared Error loss function.

Converges quickly and reliably.

Absolute Trick (Gradient Descent on MAE)

Adjusts the line in the direction of the error, but by a fixed step size regardless of the error magnitude. This is gradient descent on the Mean Absolute Error loss function.

More robust to outliers than the square trick, but converges more slowly.

Sample Output

==========================================
Simple Trick (1000 epochs)
Starting Line(5.6214, 97.0906)
Targeting Line(66.6462, -89.1658)
Epoch 0: Line(5.6214, 96.0906)
Epoch 100: Line(5.6214, -3.9094)
Epoch 200: Line(33.7134, -75.8174)
Epoch 300: Line(66.6394, -89.1714)
...
Final Line(66.6394, -89.1714) (should be Line(66.6462, -89.1658))
Simple Trick Root Mean Square Error: 0.031

==========================================
Square Trick (1000 epochs)
Starting Line(5.6214, 97.0906)
Targeting Line(66.6462, -89.1658)
Epoch 0: Line(35.4475, -19.1042)   <-- big jump
Epoch 100: Line(66.6451, -89.1666)
...
Final Line(66.6451, -89.1666) (should be Line(66.6462, -89.1658))
Square Trick Root Mean Square Error: 0.030

==========================================
Absolute Trick (1000 epochs)
Starting Line(5.6214, 97.0906)
Targeting Line(66.6462, -89.1658)
Epoch 0: Line(5.6224, 96.0906)
Epoch 100: Line(5.7224, -3.9094)
Epoch 200: Line(28.0584, -74.9094)
Epoch 300: Line(66.6345, -89.1694)
...
Final Line(66.6352, -89.1674) (should be Line(66.6462, -89.1658))
Absolute Trick Root Mean Square Error: 0.030

All three converge to similar accuracy. The square trick makes large corrections early (note the jump at epoch 0) and settles quickly. The simple and absolute tricks take smaller fixed steps, so they need more epochs to arrive at the same place.

Running It

cargo run

Generates a random target line, distributes noisy points along it, then trains each algorithm from the same random starting line and compares results.

About

Simple implementation of linear regression in Rust as a learning project.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages