-
Notifications
You must be signed in to change notification settings - Fork 68
[Bug] fusion rewrite fails #99
Copy link
Copy link
Closed
Description
Running this snippet:
import torch
import hidet
import onnx
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
y = torch.min(y, dim=0)[0]
z = x / y
return z
device = 'cuda'
model = Foo()
model.to(device)
x = torch.rand([1, 1, 1, 1, 1], device=device)
y = torch.rand([2, 2], device=device)
z = model(x, y)
print(z.shape)
torch.onnx.export(model, (x, y), 'tmp.onnx', input_names = ['x', 'y'],
output_names = ['z'])
model = onnx.load('tmp.onnx')
hidet.torch.dynamo_config.search_space(1)
x = hidet.from_torch(x)
y = hidet.from_torch(y)
symbol_data = [hidet.symbol_like(x), hidet.symbol_like(y)]
hidet_onnx_module = hidet.graph.frontend.from_onnx(model)
symbol_output = hidet_onnx_module(*symbol_data)
graph: hidet.FlowGraph = hidet.trace_from(symbol_output, inputs=symbol_data)
with hidet.graph.PassContext() as ctx:
graph_opt: hidet.FlowGraph = hidet.graph.optimize(graph)
cuda_graph = graph_opt.cuda_graph()
outputs = cuda_graph.run([x, y])
raises error message:
File "/home/su/accdiff/thirdparty/hidet/python/hidet/transforms/tools/apply_prologue_epilogue.py", line 172, in visit_BufferStoreStmt
remap: Dict[Var, Expr] = {a: b for a, b in strict_zip(tc.axes, out_indices)}
File "/home/su/accdiff/thirdparty/hidet/python/hidet/utils/py.py", line 55, in strict_zip
raise ValueError(
ValueError: Expect two sequence have the same length in zip, got length 5 and 1.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels