-
-
Notifications
You must be signed in to change notification settings - Fork 753
Description
We have created a new software package called fastjmd95 that uses numba to accelerate computation of the ocean equation of state. Everything works find with dask and a local scheduler. Now I want to run this code on a distributed dask cluster. It isn't working, I think because the workers are not able to deserialize the numba functions properly.
Original Full Example
This example with real data can be run on any pangeo cluster
Details
from intake import open_catalog
from fastjmd95 import rho
cat = open_catalog("https://raw.githubusercontent.com/pangeo-data/pangeo-datastore/master/intake-catalogs/ocean.yaml")
ds = cat["SOSE"].to_dask()
rhonil = 1025
pa_to_dbar = 1.0/10000
p = ds.PHrefC * rhonil * pa_to_dbar
s = ds.SALT
t = ds.THETA
r = rho(s.data, t.data, 0)
# works fine with local scheduler
r_mean = r[:5].compute()
# now start distributed scheduler
from dask.distributed import Client
client = Client()
r_mean = r[:5].compute()---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-4-7316322484d4> in <module>
----> 1 r_mean = r[:5].compute()
/srv/conda/envs/notebook/lib/python3.7/site-packages/dask/base.py in compute(self, **kwargs)
163 dask.base.compute
164 """
--> 165 (result,) = compute(self, traverse=False, **kwargs)
166 return result
167
/srv/conda/envs/notebook/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs)
434 keys = [x.__dask_keys__() for x in collections]
435 postcomputes = [x.__dask_postcompute__() for x in collections]
--> 436 results = schedule(dsk, keys, **kwargs)
437 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
438
/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/client.py in get(self, dsk, keys, restrictions, loose_restrictions, resources, sync, asynchronous, direct, retries, priority, fifo_timeout, actors, **kwargs)
2571 should_rejoin = False
2572 try:
-> 2573 results = self.gather(packed, asynchronous=asynchronous, direct=direct)
2574 finally:
2575 for f in futures.values():
/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/client.py in gather(self, futures, errors, direct, asynchronous)
1871 direct=direct,
1872 local_worker=local_worker,
-> 1873 asynchronous=asynchronous,
1874 )
1875
/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/client.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
766 else:
767 return sync(
--> 768 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
769 )
770
/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs)
332 if error[0]:
333 typ, exc, tb = error[0]
--> 334 raise exc.with_traceback(tb)
335 else:
336 return result[0]
/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/utils.py in f()
316 if callback_timeout is not None:
317 future = gen.with_timeout(timedelta(seconds=callback_timeout), future)
--> 318 result[0] = yield future
319 except Exception as exc:
320 error[0] = sys.exc_info()
/srv/conda/envs/notebook/lib/python3.7/site-packages/tornado/gen.py in run(self)
733
734 try:
--> 735 value = future.result()
736 except Exception:
737 exc_info = sys.exc_info()
/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/client.py in _gather(self, futures, errors, direct, local_worker)
1727 exc = CancelledError(key)
1728 else:
-> 1729 raise exception.with_traceback(traceback)
1730 raise exc
1731 if errors == "skip":
/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/pickle.py in loads()
57 def loads(x):
58 try:
---> 59 return pickle.loads(x)
60 except Exception:
61 logger.info("Failed to deserialize %s", x[:10000], exc_info=True)
/srv/conda/envs/notebook/lib/python3.7/site-packages/numpy/core/__init__.py in _ufunc_reconstruct()
123 # scipy.special.expit for instance.
124 mod = __import__(module, fromlist=[name])
--> 125 return getattr(mod, name)
126
127 def _ufunc_reduce(func):
AttributeError: module '__main__' has no attribute 'rho'
Minimal Example
I believe this reproduces the core problem
import numpy as np
from numba import vectorize, float64, float32
import dask.array as dsa
from dask.distributed import Client
client = Client()
# define a numba ufunc
@vectorize([float64(float64), float32(float32)], nopython=True)
def test_numba(a):
return a**2
# verify that the client can run it
def try_numba_on_client():
data = np.arange(5, dtype='f4')
return test_numba(data)
client.run(try_numba_on_client)
# works, output is:
# > {'tcp://127.0.0.1:37583': array([ 0., 1., 4., 9., 16.]),
# > 'tcp://127.0.0.1:44855': array([ 0., 1., 4., 9., 16.])}
# use in a computation
data_dask = dsa.arange(5, dtype='f4')
test_numba(data_dask).compute()At this point I get a KilledWorker error. In the worker log, I can see the following error (sorry for the lack of formatting--that's how it comes out of the worker error logs)
distributed.worker - ERROR - module '__main__' has no attribute 'test_numba'
Traceback (most recent call last): File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/worker.py", line 905, in handle_scheduler comm, every_cycle=[self.ensure_communicating, self.ensure_computing] File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/core.py", line 456, in handle_stream msgs = await comm.read() File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/comm/tcp.py", line 222, in read frames, deserialize=self.deserialize, deserializers=deserializers File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/comm/utils.py", line 69, in from_frames res = _from_frames() File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/comm/utils.py", line 55, in _from_frames frames, deserialize=deserialize, deserializers=deserializers File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/core.py", line 124, in loads value = _deserialize(head, fs, deserializers=deserializers) File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/serialize.py", line 255, in deserialize deserializers=deserializers, File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/serialize.py", line 268, in deserialize return loads(header, frames) File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/serialize.py", line 62, in pickle_loads return pickle.loads(b"".join(frames)) File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/pickle.py", line 59, in loads return pickle.loads(x) File "/srv/conda/envs/notebook/lib/python3.7/site-packages/numpy/core/__init__.py", line 125, in _ufunc_reconstruct return getattr(mod, name)
AttributeError: module '__main__' has no attribute 'test_numba'
The basic error appears to be the same as in the full example.
This seems like a pretty straightforward use of numba + distributed, and I assumed this sort of usage was supported. Am I missing something obvious?
Installed versions
I'm on dask 2.9.0 and numba 0.48.0.
Details
>>> client.get_versions(check=True)
{'scheduler': {'host': (('python', '3.7.6.final.0'),
('python-bits', 64),
('OS', 'Linux'),
('OS-release', '4.19.76+'),
('machine', 'x86_64'),
('processor', 'x86_64'),
('byteorder', 'little'),
('LC_ALL', 'en_US.UTF-8'),
('LANG', 'en_US.UTF-8'),
('LOCALE', 'en_US.UTF-8')),
'packages': {'required': (('dask', '2.9.0'),
('distributed', '2.9.0'),
('msgpack', '0.6.2'),
('cloudpickle', '1.2.2'),
('tornado', '6.0.3'),
('toolz', '0.10.0')),
'optional': (('numpy', '1.17.3'),
('pandas', '0.25.3'),
('bokeh', '1.4.0'),
('lz4', '2.2.1'),
('dask_ml', '1.1.1'),
('blosc', '1.8.1'))}},
'workers': {'tcp://10.32.181.10:45663': {'host': (('python', '3.7.6.final.0'),
('python-bits', 64),
('OS', 'Linux'),
('OS-release', '4.19.76+'),
('machine', 'x86_64'),
('processor', 'x86_64'),
('byteorder', 'little'),
('LC_ALL', 'en_US.UTF-8'),
('LANG', 'en_US.UTF-8'),
('LOCALE', 'en_US.UTF-8')),
'packages': {'required': (('dask', '2.9.0'),
('distributed', '2.9.0'),
('msgpack', '0.6.2'),
('cloudpickle', '1.2.2'),
('tornado', '6.0.3'),
('toolz', '0.10.0')),
'optional': (('numpy', '1.17.3'),
('pandas', '0.25.3'),
('bokeh', '1.4.0'),
('lz4', '2.2.1'),
('dask_ml', '1.1.1'),
('blosc', '1.8.1'))}},
'tcp://10.32.181.11:37259': {'host': (('python', '3.7.6.final.0'),
('python-bits', 64),
('OS', 'Linux'),
('OS-release', '4.19.76+'),
('machine', 'x86_64'),
('processor', 'x86_64'),
('byteorder', 'little'),
('LC_ALL', 'en_US.UTF-8'),
('LANG', 'en_US.UTF-8'),
('LOCALE', 'en_US.UTF-8')),
'packages': {'required': (('dask', '2.9.0'),
('distributed', '2.9.0'),
('msgpack', '0.6.2'),
('cloudpickle', '1.2.2'),
('tornado', '6.0.3'),
('toolz', '0.10.0')),
'optional': (('numpy', '1.17.3'),
('pandas', '0.25.3'),
('bokeh', '1.4.0'),
('lz4', '2.2.1'),
('dask_ml', '1.1.1'),
('blosc', '1.8.1'))}}},
'client': {'host': [('python', '3.7.6.final.0'),
('python-bits', 64),
('OS', 'Linux'),
('OS-release', '4.19.76+'),
('machine', 'x86_64'),
('processor', 'x86_64'),
('byteorder', 'little'),
('LC_ALL', 'en_US.UTF-8'),
('LANG', 'en_US.UTF-8'),
('LOCALE', 'en_US.UTF-8')],
'packages': {'required': [('dask', '2.9.0'),
('distributed', '2.9.0'),
('msgpack', '0.6.2'),
('cloudpickle', '1.2.2'),
('tornado', '6.0.3'),
('toolz', '0.10.0')],
'optional': [('numpy', '1.17.3'),
('pandas', '0.25.3'),
('bokeh', '1.4.0'),
('lz4', '2.2.1'),
('dask_ml', '1.1.1'),
('blosc', '1.8.1')]}}}