Skip to content

[BUG] Conversion compile issue#150

Merged
yaoyaoding merged 6 commits intohidet-org:mainfrom
xinli-git:conversion_fp16
Apr 3, 2023
Merged

[BUG] Conversion compile issue#150
yaoyaoding merged 6 commits intohidet-org:mainfrom
xinli-git:conversion_fp16

Conversation

@xinli-git
Copy link
Copy Markdown
Contributor

Hidet generated kernels contains cast operations such as:

__global__ void __launch_bounds__(500) hidet_compute_z(half* src, uint64_t* dst){
  dst[threadIdx.x] = (uint64_t)src[threadIdx.x];

will not compile due to the following error, as there is no explicit conversion defined from half to uint64_t
The proper conversion is defined here:

After some testing, only int64 and short cannot be directly casted, but would require additional handling, as seen in PR

source.cu(35): error: more than one conversion function from "half" to "uint64_t" applies:
            function "__half::operator float() const"
/usr/local/cuda/bin/../targets/x86_64-linux/include/cuda_fp16.hpp(204): here
            function "__half::operator short() const"
/usr/local/cuda/bin/../targets/x86_64-linux/include/cuda_fp16.hpp(222): here
            function "__half::operator unsigned short() const"
/usr/local/cuda/bin/../targets/x86_64-linux/include/cuda_fp16.hpp(225): here
            function "__half::operator int() const"
/usr/local/cuda/bin/../targets/x86_64-linux/include/cuda_fp16.hpp(228): here
            function "__half::operator unsigned int() const"
/usr/local/cuda/bin/../targets/x86_64-linux/include/cuda_fp16.hpp(231): here
            function "__half::operator long long() const"
/usr/local/cuda/bin/../targets/x86_64-linux/include/cuda_fp16.hpp(234): here
            function "__half::operator unsigned long long() const"
/usr/local/cuda/bin/../targets/x86_64-linux/include/cuda_fp16.hpp(237): here
            function "__half::operator __nv_bool() const"
/usr/local/cuda/bin/../targets/x86_64-linux/include/cuda_fp16.hpp(241): here

1 error detected in the compilation of "source.cu".

return '(int64_t)(__half2ll_rz(' + self(e.expr) + '))'
elif dst_dtype == dtypes.uint64:
return '(uint64_t)(' + self(e.expr) + ')'
return '(uint64_t)(__half2ull_rz(' + self(e.expr) + '))'
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Xin, what if the type of e.expr is not half?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the type of e.expr is guarded by the if statement above

if isinstance(src_dtype, DataType) and isinstance(dst_dtype, DataType) and src_dtype == dtypes.float16:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense. I did not notice the guarding if statement.

I am wondering if (int64_t)(long long)(expr) works or not?

Copy link
Copy Markdown
Contributor Author

@xinli-git xinli-git Apr 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does, and thanks for pointing it out, that would be a better idea because it is more generic, I made the change

Copy link
Copy Markdown
Member

@yaoyaoding yaoyaoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @xinli-git, looks good to me!

@xinli-git
Copy link
Copy Markdown
Contributor Author

Hi @yaoyaoding all looks good now, could you take a final look ?

@yaoyaoding
Copy link
Copy Markdown
Member

Thanks @xinli-git!

@yaoyaoding yaoyaoding merged commit 2b6fa72 into hidet-org:main Apr 3, 2023
@xinli-git xinli-git deleted the conversion_fp16 branch April 3, 2023 13:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants