Skip to content

problems with numba ufunc + distributed #3450

@rabernat

Description

@rabernat

We have created a new software package called fastjmd95 that uses numba to accelerate computation of the ocean equation of state. Everything works find with dask and a local scheduler. Now I want to run this code on a distributed dask cluster. It isn't working, I think because the workers are not able to deserialize the numba functions properly.

Original Full Example

This example with real data can be run on any pangeo cluster

Details
from intake import open_catalog
from fastjmd95 import rho

cat = open_catalog("https://raw.githubusercontent.com/pangeo-data/pangeo-datastore/master/intake-catalogs/ocean.yaml")
ds  = cat["SOSE"].to_dask()

rhonil = 1025
pa_to_dbar = 1.0/10000
p = ds.PHrefC * rhonil * pa_to_dbar
s = ds.SALT
t = ds.THETA
r = rho(s.data, t.data, 0)
# works fine with local scheduler
r_mean = r[:5].compute()

# now start distributed scheduler
from dask.distributed import Client
client = Client()
r_mean = r[:5].compute()
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-4-7316322484d4> in <module>
----> 1 r_mean = r[:5].compute()

/srv/conda/envs/notebook/lib/python3.7/site-packages/dask/base.py in compute(self, **kwargs)
    163         dask.base.compute
    164         """
--> 165         (result,) = compute(self, traverse=False, **kwargs)
    166         return result
    167 

/srv/conda/envs/notebook/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs)
    434     keys = [x.__dask_keys__() for x in collections]
    435     postcomputes = [x.__dask_postcompute__() for x in collections]
--> 436     results = schedule(dsk, keys, **kwargs)
    437     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    438 

/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/client.py in get(self, dsk, keys, restrictions, loose_restrictions, resources, sync, asynchronous, direct, retries, priority, fifo_timeout, actors, **kwargs)
   2571                     should_rejoin = False
   2572             try:
-> 2573                 results = self.gather(packed, asynchronous=asynchronous, direct=direct)
   2574             finally:
   2575                 for f in futures.values():

/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/client.py in gather(self, futures, errors, direct, asynchronous)
   1871                 direct=direct,
   1872                 local_worker=local_worker,
-> 1873                 asynchronous=asynchronous,
   1874             )
   1875 

/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/client.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
    766         else:
    767             return sync(
--> 768                 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
    769             )
    770 

/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs)
    332     if error[0]:
    333         typ, exc, tb = error[0]
--> 334         raise exc.with_traceback(tb)
    335     else:
    336         return result[0]

/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/utils.py in f()
    316             if callback_timeout is not None:
    317                 future = gen.with_timeout(timedelta(seconds=callback_timeout), future)
--> 318             result[0] = yield future
    319         except Exception as exc:
    320             error[0] = sys.exc_info()

/srv/conda/envs/notebook/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/client.py in _gather(self, futures, errors, direct, local_worker)
   1727                             exc = CancelledError(key)
   1728                         else:
-> 1729                             raise exception.with_traceback(traceback)
   1730                         raise exc
   1731                     if errors == "skip":

/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/pickle.py in loads()
     57 def loads(x):
     58     try:
---> 59         return pickle.loads(x)
     60     except Exception:
     61         logger.info("Failed to deserialize %s", x[:10000], exc_info=True)

/srv/conda/envs/notebook/lib/python3.7/site-packages/numpy/core/__init__.py in _ufunc_reconstruct()
    123     # scipy.special.expit for instance.
    124     mod = __import__(module, fromlist=[name])
--> 125     return getattr(mod, name)
    126 
    127 def _ufunc_reduce(func):

AttributeError: module '__main__' has no attribute 'rho'

Minimal Example

I believe this reproduces the core problem

import numpy as np
from numba import vectorize, float64, float32
import dask.array as dsa
from dask.distributed import Client
client = Client()

# define a numba ufunc
@vectorize([float64(float64), float32(float32)], nopython=True)
def test_numba(a):
    return a**2

# verify that the client can run it
def try_numba_on_client():
    data = np.arange(5, dtype='f4')
    return test_numba(data)
client.run(try_numba_on_client)
# works, output is:
# > {'tcp://127.0.0.1:37583': array([ 0.,  1.,  4.,  9., 16.]),
# > 'tcp://127.0.0.1:44855': array([ 0.,  1.,  4.,  9., 16.])}

# use in a computation
data_dask = dsa.arange(5, dtype='f4')
test_numba(data_dask).compute()

At this point I get a KilledWorker error. In the worker log, I can see the following error (sorry for the lack of formatting--that's how it comes out of the worker error logs)

distributed.worker - ERROR - module '__main__' has no attribute 'test_numba'
Traceback (most recent call last): File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/worker.py", line 905, in handle_scheduler comm, every_cycle=[self.ensure_communicating, self.ensure_computing] File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/core.py", line 456, in handle_stream msgs = await comm.read() File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/comm/tcp.py", line 222, in read frames, deserialize=self.deserialize, deserializers=deserializers File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/comm/utils.py", line 69, in from_frames res = _from_frames() File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/comm/utils.py", line 55, in _from_frames frames, deserialize=deserialize, deserializers=deserializers File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/core.py", line 124, in loads value = _deserialize(head, fs, deserializers=deserializers) File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/serialize.py", line 255, in deserialize deserializers=deserializers, File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/serialize.py", line 268, in deserialize return loads(header, frames) File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/serialize.py", line 62, in pickle_loads return pickle.loads(b"".join(frames)) File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/pickle.py", line 59, in loads return pickle.loads(x) File "/srv/conda/envs/notebook/lib/python3.7/site-packages/numpy/core/__init__.py", line 125, in _ufunc_reconstruct return getattr(mod, name)
AttributeError: module '__main__' has no attribute 'test_numba'

The basic error appears to be the same as in the full example.

This seems like a pretty straightforward use of numba + distributed, and I assumed this sort of usage was supported. Am I missing something obvious?

Installed versions

I'm on dask 2.9.0 and numba 0.48.0.

Details
>>> client.get_versions(check=True)
{'scheduler': {'host': (('python', '3.7.6.final.0'),
   ('python-bits', 64),
   ('OS', 'Linux'),
   ('OS-release', '4.19.76+'),
   ('machine', 'x86_64'),
   ('processor', 'x86_64'),
   ('byteorder', 'little'),
   ('LC_ALL', 'en_US.UTF-8'),
   ('LANG', 'en_US.UTF-8'),
   ('LOCALE', 'en_US.UTF-8')),
  'packages': {'required': (('dask', '2.9.0'),
    ('distributed', '2.9.0'),
    ('msgpack', '0.6.2'),
    ('cloudpickle', '1.2.2'),
    ('tornado', '6.0.3'),
    ('toolz', '0.10.0')),
   'optional': (('numpy', '1.17.3'),
    ('pandas', '0.25.3'),
    ('bokeh', '1.4.0'),
    ('lz4', '2.2.1'),
    ('dask_ml', '1.1.1'),
    ('blosc', '1.8.1'))}},
 'workers': {'tcp://10.32.181.10:45663': {'host': (('python', '3.7.6.final.0'),
    ('python-bits', 64),
    ('OS', 'Linux'),
    ('OS-release', '4.19.76+'),
    ('machine', 'x86_64'),
    ('processor', 'x86_64'),
    ('byteorder', 'little'),
    ('LC_ALL', 'en_US.UTF-8'),
    ('LANG', 'en_US.UTF-8'),
    ('LOCALE', 'en_US.UTF-8')),
   'packages': {'required': (('dask', '2.9.0'),
     ('distributed', '2.9.0'),
     ('msgpack', '0.6.2'),
     ('cloudpickle', '1.2.2'),
     ('tornado', '6.0.3'),
     ('toolz', '0.10.0')),
    'optional': (('numpy', '1.17.3'),
     ('pandas', '0.25.3'),
     ('bokeh', '1.4.0'),
     ('lz4', '2.2.1'),
     ('dask_ml', '1.1.1'),
     ('blosc', '1.8.1'))}},
  'tcp://10.32.181.11:37259': {'host': (('python', '3.7.6.final.0'),
    ('python-bits', 64),
    ('OS', 'Linux'),
    ('OS-release', '4.19.76+'),
    ('machine', 'x86_64'),
    ('processor', 'x86_64'),
    ('byteorder', 'little'),
    ('LC_ALL', 'en_US.UTF-8'),
    ('LANG', 'en_US.UTF-8'),
    ('LOCALE', 'en_US.UTF-8')),
   'packages': {'required': (('dask', '2.9.0'),
     ('distributed', '2.9.0'),
     ('msgpack', '0.6.2'),
     ('cloudpickle', '1.2.2'),
     ('tornado', '6.0.3'),
     ('toolz', '0.10.0')),
    'optional': (('numpy', '1.17.3'),
     ('pandas', '0.25.3'),
     ('bokeh', '1.4.0'),
     ('lz4', '2.2.1'),
     ('dask_ml', '1.1.1'),
     ('blosc', '1.8.1'))}}},
 'client': {'host': [('python', '3.7.6.final.0'),
   ('python-bits', 64),
   ('OS', 'Linux'),
   ('OS-release', '4.19.76+'),
   ('machine', 'x86_64'),
   ('processor', 'x86_64'),
   ('byteorder', 'little'),
   ('LC_ALL', 'en_US.UTF-8'),
   ('LANG', 'en_US.UTF-8'),
   ('LOCALE', 'en_US.UTF-8')],
  'packages': {'required': [('dask', '2.9.0'),
    ('distributed', '2.9.0'),
    ('msgpack', '0.6.2'),
    ('cloudpickle', '1.2.2'),
    ('tornado', '6.0.3'),
    ('toolz', '0.10.0')],
   'optional': [('numpy', '1.17.3'),
    ('pandas', '0.25.3'),
    ('bokeh', '1.4.0'),
    ('lz4', '2.2.1'),
    ('dask_ml', '1.1.1'),
    ('blosc', '1.8.1')]}}}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions