Skip to content

Linear algebra GPU library function bug tracking issue [magma/cusolver/cublas] #53879

@xwang233

Description

@xwang233

Linear algebra GPU library function bug tracking issue [magma/cusolver/cublas]

This issue is used to track known bugs in GPU library functions, e.g. MAGMA, cuSOLVER, cuBLAS. These known issues include crash, large numerical mismatches, nan outputs, etc.

torch function library function affected library versions workaround in pytorch? related PR what issue?
torch.cholesky cusolverDnXpotrfBatched fixed in cuda 11.3 #57788 #53104 #56724 Function creates nan output for large ill-conditioned matrix.
magma_Xpotrf_batched ? patched #50957 Function causes cuda illegal memory access for large inputs.
torch.eigh torch.eigvalsh cusolverDnXsyevjBatched fixed in cuda 11.3U1 not used #53040 Function creates wrong outputs sometimes.
torch.lu cusolverDn<t>getrf current replace nan with 0 #56887 for non-pivoted variant and singular input there is no check for 0/0 resulting in nan values; the cuSOLVER kernels are probably missing "if element is zero: skip it" condition.
torch.triangular_solve (sparse) cusparseSpSM not fixed in 11.5, see #66401 (comment) #66401 Function uses wrong strides to write output. The bug can be reproduced using the code given in #66401 (comment).
Code for reproducing the nan issue of getrf

Compile with nvcc -lcusolver

#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include <cuda_runtime.h>
#include <cusolverDn.h>
#include <cusolver_common.h>
int main() {
  int major=-1,minor=-1,patch=-1;
  cusolverGetProperty(MAJOR_VERSION, &major);
  cusolverGetProperty(MINOR_VERSION, &minor);
  cusolverGetProperty(PATCH_LEVEL, &patch);
  printf("CUSOLVER Version (Major,Minor,PatchLevel): %d.%d.%d\n", major,minor,patch);
  cusolverDnHandle_t handle=NULL;
  cudaStream_t stream=NULL;
  cusolverStatus_t status = CUSOLVER_STATUS_SUCCESS;
  int N = 3;
  double * cA = (double*)malloc(N * N * sizeof(double));
  double * dA;
  for (int i = 0; i < N * N; ++i) {
    cA[i] = 1;
  }
  cudaMalloc((void**)&dA, N * N * sizeof(double) );
  cudaMemcpy(dA, cA, N * N * sizeof(double), cudaMemcpyHostToDevice);
  assert(cusolverDnCreate(&handle) == CUSOLVER_STATUS_SUCCESS);
  assert(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking) ==
         cudaSuccess);
  assert(cusolverDnSetStream(handle, stream) == CUSOLVER_STATUS_SUCCESS);
  int lwork;
  double *dwork;
  int *dipiv, *dinfo, *cinfo;
  cinfo = (int*)malloc(sizeof(int) * 1);
  cudaMalloc((void**)&dinfo, sizeof(int));
  cudaMalloc((void**)&dipiv, N * sizeof(int));
  status = cusolverDnDgetrf_bufferSize(handle,
                                       N,
                                       N,
                                       dA,
                                       N,
                                       &lwork);
  assert(CUSOLVER_STATUS_SUCCESS == status);
  cudaMalloc((void**)&dwork, sizeof(double) * lwork);
  cusolverDnDgetrf(handle,
                   N,
                   N,
                   dA,
                   N,
                   dwork,
                   NULL,
                   dinfo);
  cudaMemcpy(cA, dA, N * N * sizeof(double), cudaMemcpyDeviceToHost);
  cudaMemcpy(cinfo, dinfo, sizeof(int), cudaMemcpyDeviceToHost);
  printf("dinfo: %d\n", *cinfo);
  for (int i = 0; i < N; ++i) {
    for (int j = 0; j < N; ++j) {
      printf("%lf ", cA[i + j * N]);
    }
    printf("\n");
  }
}

See also

cc @ngimel @jianyuh @nikitaved @pearu @mruberry @heitorschueroff @walterddr @IvanYashchuk @ptrblck

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudaRelated to torch.cuda, and CUDA support in generalmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmultriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions