Skip to content

use batched gemm from mkl on torch.bmm when mkl is available#11365

Closed
mingfeima wants to merge 1 commit intopytorch:masterfrom
mingfeima:bmm
Closed

use batched gemm from mkl on torch.bmm when mkl is available#11365
mingfeima wants to merge 1 commit intopytorch:masterfrom
mingfeima:bmm

Conversation

@mingfeima
Copy link
Copy Markdown
Collaborator

This PR uses mkl batched gemm for torch.bmm when mkl is available. The current logic dealing with torch.bmm is to do batch_size iterations of gemm. From the performance point of view, this should be OK in case the gemm size is large enough. However, in many cases, the gemm size is relatively small and not efficient.

One scenario it globalAttention calculation of NMT, where
mat1: N * 1 * T
mat2: N * T * H
N refers to batch size, T refers to timestep and H is the hidden size.
there the gemm size is relatively small, MKL has batched gemm APIs which is beneficial in case dealing with batched small gemms.

The following script is used for benchmarking and testing the PR. On Xeon skylake 8180 (2 sockets * 28 cores), it runs 0.81ms without the PR and 0.45ms with the PR.

import torch
from time import time
import os

N = 128
T = 30
H = 500
count = 1000

def bench_bmm():
    mat1 = torch.randn(N, 1, T)
    mat2 = torch.randn(N, T, H)

    tstart = time()
    for i in range(count):
        res = torch.bmm(mat1, mat2)
    tend = time()
    t = (tend-tstart)/count*1000
    flops = N*1*T*H*2 / t / 1000000
    print("run torch.bmm:")
    print("total time     : %.2f s" % (tend-tstart))
    print("each iteration : %.2f ms %.2f GFlops" % (t, flops))

def test_bmm(trans_A=False, trans_B=False):
    I = 10
    print("testing torch.bmm mat1: %s, mat2 %s" % ("T" if trans_A else "N", "T" if trans_B else "N"))
    mat1 = torch.randn(N, T, I).transpose(1, 2) if trans_A else torch.randn(N, I, T)
    mat2 = torch.randn(N, H, T).transpose(1, 2) if trans_B else torch.randn(N, T, H)
    mat1_ = mat1.clone()
    mat2_ = mat2.clone()

    res = torch.bmm(mat1, mat2)
    res_ = torch.Tensor(N, I, H)
    for i in range(N):
        res_[i] = mat1_[i].matmul(mat2_[i])

    for ii in range(N):
        for jj in range(I):
            for kk in range(H):
                val1 = res[ii][jj][kk]
                val2 = res_[ii][jj][kk]
                if abs(val1-val2) < 1e-5:
                    continue
                else:
                    print(res[ii][jj][kk], res_[ii][jj][kk], "not equal, case FAIL")
                    return

    print("PASS")

bench_bmm()
test_bmm(trans_A=False, trans_B=False)
test_bmm(trans_A=True, trans_B=False)
test_bmm(trans_A=False, trans_B=True)
test_bmm(trans_A=True, trans_B=True)

@t-vi
Copy link
Copy Markdown
Collaborator

t-vi commented Sep 7, 2018

Hi,
for the same reason as you, I have updated #11292 , but it doesn't do MKL yet.
Would it be OK if I merge the MKL bits of your patch into that PR?

Best regards

Thomas

P.S.: Also, I think that it would be best to do this for baddbmm/bmm in one go.

namespace at { namespace native {

Tensor bmm_mkl(const Tensor& self, const Tensor& tensor) {
throw std::runtime_error("bmm: ATen not compiled with MKL support");

This comment was marked as off-topic.

@ssnl
Copy link
Copy Markdown
Collaborator

ssnl commented Sep 7, 2018

Very nice. Thanks a lot! Although it seems to have some bug currently (see CI failures).

@mingfeima
Copy link
Copy Markdown
Collaborator Author

close this as folded in #11292

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants