Skip to content

[Bug] fusion rewrite fails #99

@soodoshll

Description

@soodoshll

Running this snippet:

import torch
import hidet
import onnx

class Foo(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x, y):
        y = torch.min(y, dim=0)[0]
        z = x / y
        return z

device = 'cuda'

model = Foo()
model.to(device)

x = torch.rand([1, 1, 1, 1, 1], device=device)
y = torch.rand([2, 2], device=device)

z = model(x, y)
print(z.shape)

torch.onnx.export(model, (x, y), 'tmp.onnx', input_names = ['x', 'y'],
                  output_names = ['z'])
model = onnx.load('tmp.onnx')

hidet.torch.dynamo_config.search_space(1)

x = hidet.from_torch(x)
y = hidet.from_torch(y)
symbol_data = [hidet.symbol_like(x), hidet.symbol_like(y)]
hidet_onnx_module = hidet.graph.frontend.from_onnx(model)
symbol_output = hidet_onnx_module(*symbol_data)
graph: hidet.FlowGraph = hidet.trace_from(symbol_output, inputs=symbol_data)
with hidet.graph.PassContext() as ctx:
    graph_opt: hidet.FlowGraph = hidet.graph.optimize(graph)
cuda_graph = graph_opt.cuda_graph()
outputs = cuda_graph.run([x, y])

raises error message:

  File "/home/su/accdiff/thirdparty/hidet/python/hidet/transforms/tools/apply_prologue_epilogue.py", line 172, in visit_BufferStoreStmt
    remap: Dict[Var, Expr] = {a: b for a, b in strict_zip(tc.axes, out_indices)}
  File "/home/su/accdiff/thirdparty/hidet/python/hidet/utils/py.py", line 55, in strict_zip
    raise ValueError(
ValueError: Expect two sequence have the same length in zip, got length 5 and 1.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions