Skip to content

Post-Scheduling Fusion with TensorCore #295

@SunflowerAries

Description

@SunflowerAries

Hello, I just read your Hidet paper, and it looks pretty powerful, but I have few questions.

Does Hidet support codegen for TensorCore right now? And can it codegen for fused operators like batchmatmul+add+reshape+transpose? This fused operator comes from bert and I want to fuse them into one complex operators to speedup the network execution.

below is my test script, although I've set the precision and mma flag, I do not find wmma instructions in generated cuda code. How can I run this operator on TensorCore? Hope for your help, Thanks

import hidet

# change the cache directory
hidet.option.cache_dir('./outs/cache')

# save the tensor program level ir in operator cache
hidet.option.save_lower_ir()


def main():
    # construct a simple graph
    x = hidet.symbol([16, 256, 512], device='cuda')
    w = hidet.randn([16, 512, 512], device='cuda')
    b = hidet.randn([512], device='cuda')
    x = hidet.ops.batch_matmul(x, w)
    x = x + b
    x = hidet.ops.reshape(x, [16, 256, 8, 64])
    x = hidet.ops.transpose(x, [0, 2, 1, 3])
    
    # x = hidet.ops.pad(x, [3, 3, 3, 3])
    # x = hidet.ops.conv2d(x, w, stride=2)
    # x = hidet.ops.relu(x)
    
    graph = hidet.trace_from(x)
    print(graph)

    # graph optimizations
    with hidet.graph.PassContext() as ctx:
        # save the computation graph level ir
        ctx.save_graph_instrument(out_dir='./outs/graphs')
        ctx.set_precision(dtype='float16')
        ctx.set_reduce_precision(dtype='float32')
        ctx.set_mma('mma')
        graph_opt = hidet.graph.optimize(graph)

    # run the optimized graph
    xx = hidet.randn([16, 256, 512], device='cuda')
    yy = graph_opt(xx)


if __name__ == '__main__':
    main()

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions