-
Notifications
You must be signed in to change notification settings - Fork 68
[Bug] Some hidet tensor methods do not support symbolic tensors? #213
Copy link
Copy link
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Hi, thanks for the great work!
I am wondering why some hidet tensor methods (e.g., to, cuda, and cpu) do not support symbolic tensors.
class TestMode(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Linear(10, 10)
def forward(self, x):
z = x.unsqueeze(0).expand(4, 4, 512).to(torch.device("cuda"))
return z
if __name__ == "__main__":
model = TestMode()
model = model.eval().half()
model = model.to(device)
hidet.torch.dynamo_config.search_space(2)
hidet.torch.dynamo_config.use_fp16()
model_opt = torch.compile(model, backend='hidet')
tokens = torch.zeros(20, 10).cuda()
model_opt(tokens)In the above test case, the exception
NotImplementedError: hidet: Tensor.to(..., device=...) is not supported for symbolic tensors., occurred when calling tensor_to(Tensor(shape=(4, 4, 512), dtype='bool', device='cuda:0'), device(type='cuda')) is raised.
I think the operation (.to(device)) is a common operation for deep learning models as the implementation of huggingface llama
Are there any concerns or limitations regarding these operations for symbolic trace?
Look forward to your response. Thanks!
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working