Skip to content

[Fixbug] Fix a bug when fused operator has no input#263

Merged
yaoyaoding merged 1 commit intohidet-org:mainfrom
yaoyaoding:allen
Jun 1, 2023
Merged

[Fixbug] Fix a bug when fused operator has no input#263
yaoyaoding merged 1 commit intohidet-org:mainfrom
yaoyaoding:allen

Conversation

@yaoyaoding
Copy link
Copy Markdown
Member

Now, we can pass the following test script

import hidet
# add w1, and w2 to get around the error when there are no weights in the fused graph object

w1 = hidet.randn([1], device='cuda', dtype='float16')
w2 = hidet.randn_like(w1)
def hidet_make_causal_mask(seq_len, dtype, device, past_key_values_length):
    # pylint: disable=protected-access
    x = hidet.ops.tri(
        n=seq_len, m=seq_len + past_key_values_length, k=past_key_values_length, dtype=dtype, device=device
    )
    t = (1 - x) * float(dtype._min_value)
    return t + w1, t + w2

s = hidet.symbol(["seqlen", "pastseqlen"], device='cuda')

y = hidet_make_causal_mask(s.shape[0], s.dtype, s.device, s.shape[1])
graph = hidet.trace_from(list(y), s)

from hidet.graph.transforms.subgraph_rewrite import subgraph_rewrite_pass
from hidet.graph.transforms.automatic_mix_precision import automatic_mix_precision_pass
from hidet.graph.transforms.resolve_variant import resolve_variant_pass
from hidet.graph.transforms.fuse_operator import fuse_operator_pass
graph = subgraph_rewrite_pass()(graph)
graph = automatic_mix_precision_pass()(graph)
graph = resolve_variant_pass()(graph)
graph = fuse_operator_pass()(graph)

@yaoyaoding yaoyaoding merged commit a1706b2 into hidet-org:main Jun 1, 2023
@yaoyaoding yaoyaoding deleted the allen branch June 5, 2023 01:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant