This repository was archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.7k
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
net.optimize_for doesn't work with numpy semantics #19446
Copy link
Copy link
Closed
Description
Description
Forward pass after calling optimize_for with specific backend doesn't work. I'm not sure what this error mean, but found a way to overcome this (ugly way :))
Problem occurs on master and 1.x branches
Error Message
Traceback (most recent call last):
File "../d.py", line 23, in <module>
print(net(a, b))
File "/home/bgawrych/Desktop/mxnet/python/mxnet/gluon/block.py", line 1407, in __call__
return super().__call__(x, *args)
File "/home/bgawrych/Desktop/mxnet/python/mxnet/gluon/block.py", line 716, in __call__
_check_all_np_ndarrays(out)
File "/home/bgawrych/Desktop/mxnet/python/mxnet/gluon/utils.py", line 480, in _check_all_np_ndarrays
raise TypeError("Block's output ndarrays/symbols must be of type `mxnet.numpy.ndarray`"
TypeError: Block's output ndarrays/symbols must be of type `mxnet.numpy.ndarray` or `mxnet.symbol.numpy._Symbol`, while got output type <class 'mxnet.ndarray.ndarray.NDArray'>
To Reproduce
import mxnet as mx
from mxnet.gluon import HybridBlock
mx.npx.set_np()
class TestBlock(HybridBlock):
def __init__(self):
super(TestBlock, self).__init__()
self.d = mx.gluon.nn.Dense(1)
def hybrid_forward(self, F, a, b, *args):
res = self.d.hybrid_forward(F, a, b)
return res
a = mx.np.random.uniform(low=-1, high=1, size=(1,1))
b = mx.np.random.uniform(low=-1, high=1, size=(1,1))
net = TestBlock()
net.initialize()
net.hybridize()
print(net(a, b))
net.optimize_for(a, b, backend="MKLDNN")
#print(net(a, b)) # <---- this line doesn't work now - we need to reload symbol with JSON
inputs, sym = net._cached_graph
sym = mx.sym.np._symbol.load_json(sym.tojson())
x = mx.gluon.SymbolBlock(sym, [mx.sym.var('data0'), mx.sym.var('data1')], net.collect_params())
print(x(a, b))
What have you tried to solve it?
- Add
ConvertShapeAttrToNumPyCompatible(&g);inMXOptimizeForBackend- doesn't help
@samskalicky maybe you will be able to help