GPU friendly approximate matrix exp and log

Aug 1, 2024   #Numerical approximation 

Matrix exp and log calculations typically involve eigen-decomposition which is not very GPU friendly. For matrices with “well behaved” eigenvalues, the calculation approaches in this post seem to work sufficiently well without using eigen decomposition or matrix inverse.

status: DRAFT

Rigorous numerical analysis is pending, but in tests on matrix sizes up to 1024x1024, as long as the eigen values won’t cause the functions to blow up – small enough eigenvalues (to not blow up \(e^\lambda\)) in the case of \(\exp({\bf M})\) and eigenvalues with magnitude sufficiently greater than 0 in the case of \(\log({\bf M})\). Both the algorithms described here work for complex matrices too as long as those conditions are met.

Jump links -

Matrix exponentiation

The idea behind this is simple - solve the following matrix differential equation

\[ \frac{d{\bf X}}{dt} = {\bf M}{\bf X} \]

… which for a constant matrix \({\bf M}\) has the solution –

\[ {\bf X}(t) = {e^{{\bf M}t}}{\bf X}(0) \]

Here, \({\bf M}\) is a square matrix of dimensions \(n \times n\) and so is \({\bf X}\). So if we start with \({\bf X}(0) = {\bf 1}\) (the identity matrix) at \(t = 0\) and integrate till \(t = 1\), then we’ll have \({\bf X}(1) = e^{\bf M}\).

First off, this can also be calculated using a straight forward Taylor series approximation in the way shown below (code in Julia). Our constraint here is that we only want to use matrix multiplication and addition/scaling for our calculations so they can be GPU-ified easily.

using LinearAlgebra, Metal, BenchmarkTools

# A simple relative approximation error where m2 is the ground truth
# calculated using eigen decomposition on the CPU.
function relerr(m1,m2)
    sqrt(sum(abs2, m1 - m2) / sum(abs2, m2))
end

function exp_taylor(M :: T;eps=0.1) where {T <: AbstractArray}
    N = size(M)[1]
    # Start with identity matrix. Doing T() ensures that
    # if the input matrix is a GPU array, then the identity
    # maatrix is also on the GPU.
    eM = T(Diagonal(ones(Float32, N)))

    # Sum terms of the Taylor expansion until the delta
    # is below some set epsilon.
    Mk = M
    k = 1.0f0
    while sum(abs2, Mk) > eps * eps * N*N
        # Do two terms within loop to reduce the cost of the
        # termination condition check.
        eM += Mk
        k += 1.0f0
        Mk = M * Mk / k

        eM += Mk
        k += 1.0f0
        Mk = M * Mk / k
    end
    
    return eM
end

I used random matrices as follows -

M = rand(Float32, (1024,1024)) .- 0.5f0
Mc = rand(ComplexF32, (1024,1024)) .- (0.5f0 + im * 0.5f0);
mM = MtlArray(M)
mMc = MtlArray(Mc)

The baseline CPU version of exp(M) in Julia benchmarks to a media timing of about 102ms on an M1 MacBook Air -

CPU version of exp(M)

For a relative error of \(2\times 10^{-4}\), exp_taylor(mM) benchmarks at about the same on the GPU using Metal, where mM = MtlArray(M) is evaluated before benchmarking.

GPU version of exp_taylor(mM)

The approach using numerical integration is given below. It expands the matrix using the Taylor series for a one small step \(X(1/2^n)\) and then repeatedly squares it to get the final result for \(X(1)\)

function exp_deq(M :: T;steps=4,prec=6) where {T <: AbstractArray}
    N = size(M)[1]
    # Start with identity matrix. Doing T() ensures that
    # if the input matrix is a GPU array, then the identity
    # maatrix is also on the GPU.
    eM = T(Diagonal(ones(Float32,N)))

    # Integrate in small "time steps" of f
    # Determine one time step using `prec` terms
    # of the Taylor series for matrix exponentiation.
    f = Float32(1 / 2^steps)
    k = 1
    fM = f * M
    Macc = fM
    MM = fM
    for k in 2:prec
        MM *= Float32(1/k) * fM
        Macc += MM
    end

    # Repeatedly square the approximate value for one
    # time step to get the final value.
    eM += Macc
    for i in 1:steps
        eM *= eM
    end
    return eM
end

This version benchmarks to about 14ms on the GPU for an error of \(4 \times 10^{-5}\), which is a significant improvement in both the accuracy as well as speed.

GPU version of exp_deq(mM)

Matrix logarithm

Approximate matrix logarithms are useful during entropy calculations of quantum systems. But the logarithm is a tricky beast. I didn’t read too widely about existing algorithms but the ones I read pointed to approximating \(\log({\bf M})\) by first using an effficient square rooting algorithm to calculate \(\log({\bf M}^{1/2^n})\) – since \({\bf M}^{1/2^n}\) will be close to \({\bf 1}\), permitting Taylor series expansion – and then multiply the result by \(2^n\). The numerical behaviour of this algorithm is known. However, the square rooting operation requires the calculation of the matrix inverse, which is better avoided on the GPU too.

It then seemed like an interesting problem to take on, even if I hit on a solution that others have already worked out. If it’s worked out like this, I’d like it to be posted on a blog like this too please :)

So we take the approach of trying to construct a differential equation whose solution will give us \(\log({\bf M})\).

Consider the following function –

\[ {\bf Y}(t) = \log({\bf 1} + t({\bf M} - {\bf 1})) \]

where \(t\) is a scalar and \({\bf 1}\) is the identity matrix of the same dimensions as \({\bf M}\). We want \({\bf Y}(1)\). Rewriting it, we have

\[ \begin{array}{rcl} e^{{\bf Y}(t)} & = & {\bf 1} + t({\bf M} - {\bf 1}) \end{array} \]

Differentiating w.r.t. \(t\), we get –

\[ \begin{array}{rcl} e^{{\bf Y}(t)} {{\bf Y}’(t)} & = & {\bf M} - {\bf 1} \\ => {\bf Y}’(t) & = & e^{-{\bf Y}(t)}({\bf M} - {\bf 1}) \end{array} \]

If we now take \({\bf Z}(t) = e^{-{\bf Y}(t)}\) and differentiate it, we can rewrite that as –

\[ \begin{array}{rcl} {\bf Y}’(t) & = & {\bf Z}(t)({\bf M} - {\bf 1}) \\ {\bf Z}’(t) & = & - {\bf Z}(t) {\bf Y}’(t) \end{array} \]

We now have a pair of coupled differential equations that we can integrate numerically in lock step over \(t \in [0,1]\) to get \({\bf Y}(1)\) and \({\bf Z}(1)\).

To get a better approximation step (central difference), we can use an estimate of the second derivative as well –

\[ \begin{array}{rcl} {\bf Y}’’(t) & = & {\bf Z}’(t)({\bf M} - {\bf 1}) \\ {\bf Z}’’(t) & = & - {\bf Z}’(t) {\bf Y}’(t) - {\bf Z}(t) {\bf Y}’’(t) \end{array} \]

That gives us the following program –

function log_approx2(M :: T; da = 0.2f0) where {T <: AbstractArray}
    N = size(M)[1]
    Y = T(zeros(Float32, (N,N)))
    I = T(Diagonal(ones(Float32, N)))
    M1 = M - I
    Z = I
    for t in 0.0f0:da:(1.0f0-da)
        dY = da * Z * M1
        dZ = -Z * dY
        d2Y = dZ * M1 * da
        d2Z = - dZ * dY - Z * d2Y
        Y += dY + 0.5f0 * d2Y
        Z += dZ + 0.5f0 * d2Z
    end
    return Y, Z     
end

Construct a random matrix with positive eigenvalues sufficiently away from zero like this -

function rpmat(N)
    M = rand(Float32, (N,N)) .- 0.5f0
    e = Diagonal(0.5f0 .+ rand(Float32, N))
    M * e * inv(M)
end

pM = rpmat(1024)
mpM = MtlArray(pM)

The baseline CPU version of log(pM) benchmarks to a median 750ms.

CPU matrix logarithm

To a relative error of \(5\times 10^{-3}\), the approximate matrix logarithm (log_approx2(mpM)) benchmarks to a median of 43ms on the GPU (of an M1 MacBook Air), which looks pretty good and usable for some calculations.

GPU approx matrix logarithm

A nice side benefit of log_approx2 is that it gives us the matrix logarithm as well as the matrix inverse at one go, since \({\bf Z}(1) = e^{-{\bf Y}(1)} = {\bf M}^{-1}\). Yay! … though I don’t know where getting them both together is of use. … and we’ve only used ordinary matrix operations like addition, scalar multiplication and matrix multiplication in the process.

Enjoy! … and let me know if this is useful, already implemented in some library, etc.

Addendum: The pair of equations we’re integrating aren’t really “coupled” very much. While we have \({\bf Y}’(t) = {\bf Z}(t)({\bf M} - {\bf 1})\), \({\bf Z}\)’s equation is \({\bf Z}’(t) = - {\bf Z}(t){\bf Y}’(t)\), so the second equation can be expanded to \({\bf Z}’(t) = - {\bf Z}(t){\bf Z}(t)({\bf M} - {\bf 1})\), which is independent of \({\bf Y}\) or \({\bf Y}’(t)\). \({\bf Z}\)’s equation essentially solves for \(({\bf 1} + t({\bf M} - {\bf 1}))^{-1}\) incrementally and we use those incremental values to solve for \({\bf Y}(t)\).

Improving the approximation

The integration for calculation \(\log({\bf M})\) can be improved for better accuracy using Runge Kutta methods though with slightly slower performance. Below is the same log_approx function implemented using RK4 –

function log_approx_rk4(M :: T; da = 0.2f0) where {T <: AbstractArray}
    N = size(M)[1]
    Y = T(zeros(Float32, (N,N)))
    I = T(Diagonal(ones(Float32, N)))
    M1 = M - I
    Z = I
    M1da = M1 * da
    for t in da:da:1.0f0
        Zk1 = - Z * Z * M1
        Z1  = Z + 0.5f0 * da * Zk1
        Zk2 = - Z1 * Z1 * M1
        Z2  = Z + 0.5f0 * da * Zk2
        Zk3 = - Z2 * Z2 * M1
        Z3  = Z + da * Zk3
        Zk4 = - Z3 * Z3 * M1
        dZ  = Float32(1/6) * (Zk1 + 2 * Zk2 + 2 * Zk3 + Zk4)

        Yk1 = Z * M1
        Y1  = Y + 0.5f0 * da * Yk1
        Yk2 = Z1 * M1
        Y2  = Y + 0.5f0 * da * Yk2
        Yk3 = Z2 * M1
        Y3  = Y + da * Yk3
        Yk4 = Z3 * M1
        dY  = Float32(1/6) * (Yk1 + 2 * Yk2 + 2 * Yk3 + Yk4)

        Y   += da * dY
        Z   += da * dZ
    end
    return Y, Z        
end

For a similar matrix, log_approx_rk4 gives a relative error of \(1.9 \times 10^{-5}\) for a median runtime of 115ms, which is still about 6x the speed of the CPU reference implementation. For matrices with smaller eigenvalues, we’ll need to reduce the “time step” for better accuracy and the RK4 version still performs better with that.

Runge Kutta method

Clarification on derivative step

We wrote earlier that when we differentiate \(e^{{\bf Y}(t)}\) w.r.t \(t\), we get \(e^{{\bf Y}(t)}{\bf Y}’(t)\). This works because \({\bf Y}(t)\) and \({\bf Y}’(t)\) commute as shown below -

\[ \begin{array}{rcl} e^{{\bf Y}(t)} & = & {\bf 1} + t({\bf M}-{\bf 1}) \\ => \frac{d}{dt}(e^{{\bf Y}(t)}) & = & {\bf M}-{\bf 1} \\ => \lim\limits_{dt \to 0}{\frac{e^{{\bf Y}(t+dt)} - e^{{\bf Y}(t)}}{dt}} & = & {\bf M} - {\bf 1} \\ \end{array} \]

So, that hangs on \(e^{{\bf Y}(t) + {\bf Y}’(t)dt} = e^{{\bf Y}(t)}e^{{\bf Y}’(t)dt} = e^{{\bf Y}’(t)dt}e^{{\bf Y}(t)}\)

… which is true if and only if –

\[ \begin{array}{rcl} ({\bf 1}+t({\bf M}-{\bf 1}))({\bf 1} + {\bf Y}’(t)dt) & = & ({\bf 1} + {\bf Y}’(t)dt)({\bf 1}+t({\bf M}-{\bf 1})) \\ \end{array} \]

… which holds (for \(t > 0\)) if and only if –

\[ \begin{array}{rcll} {\bf M}{\bf Y}’(t) & = & {\bf Y}’(t){\bf M} & ({\mbox{for } t > 0}) \end{array} \]

Since \(\log({\bf 1} + t({\bf M}-{\bf 1}))\) can be Taylor expanded on \(t({\bf M} - {\bf 1})\), where the scalar \(t\) is the only variable, the derivative w.r.t. \(t\) is still a polynomial in \(({\bf M} - {\bf 1})\) and therefore commutes with \({\bf M}\).

So we have \(\frac{d}{dt}(e^{{\bf Y}(t)}) = e^{{\bf Y}(t)}{\bf Y}’(t) = {\bf Y}’(t)e^{{\bf Y}(t)}\)

Consequently, we see the same metrics turn up when we replace the differential equations with \({\bf Y}’(t) = ({\bf M} - {\bf 1}){\bf Z}(t)\) and \({\bf Z}’(t) = - {\bf Y}’(t) {\bf Z}(t)\).