Skip to content

Linalg solve_triangular fails on cupy backed arrays #7536

@beckernick

Description

@beckernick

dask.array.linalg.solve_triangular fails on cupy backed arrays due to the function explicitly using the scipy routine. Now that CuPy supports solve_triangular in the cupyx.scipy.linalg namespace, I suspect we could dispatch and provide support in dask.array.

dask/dask/array/linalg.py

Lines 1178 to 1191 in c5633c2

dsk[_key(i, j)] = (_solve_triangular_lower, (a.name, i, i), target)
else:
for i in range(vchunks):
for j in range(hchunks):
target = _b_init(i, j)
if i < vchunks - 1:
prevs = []
for k in range(i + 1, vchunks):
prev = name_mdot, i, k, k, j
dsk[prev] = (np.dot, (a.name, i, k), _key(k, j))
prevs.append(prev)
target = (operator.sub, target, (sum, prevs))
dsk[_key(i, j)] = (
scipy.linalg.solve_triangular,

dask/dask/array/linalg.py

Lines 966 to 969 in c5633c2

def _solve_triangular_lower(a, b):
import scipy.linalg
return scipy.linalg.solve_triangular(a, b, lower=True)

import numpy as np
import dask.array as daa = np.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]])
b = np.array([4, 2, 4, 2])
​
a, b = da.from_array(a), da.from_array(b)
​
da.linalg.solve_triangular(a, b).compute()
array([1.33333333, 2.        , 4.        , 2.        ])
import cupy as cp
import dask.array as daa = cp.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]])
b = cp.array([4, 2, 4, 2])
​
a, b = da.from_array(a), da.from_array(b)
​
da.linalg.solve_triangular(a, b).compute()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-68-9a877ba2c41f> in <module>
      7 a, b = da.from_array(a), da.from_array(b)
      8 
----> 9 da.linalg.solve_triangular(a, b).compute()

/raid/nicholasb/dev/dask/dask/base.py in compute(self, **kwargs)
    282         dask.base.compute
    283         """
--> 284         (result,) = compute(self, traverse=False, **kwargs)
    285         return result
    286 

/raid/nicholasb/dev/dask/dask/base.py in compute(*args, **kwargs)
    564         postcomputes.append(x.__dask_postcompute__())
    565 
--> 566     results = schedule(dsk, keys, **kwargs)
    567     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    568 

/raid/nicholasb/dev/dask/dask/threaded.py in get(dsk, result, cache, num_workers, pool, **kwargs)
     77             pool = MultiprocessingPoolExecutor(pool)
     78 
---> 79     results = get_async(
     80         pool.submit,
     81         pool._max_workers,

/raid/nicholasb/dev/dask/dask/local.py in get_async(submit, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, chunksize, **kwargs)
    512                             _execute_task(task, data)  # Re-execute locally
    513                         else:
--> 514                             raise_exception(exc, tb)
    515                     res, worker_id = loads(res_info)
    516                     state["cache"][key] = res

/raid/nicholasb/dev/dask/dask/local.py in reraise(exc, tb)
    323     if exc.__traceback__ is not tb:
    324         raise exc.with_traceback(tb)
--> 325     raise exc
    326 
    327 

/raid/nicholasb/dev/dask/dask/local.py in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    221     try:
    222         task, data = loads(task_info)
--> 223         result = _execute_task(task, data)
    224         id = get_id()
    225         result = dumps((result, id))

/raid/nicholasb/dev/dask/dask/core.py in _execute_task(arg, cache, dsk)
    119         # temporaries by their reference count and can execute certain
    120         # operations in-place.
--> 121         return func(*(_execute_task(a, cache) for a in args))
    122     elif not ishashable(arg):
    123         return arg

/raid/nicholasb/miniconda3/envs/cuml-dev/lib/python3.8/site-packages/scipy/linalg/basic.py in solve_triangular(a, b, trans, lower, unit_diagonal, overwrite_b, debug, check_finite)
    331              'versions of SciPy.', DeprecationWarning, stacklevel=2)
    332 
--> 333     a1 = _asarray_validated(a, check_finite=check_finite)
    334     b1 = _asarray_validated(b, check_finite=check_finite)
    335     if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:

/raid/nicholasb/miniconda3/envs/cuml-dev/lib/python3.8/site-packages/scipy/_lib/_util.py in _asarray_validated(a, check_finite, sparse_ok, objects_ok, mask_ok, as_inexact)
    260             raise ValueError('masked arrays are not supported')
    261     toarray = np.asarray_chkfinite if check_finite else np.asarray
--> 262     a = toarray(a)
    263     if not objects_ok:
    264         if a.dtype is np.dtype('O'):

/raid/nicholasb/miniconda3/envs/cuml-dev/lib/python3.8/site-packages/numpy/lib/function_base.py in asarray_chkfinite(a, dtype, order)
    484 
    485     """
--> 486     a = asarray(a, dtype=dtype, order=order)
    487     if a.dtype.char in typecodes['AllFloat'] and not np.isfinite(a).all():
    488         raise ValueError(

/raid/nicholasb/miniconda3/envs/cuml-dev/lib/python3.8/site-packages/numpy/core/_asarray.py in asarray(a, dtype, order, like)
    100         return _asarray_with_like(a, dtype=dtype, order=order, like=like)
    101 
--> 102     return array(a, dtype, copy=False, order=order)
    103 
    104 

cupy/core/core.pyx in cupy.core.core.ndarray.__array__()

TypeError: Implicit conversion to a NumPy array is not allowed. Please use `.get()` to construct a NumPy array explicitly.
conda list | grep "dask\|cupy\|numpy\|scipy"
cupy                      8.5.0            py38ha0d87d6_1    conda-forge
dask                      2021.4.0+6.gbc785e1a.dirty           dev_0    <develop>
dask-cuda                 0.19.0a210326           py38_45    rapidsai-nightly
dask-cudf                 0.19.0a210326   py38_gad5452d7eb_288    rapidsai-nightly
dask-glm                  0.2.1.dev52+g1daf4c5.d20210407           dev_0    <develop>
dask-labextension         4.0.1              pyhd8ed1ab_0    conda-forge
dask-ml                   1.8.0              pyhd8ed1ab_0    conda-forge
numpy                     1.20.1           py38h18fd61f_0    conda-forge
numpydoc                  1.1.0                      py_1    conda-forge
scipy                     1.6.0            py38hb2138dd_0    conda-forge

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions