Skip to content

dask.array.dot doesn't work with cupyx sparse based dask arrays #6820

@daxiongshu

Description

@daxiongshu

da.dot(dx, dw) throws an error when dx and dw are dask arrays of cupyx.scipy.sparse chunks. Here is the code.

import cupy
import dask.array as da
from cupyx.scipy.sparse import csr_matrix
x = cupy.arange(24, dtype=cupy.float32).reshape(4, 6)
dx = da.from_array(x, chunks=(2, 6), asarray=False, fancy=False)
dx = dx.map_blocks(csr_matrix, dtype=cupy.float32)
w = cupy.arange(18, dtype=cupy.float32).reshape(6,3)
dw = da.from_array(w, chunks=(2, 3), asarray=False, fancy=False)
dw = dw.map_blocks(csr_matrix, dtype=cupy.float32)
da.dot(dx, dw).compute() # error
Click to see the error

Error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-6-9460b0874b75> in <module>
     11 dw = dw.map_blocks(csr_matrix, dtype=cupy.float32)
     12 
---> 13 da.dot(dx, dw).compute()
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/dask/base.py in compute(self, **kwargs)
    165         dask.base.compute
    166         """
--> 167         (result,) = compute(self, traverse=False, **kwargs)
    168         return result
    169 
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs)
    452         postcomputes.append(x.__dask_postcompute__())
    453 
--> 454     results = schedule(dsk, keys, **kwargs)
    455     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    456 
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/dask/threaded.py in get(dsk, result, cache, num_workers, pool, **kwargs)
     82         get_id=_thread_get_id,
     83         pack_exception=pack_exception,
---> 84         **kwargs
     85     )
     86 
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/dask/local.py in get_async(apply_async, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, **kwargs)
    484                         _execute_task(task, data)  # Re-execute locally
    485                     else:
--> 486                         raise_exception(exc, tb)
    487                 res, worker_id = loads(res_info)
    488                 state["cache"][key] = res
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/dask/local.py in reraise(exc, tb)
    314     if exc.__traceback__ is not tb:
    315         raise exc.with_traceback(tb)
--> 316     raise exc
    317 
    318 
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/dask/local.py in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    220     try:
    221         task, data = loads(task_info)
--> 222         result = _execute_task(task, data)
    223         id = get_id()
    224         result = dumps((result, id))
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
    119         # temporaries by their reference count and can execute certain
    120         # operations in-place.
--> 121         return func(*(_execute_task(a, cache) for a in args))
    122     elif not ishashable(arg):
    123         return arg
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/cupyx/scipy/sparse/compressed.py in __getitem__(self, slices)
    247                 return self._get_major_slice(major)
    248 
--> 249         raise ValueError('unsupported indexing')
    250 
    251     def _get_single(self, major, minor):
ValueError: unsupported indexing
Click to see another example and a different error

If I change the above code

dw = da.from_array(w, chunks=(2, 3), asarray=False, fancy=False)

to

dw = da.from_array(w, chunks=(6, 1), asarray=False, fancy=False)

I got a different error:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-5-185740c5f0c0> in <module>
     11 dw = dw.map_blocks(csr_matrix, dtype=cupy.float32)
     12 
---> 13 da.dot(dx, dw).compute()
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/dask/base.py in compute(self, **kwargs)
    165         dask.base.compute
    166         """
--> 167         (result,) = compute(self, traverse=False, **kwargs)
    168         return result
    169 
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs)
    452         postcomputes.append(x.__dask_postcompute__())
    453 
--> 454     results = schedule(dsk, keys, **kwargs)
    455     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    456 
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/dask/threaded.py in get(dsk, result, cache, num_workers, pool, **kwargs)
     82         get_id=_thread_get_id,
     83         pack_exception=pack_exception,
---> 84         **kwargs
     85     )
     86 
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/dask/local.py in get_async(apply_async, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, **kwargs)
    484                         _execute_task(task, data)  # Re-execute locally
    485                     else:
--> 486                         raise_exception(exc, tb)
    487                 res, worker_id = loads(res_info)
    488                 state["cache"][key] = res
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/dask/local.py in reraise(exc, tb)
    314     if exc.__traceback__ is not tb:
    315         raise exc.with_traceback(tb)
--> 316     raise exc
    317 
    318 
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/dask/local.py in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    220     try:
    221         task, data = loads(task_info)
--> 222         result = _execute_task(task, data)
    223         id = get_id()
    224         result = dumps((result, id))
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
    119         # temporaries by their reference count and can execute certain
    120         # operations in-place.
--> 121         return func(*(_execute_task(a, cache) for a in args))
    122     elif not ishashable(arg):
    123         return arg
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/dask/core.py in <genexpr>(.0)
    119         # temporaries by their reference count and can execute certain
    120         # operations in-place.
--> 121         return func(*(_execute_task(a, cache) for a in args))
    122     elif not ishashable(arg):
    123         return arg
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
    119         # temporaries by their reference count and can execute certain
    120         # operations in-place.
--> 121         return func(*(_execute_task(a, cache) for a in args))
    122     elif not ishashable(arg):
    123         return arg
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/dask/optimization.py in __call__(self, *args)
    961         if not len(args) == len(self.inkeys):
    962             raise ValueError("Expected %d args, got %d" % (len(self.inkeys), len(args)))
--> 963         return core.get(self.dsk, self.outkey, dict(zip(self.inkeys, args)))
    964 
    965     def __reduce__(self):
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/dask/core.py in get(dsk, out, cache)
    149     for key in toposort(dsk):
    150         task = dsk[key]
--> 151         result = _execute_task(task, cache)
    152         cache[key] = result
    153     result = _execute_task(out, cache)
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
    119         # temporaries by their reference count and can execute certain
    120         # operations in-place.
--> 121         return func(*(_execute_task(a, cache) for a in args))
    122     elif not ishashable(arg):
    123         return arg
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/dask/core.py in <genexpr>(.0)
    119         # temporaries by their reference count and can execute certain
    120         # operations in-place.
--> 121         return func(*(_execute_task(a, cache) for a in args))
    122     elif not ishashable(arg):
    123         return arg
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
    113     """
    114     if isinstance(arg, list):
--> 115         return [_execute_task(a, cache) for a in arg]
    116     elif istask(arg):
    117         func, args = arg[0], arg[1:]
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/dask/core.py in <listcomp>(.0)
    113     """
    114     if isinstance(arg, list):
--> 115         return [_execute_task(a, cache) for a in arg]
    116     elif istask(arg):
    117         func, args = arg[0], arg[1:]
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
    119         # temporaries by their reference count and can execute certain
    120         # operations in-place.
--> 121         return func(*(_execute_task(a, cache) for a in args))
    122     elif not ishashable(arg):
    123         return arg
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/dask/utils.py in apply(func, args, kwargs)
     27 def apply(func, args, kwargs=None):
     28     if kwargs:
---> 29         return func(*args, **kwargs)
     30     else:
     31         return func(*args)
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/dask/array/routines.py in _tensordot(a, b, axes)
    227         )
    228     else:
--> 229         x = tensordot(a, b, axes=axes)
    230 
    231     ind = [slice(None, None)] * x.ndim
<__array_function__ internals> in tensordot(*args, **kwargs)
~/anaconda3/envs/dask_glm_test/lib/python3.7/site-packages/numpy/core/numeric.py in tensordot(a, b, axes)
   1073     else:
   1074         for k in range(na):
-> 1075             if as_[axes_a[k]] != bs[axes_b[k]]:
   1076                 equal = False
   1077                 break
IndexError: tuple index out of range

Environment:

  • Dask version: 2.30.0+17.g25a5db2d
  • cupy version: 8.0.0
  • Python version: 3.7
  • Operating System: ubuntu 18.04
  • Install method (conda, pip, source): source

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions