-
Notifications
You must be signed in to change notification settings - Fork 68
[Bug] Use int64 in argmax #103
Copy link
Copy link
Closed
Description
Now hidet uses int32 as the return type of ArgReduceTask,
| extent=x_shape[dim], fcompute=reduce_fcompute, reduce_type=reduce_type, index_dtype='int32' |
which is misaligned with torch and onnx that return int64, leading to incompatibility with other operators. Like concatenation requires inputs to have the same dtype. So concatenating the output of argmax with an int64 tensor is legal in torch but illegal in hidet.
Here is a simple snippet:
import torch
import hidet
import onnx
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
y = torch.argmax(y, dim=0)
print(y.dtype) # int64
return torch.concat([x, y])
device = 'cuda'
model = Foo()
model.to(device)
x = torch.ones([5], dtype=torch.int64, device=device)
y = torch.rand([5, 5], device=device)
z = model(x, y)
print(z.shape)
torch.onnx.export(model, (x, y), 'tmp.onnx', input_names = ['x', 'y'],
output_names = ['z'])
model = onnx.load('tmp.onnx')
hidet.torch.dynamo_config.search_space(1)
x = hidet.from_torch(x)
y = hidet.from_torch(y)
symbol_data = [hidet.symbol_like(x), hidet.symbol_like(y)]
hidet_onnx_module = hidet.graph.frontend.from_onnx(model)
symbol_output = hidet_onnx_module(*symbol_data)
graph: hidet.FlowGraph = hidet.trace_from(symbol_output, inputs=symbol_data)
with hidet.graph.PassContext() as ctx:
graph_opt: hidet.FlowGraph = hidet.graph.optimize(graph)
cuda_graph = graph_opt.cuda_graph()
outputs = cuda_graph.run([x, y])
which raises an error:
ValueError: concat: expect all tensors have the same dtype, but got:
Tensor(shape=(5,), dtype='int64', device='cuda:0')
Tensor(shape=(5,), dtype='int32', device='cuda:0')
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels