3D PCA Scatterplot Graph with Python

As part of the series of tutorials on PCA with Python, we will learn how to plot a 3D PCA graph (scatter plot) on the Iris Dataset with Python, Scikit-learn and Matplotlib.

What is 3D PCA Scatter plot?

A 3D PCA (Principal Component Analysis) scatter plot is a PCA visualization that shows the distribution of data points in a 3D space after reducing a dataset to 3 PCA features.

How to Plot a 3D PCA Graph in Python?

To plot a 3D PCA scatter plot in Python, reduce the number of features to 3 principal components. After, use matplotlib to generate a three dimensional scatterplot from the data.

Join the Newsletter

    Here are the detailed steps to plot a 3D PCA scatter plot in Python:

    1. Load the required Python Libraries
    2. Load your Dataset
    3. Set up a 3D plotting environment
    4. Assign PCA Features to their own Axes of the Scatter Plot
    5. Plot the 3D PCA Graph using scatter3D
    6. Interpret the 3D PCA Scatterplot

    1. Loading the Required Python Libraries

    import matplotlib.pyplot as plt 
    import numpy as np
    import pandas as pd 
    from sklearn import datasets
    from sklearn.preprocessing import StandardScaler
    from sklearn.decomposition import PCA
    plt.style.use('default')
    

    2. Loading the Iris Dataset in Python

    To start, let’s load the Iris dataset in Python.

    # load features and targets separately
    iris = datasets.load_iris()
    X = iris.data
    y = iris.target
    

    From this data, we will learn various ways to plot the 3D PCA graph with Python.

    3. Scale and Reduce the Number of Features Using PCA

    Next, do some preprocessing and use PCA to reduce the dataset to 3 features. Scale the data before applying PCA, and select the n_component to be equal to 3. To learn what this means, follow our tutorial on PCA with Python.

    # Data Scaling
    x_scaled = StandardScaler().fit_transform(X)
    
    # Reduce from 4 to 3 features with PCA
    pca = PCA(n_components=3)
    
    # Fit and transform data
    pca_features = pca.fit_transform(x_scaled)
    

    4. Set up a 3D Plotting Environment in Matplotlib

    Sett up a 3D plotting environment in matplotlib using plt.axes(projection='3d').

    ax = plt.axes(projection='3d')

    Let’s see an example by plotting our selected features into a 3D graph.

    # Prepare 3D graph
    fig = plt.figure()
    ax = plt.axes(projection='3d')
    

    5. Assign PCA Features to their own Axes of the Scatter Plot

    Before we can plot the data, we need to set-up the data for the x, y and z axes of the 3D scatter plot. Each feature will be on its own axis.

    # Plot scaled features
    xdata = pca_features[:,0]
    ydata = pca_features[:,1]
    zdata = pca_features[:,2]
    

    6. Plot the 3D PCA Graph using scatter3D

    To plot the 3D PCA graph in Python, use ax.scatter3D with the x, y and z data as its argument, mapping each PCA feature to its own axes in the scatter plot.

    # Plot 3D plot
    ax.scatter3D(xdata, ydata, zdata, c=zdata, cmap='viridis')
    
    # Plot title of graph
    plt.title(f'3D Scatter of Iris')
    
    # Plot x, y, z even ticks
    ticks = np.linspace(-3, 3, num=5)
    ax.set_xticks(ticks)
    ax.set_yticks(ticks)
    ax.set_zticks(ticks)
    
    # Plot x, y, z labels
    ax.set_xlabel('sepal_length', rotation=150)
    ax.set_ylabel('sepal_width')
    ax.set_zlabel('petal_length', rotation=60)
    plt.show()
    
    3D scatterplot of Iris dataset

    6. Interpret the 3D PCA Scatterplot

    When plotting a 3D graph, it is clearer that there is less variance in Petal length of Iris flowers than in Sepal length or Sepal width, almost making a flat 2D pane inside the 3D graph. That shows that the intrinsic dimension of the data is essentially 2 dimensions instead of 4.

    Reducing these 3 features to 2 would not only make the model faster but the visualizations more informative without losing too much information.

    Next Steps

    After plotting a 3D PCA Scatterplot, it is interesting to learn how to plot a 3D PCA Biplot.

    Full Code

    import matplotlib.pyplot as plt 
    import numpy as np
    import pandas as pd 
    from sklearn import datasets
    from sklearn.preprocessing import StandardScaler
    from sklearn.decomposition import PCA
    plt.style.use('default')
    
    # load features and targets separately
    iris = datasets.load_iris()
    X = iris.data
    y = iris.target
    
    # Data Scaling
    x_scaled = StandardScaler().fit_transform(X)
    
    # Dimention Reduction
    pca = PCA(n_components=3)
    pca_features = pca.fit_transform(x_scaled)
     
    
    # Prepare 3D graph
    fig = plt.figure()
    ax = plt.axes(projection='3d')
    
    # Plot scaled features
    xdata = pca_features[:,0]
    ydata = pca_features[:,1]
    zdata = pca_features[:,2]
    
    # Plot 3D plot
    ax.scatter3D(xdata, ydata, zdata, c=zdata, cmap='viridis')
    
    # Plot title of graph
    plt.title(f'3D Scatter of Iris')
    
    # Plot x, y, z even ticks
    ticks = np.linspace(-3, 3, num=5)
    ax.set_xticks(ticks)
    ax.set_yticks(ticks)
    ax.set_zticks(ticks)
    
    # Plot x, y, z labels
    ax.set_xlabel('sepal_length', rotation=150)
    ax.set_ylabel('sepal_width')
    ax.set_zlabel('petal_length', rotation=60)
    plt.show()
    
    Enjoyed This Post?