Skip to content

Linalg lstsq fails on cupy backed arrays #7537

@beckernick

Description

@beckernick

dask.array.linalg.lstsq fails on cupy backed arrays due to solve_triangular explicitly using the scipy routine. Now that CuPy supports solve_triangular in the cupyx.scipy.linalg namespace, solving #7536 would likely be one of the last pieces to fall into place (as linalg.qr already works)

dask/dask/array/linalg.py

Lines 1404 to 1405 in c5633c2

q, r = qr(a)
x = solve_triangular(r, q.T.conj().dot(b))

import cupy as cp
import dask.array as daa = cp.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]])
b = cp.array([4, 2, 4, 2])
​
a, b = da.from_array(a), da.from_array(b)
​
da.compute(da.linalg.lstsq(a, b))
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-74-2d6bd9d8e7c5> in <module>
      7 a, b = da.from_array(a), da.from_array(b)
      8 
----> 9 da.compute(da.linalg.lstsq(a, b))

/raid/nicholasb/dev/dask/dask/base.py in compute(*args, **kwargs)
    564         postcomputes.append(x.__dask_postcompute__())
    565 
--> 566     results = schedule(dsk, keys, **kwargs)
    567     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    568 

/raid/nicholasb/dev/dask/dask/threaded.py in get(dsk, result, cache, num_workers, pool, **kwargs)
     77             pool = MultiprocessingPoolExecutor(pool)
     78 
---> 79     results = get_async(
     80         pool.submit,
     81         pool._max_workers,

/raid/nicholasb/dev/dask/dask/local.py in get_async(submit, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, chunksize, **kwargs)
    512                             _execute_task(task, data)  # Re-execute locally
    513                         else:
--> 514                             raise_exception(exc, tb)
    515                     res, worker_id = loads(res_info)
    516                     state["cache"][key] = res

/raid/nicholasb/dev/dask/dask/local.py in reraise(exc, tb)
    323     if exc.__traceback__ is not tb:
    324         raise exc.with_traceback(tb)
--> 325     raise exc
    326 
    327 

/raid/nicholasb/dev/dask/dask/local.py in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    221     try:
    222         task, data = loads(task_info)
--> 223         result = _execute_task(task, data)
    224         id = get_id()
    225         result = dumps((result, id))

/raid/nicholasb/dev/dask/dask/core.py in _execute_task(arg, cache, dsk)
    119         # temporaries by their reference count and can execute certain
    120         # operations in-place.
--> 121         return func(*(_execute_task(a, cache) for a in args))
    122     elif not ishashable(arg):
    123         return arg

/raid/nicholasb/miniconda3/envs/cuml-dev/lib/python3.8/site-packages/scipy/linalg/basic.py in solve_triangular(a, b, trans, lower, unit_diagonal, overwrite_b, debug, check_finite)
    331              'versions of SciPy.', DeprecationWarning, stacklevel=2)
    332 
--> 333     a1 = _asarray_validated(a, check_finite=check_finite)
    334     b1 = _asarray_validated(b, check_finite=check_finite)
    335     if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:

/raid/nicholasb/miniconda3/envs/cuml-dev/lib/python3.8/site-packages/scipy/_lib/_util.py in _asarray_validated(a, check_finite, sparse_ok, objects_ok, mask_ok, as_inexact)
    260             raise ValueError('masked arrays are not supported')
    261     toarray = np.asarray_chkfinite if check_finite else np.asarray
--> 262     a = toarray(a)
    263     if not objects_ok:
    264         if a.dtype is np.dtype('O'):

/raid/nicholasb/miniconda3/envs/cuml-dev/lib/python3.8/site-packages/numpy/lib/function_base.py in asarray_chkfinite(a, dtype, order)
    484 
    485     """
--> 486     a = asarray(a, dtype=dtype, order=order)
    487     if a.dtype.char in typecodes['AllFloat'] and not np.isfinite(a).all():
    488         raise ValueError(

/raid/nicholasb/miniconda3/envs/cuml-dev/lib/python3.8/site-packages/numpy/core/_asarray.py in asarray(a, dtype, order, like)
    100         return _asarray_with_like(a, dtype=dtype, order=order, like=like)
    101 
--> 102     return array(a, dtype, copy=False, order=order)
    103 
    104 

cupy/core/core.pyx in cupy.core.core.ndarray.__array__()

TypeError: Implicit conversion to a NumPy array is not allowed. Please use `.get()` to construct a NumPy array explicitly.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions