-
-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Closed
Description
dask.array.linalg.solve_triangular fails on cupy backed arrays due to the function explicitly using the scipy routine. Now that CuPy supports solve_triangular in the cupyx.scipy.linalg namespace, I suspect we could dispatch and provide support in dask.array.
Lines 1178 to 1191 in c5633c2
| dsk[_key(i, j)] = (_solve_triangular_lower, (a.name, i, i), target) | |
| else: | |
| for i in range(vchunks): | |
| for j in range(hchunks): | |
| target = _b_init(i, j) | |
| if i < vchunks - 1: | |
| prevs = [] | |
| for k in range(i + 1, vchunks): | |
| prev = name_mdot, i, k, k, j | |
| dsk[prev] = (np.dot, (a.name, i, k), _key(k, j)) | |
| prevs.append(prev) | |
| target = (operator.sub, target, (sum, prevs)) | |
| dsk[_key(i, j)] = ( | |
| scipy.linalg.solve_triangular, |
Lines 966 to 969 in c5633c2
| def _solve_triangular_lower(a, b): | |
| import scipy.linalg | |
| return scipy.linalg.solve_triangular(a, b, lower=True) |
import numpy as np
import dask.array as da
a = np.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]])
b = np.array([4, 2, 4, 2])
a, b = da.from_array(a), da.from_array(b)
da.linalg.solve_triangular(a, b).compute()
array([1.33333333, 2. , 4. , 2. ])import cupy as cp
import dask.array as da
a = cp.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]])
b = cp.array([4, 2, 4, 2])
a, b = da.from_array(a), da.from_array(b)
da.linalg.solve_triangular(a, b).compute()
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-68-9a877ba2c41f> in <module>
7 a, b = da.from_array(a), da.from_array(b)
8
----> 9 da.linalg.solve_triangular(a, b).compute()
/raid/nicholasb/dev/dask/dask/base.py in compute(self, **kwargs)
282 dask.base.compute
283 """
--> 284 (result,) = compute(self, traverse=False, **kwargs)
285 return result
286
/raid/nicholasb/dev/dask/dask/base.py in compute(*args, **kwargs)
564 postcomputes.append(x.__dask_postcompute__())
565
--> 566 results = schedule(dsk, keys, **kwargs)
567 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
568
/raid/nicholasb/dev/dask/dask/threaded.py in get(dsk, result, cache, num_workers, pool, **kwargs)
77 pool = MultiprocessingPoolExecutor(pool)
78
---> 79 results = get_async(
80 pool.submit,
81 pool._max_workers,
/raid/nicholasb/dev/dask/dask/local.py in get_async(submit, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, chunksize, **kwargs)
512 _execute_task(task, data) # Re-execute locally
513 else:
--> 514 raise_exception(exc, tb)
515 res, worker_id = loads(res_info)
516 state["cache"][key] = res
/raid/nicholasb/dev/dask/dask/local.py in reraise(exc, tb)
323 if exc.__traceback__ is not tb:
324 raise exc.with_traceback(tb)
--> 325 raise exc
326
327
/raid/nicholasb/dev/dask/dask/local.py in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
221 try:
222 task, data = loads(task_info)
--> 223 result = _execute_task(task, data)
224 id = get_id()
225 result = dumps((result, id))
/raid/nicholasb/dev/dask/dask/core.py in _execute_task(arg, cache, dsk)
119 # temporaries by their reference count and can execute certain
120 # operations in-place.
--> 121 return func(*(_execute_task(a, cache) for a in args))
122 elif not ishashable(arg):
123 return arg
/raid/nicholasb/miniconda3/envs/cuml-dev/lib/python3.8/site-packages/scipy/linalg/basic.py in solve_triangular(a, b, trans, lower, unit_diagonal, overwrite_b, debug, check_finite)
331 'versions of SciPy.', DeprecationWarning, stacklevel=2)
332
--> 333 a1 = _asarray_validated(a, check_finite=check_finite)
334 b1 = _asarray_validated(b, check_finite=check_finite)
335 if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
/raid/nicholasb/miniconda3/envs/cuml-dev/lib/python3.8/site-packages/scipy/_lib/_util.py in _asarray_validated(a, check_finite, sparse_ok, objects_ok, mask_ok, as_inexact)
260 raise ValueError('masked arrays are not supported')
261 toarray = np.asarray_chkfinite if check_finite else np.asarray
--> 262 a = toarray(a)
263 if not objects_ok:
264 if a.dtype is np.dtype('O'):
/raid/nicholasb/miniconda3/envs/cuml-dev/lib/python3.8/site-packages/numpy/lib/function_base.py in asarray_chkfinite(a, dtype, order)
484
485 """
--> 486 a = asarray(a, dtype=dtype, order=order)
487 if a.dtype.char in typecodes['AllFloat'] and not np.isfinite(a).all():
488 raise ValueError(
/raid/nicholasb/miniconda3/envs/cuml-dev/lib/python3.8/site-packages/numpy/core/_asarray.py in asarray(a, dtype, order, like)
100 return _asarray_with_like(a, dtype=dtype, order=order, like=like)
101
--> 102 return array(a, dtype, copy=False, order=order)
103
104
cupy/core/core.pyx in cupy.core.core.ndarray.__array__()
TypeError: Implicit conversion to a NumPy array is not allowed. Please use `.get()` to construct a NumPy array explicitly.conda list | grep "dask\|cupy\|numpy\|scipy"
cupy 8.5.0 py38ha0d87d6_1 conda-forge
dask 2021.4.0+6.gbc785e1a.dirty dev_0 <develop>
dask-cuda 0.19.0a210326 py38_45 rapidsai-nightly
dask-cudf 0.19.0a210326 py38_gad5452d7eb_288 rapidsai-nightly
dask-glm 0.2.1.dev52+g1daf4c5.d20210407 dev_0 <develop>
dask-labextension 4.0.1 pyhd8ed1ab_0 conda-forge
dask-ml 1.8.0 pyhd8ed1ab_0 conda-forge
numpy 1.20.1 py38h18fd61f_0 conda-forge
numpydoc 1.1.0 py_1 conda-forge
scipy 1.6.0 py38hb2138dd_0 conda-forge
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels