Skip to content

Workaround for MAGMA accessing illegal memory in batched cholesky#50957

Closed
heitorschueroff wants to merge 3 commits intogh/heitorschueroff/34/basefrom
gh/heitorschueroff/34/head
Closed

Workaround for MAGMA accessing illegal memory in batched cholesky#50957
heitorschueroff wants to merge 3 commits intogh/heitorschueroff/34/basefrom
gh/heitorschueroff/34/head

Conversation

@heitorschueroff
Copy link
Copy Markdown
Contributor

@heitorschueroff heitorschueroff commented Jan 22, 2021

Stack from ghstack:

MAGMA has an off-by-one error in their batched cholesky implementation which is causing illegal memory access for certain inputs. The workaround implemented in this PR is to pad the input to MAGMA with 1 extra element.

Benchmark
Ran the script below for both before and after my PR and got similar results.

Script

import torch
from torch.utils import benchmark

DTYPE = torch.float32
BATCHSIZE = 512 * 512
MATRIXSIZE = 16

a = torch.eye(MATRIXSIZE, device='cuda', dtype=DTYPE)

t0 = benchmark.Timer(
    stmt='torch.cholesky(a)',
    globals={'a': a},
    label='Single'
)

t1 = benchmark.Timer(
    stmt='torch.cholesky(a)',
    globals={'a': a.expand(BATCHSIZE, -1, -1)},
    label='Batched'
)

print(t0.timeit(100))
print(t1.timeit(100))

Results before

<torch.utils.benchmark.utils.common.Measurement object at 0x7faf9bc63400>
Single
  2.08 ms
  1 measurement, 100 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7faf9bc63400>
Batched
  7.68 ms
  1 measurement, 100 runs , 1 thread

Results after

<torch.utils.benchmark.utils.common.Measurement object at 0x7faf9bc63400>
Single
  2.10 ms
  1 measurement, 100 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7faf9bc63400>
Batched
  7.56 ms
  1 measurement, 100 runs , 1 thread

Fixes #41394, #26996, #48996

See also #42666, #26789

TODO

  • Benchmark to check for perf regressions

Differential Revision: D26050978

@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Jan 22, 2021

💊 CI failures summary and remediations

As of commit 4b4b0f7 (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

…holesky"


MAGMA has an off-by-one error in their batched cholesky implementation which is causing illegal memory access for certain inputs. The workaround implemented in this PR is to pad the input to MAGMA with 1 extra element.

Fixes #41394, #26996, #48996

See also #42666, #26789

TODO
---
- [ ] Benchmark to check for perf regressions

[ghstack-poisoned]
heitorschueroff added a commit that referenced this pull request Jan 22, 2021
@ngimel
Copy link
Copy Markdown
Collaborator

ngimel commented Jan 23, 2021

Awesome, thanks for fixing!

@codecov
Copy link
Copy Markdown

codecov Bot commented Jan 23, 2021

Codecov Report

Merging #50957 (4b4b0f7) into gh/heitorschueroff/34/base (a6257b2) will decrease coverage by 0.00%.
The diff coverage is n/a.

@@                      Coverage Diff                       @@
##           gh/heitorschueroff/34/base   #50957      +/-   ##
==============================================================
- Coverage                       80.91%   80.91%   -0.01%     
==============================================================
  Files                            1926     1926              
  Lines                          210014   210014              
==============================================================
- Hits                           169942   169941       -1     
- Misses                          40072    40073       +1     

Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
return self_working_copy.transpose(-1, -2);
} else {
return self_working_copy;
result.transpose_(-1, -2);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the impact of changing this to be an in-place transpose?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One less Tensor view created !? The main reason I changed was stylistic. Since result holds a reference to a new Tensor created with either (at::empty or at::clone), it does not affect the input tensor.

@mruberry
Copy link
Copy Markdown
Collaborator

This looks great, nice work @heitorschueroff! cc @ptrblck and @xwang233.

I made one comment suggestion and have a question about a change of transpose variant.

…holesky"


MAGMA has an off-by-one error in their batched cholesky implementation which is causing illegal memory access for certain inputs. The workaround implemented in this PR is to pad the input to MAGMA with 1 extra element.

Fixes #41394, #26996, #48996

See also #42666, #26789

TODO
---
- [ ] Benchmark to check for perf regressions

[ghstack-poisoned]
heitorschueroff added a commit that referenced this pull request Jan 25, 2021
@mruberry mruberry self-requested a review January 25, 2021 19:13
Copy link
Copy Markdown
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

@ptrblck
Copy link
Copy Markdown
Collaborator

ptrblck commented Jan 25, 2021

Thanks for the PR and the discussion on Slack!

Were you somehow able to verify, that the data/results won't be corrupted by this change?
I might be too paranoid, but am afraid of silent errors, which might create wrong results, if we are not only reading out of bounds, but also process these values.

I've spent some time digging into the code and think the first issue is this illegal access:

__global__ void spotf2_smlpin_fixwidth_kernel_batched(int m,
        float **dA_array, int ai, int aj, int lda,
        int localstep, int gbstep, magma_int_t *info_array, const int batchCount)
{
    const int batchid = blockIdx.z * blockDim.y + threadIdx.y;
    float *dA = dA_array[batchid] + aj * lda + ai;
    if (batchid >= batchCount) return;

dA_array has a valid length of 4, while batchid is checked too late, so I think swapping it should solve this first issue.

(cuda-gdb) print dA_array[0]
$36 = (@generic float * @generic) 0x7ffeef5ff9c0
(cuda-gdb) print dA_array[1]
$37 = (@generic float * @generic) 0x7ffeef5ffb50
(cuda-gdb) print dA_array[2]
$38 = (@generic float * @generic) 0x7ffeef5ffce0
(cuda-gdb) print dA_array[3]
$39 = (@generic float * @generic) 0x7ffeef5ffe70
(cuda-gdb) print dA_array[4]
$40 = (@generic float * @generic) 0xfffffff389
(cuda-gdb) print dA_array[0] @ 4
$44 = {0x7ffeef5ff9c0, 0x7ffeef5ffb50, 0x7ffeef5ffce0, 0x7ffeef5ffe70}
(cuda-gdb) print dA_array[0] @ 5
Error: Failed to read generic memory at address 0x7ffeef9fffe0 on device 0 sm 0 warp 3 lane 0, error=CUDBG_ERROR_MEMORY_MAPPING_FAILED(0x9).

However, once this is fixed, another IMA is triggered:

(cuda-gdb) bt
#0  0x00005555a86e9a20 in _INTERNAL_61_tmpxft_00007a08_00000000_11_spotf2_kernels_compute_80_cpp1_ii_7b34eafc::sgemm_v20_1_anywidth_device (m=2, n=2,
    k=8, A0=0xff00000000, lda=10, sC=0x7fff260027d0, sB=0x7fffffffff00)
    at /opt/conda/conda-bld/magma-cuda110_1611350834581/work/magmablas/spotf2_devicesfunc.cuh:323
#1  0x00005555a86e9a20 in _INTERNAL_61_tmpxft_00007a08_00000000_11_spotf2_kernels_compute_80_cpp1_ii_7b34eafc::sgemm_v20_1_anywidth_device (m=8, n=0,
    k=0, A0=0x7fff28fffb00, lda=-1807745072, sC=0x7ffeee6aaa00, sB=0x7ffe943ffe90)
    at /opt/conda/conda-bld/magma-cuda110_1611350834581/work/magmablas/spotf2_devicesfunc.cuh:323
#2  0x00005555a86de7c0 in _INTERNAL_61_tmpxft_00007a08_00000000_11_spotf2_kernels_compute_80_cpp1_ii_7b34eafc::spotf2_smlpout_anywidth_device (
    m=<optimized out>, n=0, A0=0x0, A=0x0, lda=<optimized out>, localstep=<optimized out>, gbstep=<optimized out>, info=<optimized out>)
    at /opt/conda/conda-bld/magma-cuda110_1611350834581/work/magmablas/spotf2_devicesfunc.cuh:450
#3  0x00005555a86fdb70 in spotf2_smlpout_anywidth_kernel_batched<<<(1,1,1),(2,8,1)>>> (m=2, n=2, dA_array=0x7ffeff7fffe0, ai=0, aj=0, lda=10,
    localstep=8, gbstep=0, info_array=0x7ffeff4ffff0, batchCount=4)
    at /opt/conda/conda-bld/magma-cuda110_1611350834581/work/magmablas/spotf2_kernels.cu:135
(cuda-gdb) info registers
R0             0x5c                92
R1             0xfffa40            16775744
R2             0x7fff              32767
R3             0x2                 2
R4             0x94400000          -1807745024
R5             0x7ffe              32766

Where an LDG is executed on [R4].

Based on the $pc instructions, the most likely candidate is A0 in:

static inline __device__ void sgemm_v20_1_anywidth_device(int m, int n, int k,
        const float* __restrict__ A0, int lda,
        float *sC, float  *sB)
{
[...]
        // prefetch next block. Azzam
        #ifdef ENABLE_COND5
        if (tx < m )  
        {      
        #endif
            #pragma unroll
            for (int i=0; i < POTF2_NB; i++)
            {
                rp[i] = A0[min(bound_A, tx + (i+(iter+POTF2_NB)) * lda)]; // min(bound,xxx) is to avoid reading out of bound
            }
        #ifdef ENABLE_COND5
        }
        #endif

I'm happy to dig further into it, if it makes sense.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@heitorschueroff merged this pull request in a7cf04e.

@mruberry
Copy link
Copy Markdown
Collaborator

mruberry commented Jan 25, 2021

Were you somehow able to verify, that the data/results won't be corrupted by this change?
I might be too paranoid, but am afraid of silent errors, which might create wrong results, if we are not only reading out of bounds, but also process these values.

We have existing tests for batched cholesky that this PR re-enabled. The function was always reading these values and I don't think we've had reports of incorrect or corrupted values. One possible fix would be to change this PR's empty call to a zeros call, so the memory being read was zero, although the effect of that change is ???.

I don't want to discourage you from further debugging into MAGMA, but our thinking is that this fix is a definite improvement solving all known segfaults with no harm done to correctness, and we should wait for the community to report new issues if there are any. We also want to begin pursuing a cuSOLVER implementation of batched cholesky.

For example, if we identify the issue with MAGMA and submit a fix for it, that fix will only appear in newer versions of MAGMA and this workaround will still be required. If batched cholesky uses cuSOLVER when the CUDA toolkit version is sufficiently high then the overlap between users with a fixed MAGMA version who are also not using cuSOLVER might be very small.

@ngimel
Copy link
Copy Markdown
Collaborator

ngimel commented Jan 26, 2021

@ptrblck that's a very valid concern. Do you think one way to alleviate it would be to fill the padding element with 'nan' and run a battery of random tests? If that value is indeed read and processed, it would pollute the results. I don't want to leave the nan filling permanently, as that would be additional and unneded perf penalty, but I also don't think that existing tests in the test suite are enough to be reasonably sure that correct results are produced.

@ptrblck
Copy link
Copy Markdown
Collaborator

ptrblck commented Jan 26, 2021

@ngimel Yes, I think it's a very good idea.
I can fill the result tensor with NaNs and run tests using multiples of 512 for the batch size first (as it was hitting the IMA) and then continue with random shapes in a reasonable amount of iterations (let me know, if you have specific shapes in mind).
If NaNs are detected in the result, I'll continue debugging it.

@ngimel
Copy link
Copy Markdown
Collaborator

ngimel commented Jan 26, 2021

You can use Fuzzer to help with input generation https://github.com/pytorch/pytorch/blob/47f0bda3ef8d196e0fa81a1749814dd75ffb1692/torch/utils/benchmark/examples/sparse/fuzzer.py, e.g. you can teach it to generate multiples of 512 half the time, and random sizes distributed according to some distribution for the rest.

@ptrblck
Copy link
Copy Markdown
Collaborator

ptrblck commented Jan 29, 2021

Good news! I've checked 600k shapes and no invalid values were found in any result tensor.

Testing script:
To fill the result with NaNs I've added the following change:

result = at::empty(input.numel() + 1, input.options());
    result = result * at::log(at::ones({1}) * -1.).item();
    result.resize_as_(input).copy_(input).transpose_(-1, -2);
  } else {
    result = cloneBatchedColumnMajor(upper ? self.transpose(-1, -2) : self);
  }

The testing script is defined as:

import torch
import sys
import  pandas as pd
import numpy as np


def get_mat_dist_pow():
    p10 = 0.7 # a matrix_size of 10 is highly likely to fail, so increase probability
    p_other = (1 - p10) / 46
    dist_matrix_size = {v.long().item(): p_other for v in torch.linspace(4, 50, 47)}
    dist_matrix_size[10] = p10
    return dist_matrix_size


def get_batch_dist_pow():
    return {2**e.long().item(): 1/11. for e in torch.linspace(0, 10, 11)}


def get_matrix_size():
    # https://github.com/pytorch/pytorch/blob/47f0bda3ef8d196e0fa81a1749814dd75ffb1692/torch/utils/benchmark/utils/fuzzer.py#L120
    index = state.choice(
        np.arange(len(mat_dist_pow)),
        p=tuple(mat_dist_pow.values()))
    matrix_size = list(mat_dist_pow.keys())[index]
    return matrix_size


def get_batch_size(p_switch=0.5):
    # switch between uniform and pow
    if torch.empty(1).uniform_() > p_switch:
        batch_size = torch.empty(1).uniform_(1, 1024*1024).long()
    else:
        # https://github.com/pytorch/pytorch/blob/47f0bda3ef8d196e0fa81a1749814dd75ffb1692/torch/utils/benchmark/utils/fuzzer.py#L120
        index = state.choice(
            np.arange(len(batch_dist_pow)),
            p=tuple(batch_dist_pow.values()))
        batch_size = list(batch_dist_pow.keys())[index]
    return batch_size


dtype = torch.float32
nb_iters = 600000
columns=["matrix_size", "batch_size", "passed"]
result = pd.DataFrame(columns=columns)
device = 'cuda:0'
mat_dist_pow = get_mat_dist_pow()
batch_dist_pow = get_batch_dist_pow()
state = np.random.RandomState(2809)

for _ in range(nb_iters):
    try:
        # create fake data on device to move workload
        size = 2**torch.randint(0, 30, (1,))
        fake = torch.randn(size, device=device)
        print('Created fake tensor of {}MB'.format(fake.nelement()*4/1024**2))

        matrix_size = get_matrix_size()
        batch_size = get_batch_size()
        print('Using matrix_Size {}, batch_size {}'.format(matrix_size, batch_size))
        input = torch.eye(matrix_size, device=device, dtype=dtype).expand(batch_size, -1, -1)
        print('input.shape {}'.format(input.shape))

        # execute test multiple times
        for _ in range(3):
            ret = torch.cholesky(input)
        torch.cuda.synchronize()
        passed = torch.isfinite(ret).all()
        if not passed:
            print('ERROR! Invalid values found!')
            print(ret)
            torch.save(ret, 'invalid_ret01.pt')
            sys.exit(-1)

        # save results
        result = result.append(pd.DataFrame([[matrix_size, batch_size, passed.item()]], columns=columns))
        del fake
        del input
        del ret
    except Exception as e:
        print(e)
        if 'out of memory' in str(e):
            result = result.append(pd.DataFrame([[matrix_size, batch_size, 'oom']], columns=columns))
            continue
        else:
            result = result.append(pd.DataFrame([[matrix_size, batch_size, str(e)]], columns=columns))
            continue

result.to_csv('magma_ima_result01.csv', index=False)

For the matrix_size, I've used a uniform distribution for values in [4, 50] with a peak of p=0.7 for a matrix_size of 10, as this shape seems to trigger the IMA often.
The batch_size was randomly picked with p=0.5 between:

  • uniform distribution in [1, 1024*1024]
  • random choice in {2**e.long().item(): 1/11. for e in torch.linspace(0, 10, 11)} = [1, 2, 4, 8, ... 1024]

In each iteration a fake tensor is first created on the GPU in order to move the workload a bit around, as I thought this might also be useful.
The script was executed with a disabled caching allocator via PYTORCH_NO_CUDA_MEMORY_CACHING=1, as I was concerned the cache might "sanitize" the test.

I tried to use the Fuzzer (which is a nice utility btw.), but couldn't get it working in picking randomly between the two distributions of the batch_size.

Attached is the result image, where blue dots show a passed test and red one indicate an OOM.
The tests were performed using a master build with the necessary changes to create the NaN in result on a TitanV.

res02

Also, sorry for the delay, but my estimation of the duration of this test was way off, and it took ~30h to finish.

Please let me know, if you see any issues in the current test and/or if I should run additional tests.

@facebook-github-bot facebook-github-bot deleted the gh/heitorschueroff/34/head branch January 29, 2021 15:21
@heitorschueroff
Copy link
Copy Markdown
Contributor Author

@ptrblck Thank you for this extensive exploration and verifying that MAGMA is not corrupting the results with the OOM reads. I think at this point we can feel safe in the fix implemented here. If there are new IMA issues after this fix we can dive deeper into the issue, but as @mruberry suggested, we will likely transition most use cases to cuSOLVER.

laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
…torch#50957)

Summary:
Pull Request resolved: pytorch#50957

MAGMA has an off-by-one error in their batched cholesky implementation which is causing illegal memory access for certain inputs. The workaround implemented in this PR is to pad the input to MAGMA with 1 extra element.

**Benchmark**
Ran the script below for both before and after my PR and got similar results.

*Script*
```
import torch
from torch.utils import benchmark

DTYPE = torch.float32
BATCHSIZE = 512 * 512
MATRIXSIZE = 16

a = torch.eye(MATRIXSIZE, device='cuda', dtype=DTYPE)

t0 = benchmark.Timer(
    stmt='torch.cholesky(a)',
    globals={'a': a},
    label='Single'
)

t1 = benchmark.Timer(
    stmt='torch.cholesky(a)',
    globals={'a': a.expand(BATCHSIZE, -1, -1)},
    label='Batched'
)

print(t0.timeit(100))
print(t1.timeit(100))
```

*Results before*
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7faf9bc63400>
Single
  2.08 ms
  1 measurement, 100 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7faf9bc63400>
Batched
  7.68 ms
  1 measurement, 100 runs , 1 thread
```

*Results after*
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7faf9bc63400>
Single
  2.10 ms
  1 measurement, 100 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7faf9bc63400>
Batched
  7.56 ms
  1 measurement, 100 runs , 1 thread
```

Fixes pytorch#41394, pytorch#26996, pytorch#48996

See also pytorch#42666, pytorch#26789

TODO
 ---
- [x] Benchmark to check for perf regressions

Test Plan: Imported from OSS

Reviewed By: bdhirsh

Differential Revision: D26050978

Pulled By: heitorschueroff

fbshipit-source-id: 7a5ba7e34c9d74b58568b2a0c631cc6d7ba63f86
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants