Skip to content

[Bug] Error out for "hidet.ops.matmul" #326

@liangdzou

Description

@liangdzou

Describe the bug
nvcc require -std=c++11 to compile the cu file. Here is the Error log.

Tensor(shape=(2, 3), dtype='float32', device='cuda:0')
[[ 0.6858963  -0.3749267  -0.00427682]
 [ 0.19206195  2.0178227   0.16703984]]
Compiling cuda task matmul(a=float32(2, 3), b=float32(3, 2), c=float32(2, 2))...
Traceback (most recent call last):
  File "matmul.py", line 7, in <module>
    d = hidet.ops.matmul(a, b)
  File "/home/users/liang.zou/software/virtualenv/py38/lib/python3.8/site-packages/hidet/graph/ops/matmul/matmul.py", line 31, in matmul
    return MatmulOp(a, b, require_prologue=require_prologue).get_output(0)
  File "/home/users/liang.zou/software/virtualenv/py38/lib/python3.8/site-packages/hidet/graph/ops/matmul/matmul.py", line 27, in __init__
    super().__init__(inputs=[a, b], attributes={'require_prologue': require_prologue}, task=task)
  File "/home/users/liang.zou/software/virtualenv/py38/lib/python3.8/site-packages/hidet/graph/operator.py", line 45, in __init__
    self.outputs = self._run()
  File "/home/users/liang.zou/software/virtualenv/py38/lib/python3.8/site-packages/hidet/graph/operator.py", line 100, in _run
    return self.imperative_run(self.inputs)
  File "/home/users/liang.zou/software/virtualenv/py38/lib/python3.8/site-packages/hidet/graph/operator.py", line 156, in imperative_run
    outputs = self.compiled_task.run_async(inputs)
  File "/home/users/liang.zou/software/virtualenv/py38/lib/python3.8/site-packages/hidet/graph/operator.py", line 93, in compiled_task
    self._compiled_task = self.task.build(target=self.build_target)
  File "/home/users/liang.zou/software/virtualenv/py38/lib/python3.8/site-packages/hidet/ir/task.py", line 250, in build
    return build_task(self, target=target, load=load)
  File "/home/users/liang.zou/software/virtualenv/py38/lib/python3.8/site-packages/hidet/drivers/build_task.py", line 244, in build_task
    build_task_module(task, candidates, task_dir, target)
  File "/home/users/liang.zou/software/virtualenv/py38/lib/python3.8/site-packages/hidet/drivers/build_task.py", line 133, in build_task_module
    build_ir_module(ir_module=task_ir_module, output_dir=task_dir, output_kind='.so', target=target)
  File "/home/users/liang.zou/software/virtualenv/py38/lib/python3.8/site-packages/hidet/drivers/build_module.py", line 62, in build_ir_module
    compile_source(
  File "/home/users/liang.zou/software/virtualenv/py38/lib/python3.8/site-packages/hidet/backend/build.py", line 288, in compile_source
    compiler.compile(
  File "/home/users/liang.zou/software/virtualenv/py38/lib/python3.8/site-packages/hidet/backend/build.py", line 181, in compile
    self.run_compile_command(" ".join(command), src_path, out_lib_path)
  File "/home/users/liang.zou/software/virtualenv/py38/lib/python3.8/site-packages/hidet/backend/build.py", line 75, in run_compile_command
    raise CompilationFailed(src_path, message)
hidet.backend.build.CompilationFailed: failed to compile file:///home/users/liang.zou/.hidet/cache/ops/cuda_space_0/matmul/9563d8248463c985/source.cu
Command: /usr/local/cuda-11.6/bin/nvcc -I/home/users/liang.zou/software/virtualenv/py38/lib/python3.8/site-packages/hidet/include -L/home/users/liang.zou/software/virtualenv/py38/lib/python3.8/site-packages/hidet/lib -O3 -Xcompiler -fopenmp,-fPIC,-m64,-mavx2,-march=native,-O3,-funroll-loops,-ffast-math -gencode arch=compute_61,code=sm_61 --ptxas-options=-v -lineinfo -ftz=true -prec-div=false -lhidet_runtime --cudart shared --diag-suppress 177 --diag-suppress 179 --diag-suppress 39 --shared  /home/users/liang.zou/.hidet/cache/ops/cuda_space_0/matmul/9563d8248463c985/source.cu -o /home/users/liang.zou/.hidet/cache/ops/cuda_space_0/matmul/9563d8248463c985/lib.so
In file included from /usr/include/c++/4.8.2/cstdint:35:0,
                 from /home/users/liang.zou/software/virtualenv/py38/lib/python3.8/site-packages/hidet/include/hidet/runtime/symbols.h:16,
                 from /home/users/liang.zou/.hidet/cache/ops/cuda_space_0/matmul/9563d8248463c985/source.cu:2:
/usr/include/c++/4.8.2/bits/c++0x_warning.h:32:2: error: #error This file requires compiler and library support for the ISO C++ 2011 standard. This support is currently experimental, and must be enabled with the -std=c++11 or -std=gnu++11 compiler options.
 #error This file requires compiler and library support for the \
  ^

To Reproduce
Version:

pip show hidet
Name: hidet
Version: 0.2.4
Summary: Hidet: a compilation-based DNN inference framework.
Home-page: https://docs.hidet.org
Author: 
Author-email: 
License: Apache-2.0
Location: /home/users/liang.zou/software/virtualenv/py38/lib/python3.8/site-packages
Requires: astunparse, click, cuda-python, numpy, nvtx, packaging, psutil, tabulate, tqdm
Required-by:

Python src file:

import hidet

a = hidet.randn([2, 3], device='cuda')
print(a)
b = hidet.randn([3, 2], device='cuda')
c = hidet.randn([2], device='cuda')
d = hidet.ops.matmul(a, b)
d = d + c  # 'd + c' is equivalent to 'hidet.ops.add(d, c)'
print(d)

Expected behavior
Expect to compute matmul and print the result, but fail compile with nvcc.

Enviroment

  • OS: centos 7.6.1810
  • GPU: NVIDIA TITAN Xp
  • Others: NVIDIA GPU Driver Version: 515.76, CUDA Version: 11.7

Additional context
After I add -std=c++11 the nvcc command compile successfully.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions