Skip to content

Commit d33a5e2

Browse files
jeffdailypytorchmergebot
authored andcommitted
[ROCm] fastSpecializedAtomicAdd for MI300 (pytorch#135770)
MI300 adds HW support for packed bfloat16 and fp16. Enable via existing fastSpecializedAtomicAdd. Pull Request resolved: pytorch#135770 Approved by: https://github.com/xw285cornell, https://github.com/jianyuh
1 parent c9653bf commit d33a5e2

1 file changed

Lines changed: 89 additions & 8 deletions

File tree

aten/src/ATen/native/cuda/KernelUtils.cuh

Lines changed: 89 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,74 @@
55
#include <cuda_bf16.h>
66
#endif
77

8+
// ROCm 6.3 is planned to have these functions, but until then here they are.
9+
#if defined(USE_ROCM) && ROCM_VERSION >= 60201
10+
#include <hip/hip_bf16.h>
11+
#include <hip/hip_fp16.h>
12+
13+
__device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* address, __hip_bfloat162 value) {
14+
#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) && \
15+
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2bf16)
16+
typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2;
17+
static_assert(sizeof(vec_short2) == sizeof(__hip_bfloat162_raw));
18+
union {
19+
__hip_bfloat162_raw bf162_raw;
20+
vec_short2 vs2;
21+
} u{static_cast<__hip_bfloat162_raw>(value)};
22+
u.vs2 = __builtin_amdgcn_flat_atomic_fadd_v2bf16((vec_short2*)address, u.vs2);
23+
return static_cast<__hip_bfloat162>(u.bf162_raw);
24+
#else
25+
static_assert(sizeof(unsigned int) == sizeof(__hip_bfloat162_raw));
26+
union u_hold {
27+
__hip_bfloat162_raw h2r;
28+
unsigned int u32;
29+
};
30+
u_hold old_val, new_val;
31+
old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
32+
do {
33+
new_val.h2r = __hadd2(old_val.h2r, value);
34+
} while (!__hip_atomic_compare_exchange_strong(
35+
(unsigned int*)address, &old_val.u32, new_val.u32,
36+
__ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT));
37+
return old_val.h2r;
38+
#endif
39+
}
40+
41+
__device__ inline __half2 preview_unsafeAtomicAdd(__half2* address, __half2 value) {
42+
#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) && \
43+
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2f16)
44+
// The api expects an ext_vector_type of half
45+
typedef _Float16 __attribute__((ext_vector_type(2))) vec_fp162;
46+
static_assert(sizeof(vec_fp162) == sizeof(__half2_raw));
47+
union {
48+
__half2_raw h2r;
49+
vec_fp162 fp16;
50+
} u {static_cast<__half2_raw>(value)};
51+
u.fp16 = __builtin_amdgcn_flat_atomic_fadd_v2f16((vec_fp162*)address, u.fp16);
52+
return static_cast<__half2>(u.h2r);
53+
#else
54+
static_assert(sizeof(__half2_raw) == sizeof(unsigned int));
55+
union u_hold {
56+
__half2_raw h2r;
57+
unsigned int u32;
58+
};
59+
u_hold old_val, new_val;
60+
old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
61+
do {
62+
new_val.h2r = __hadd2(old_val.h2r, value);
63+
} while (!__hip_atomic_compare_exchange_strong(
64+
(unsigned int*)address, &old_val.u32, new_val.u32,
65+
__ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT));
66+
return old_val.h2r;
67+
#endif
68+
}
69+
#define ATOMICADD preview_unsafeAtomicAdd
70+
#define NATIVE_ZERO_BF16 __float2bfloat16(0.0f)
71+
#else
72+
#define ATOMICADD atomicAdd
73+
#define NATIVE_ZERO_BF16 __int2bfloat16_rz(0)
74+
#endif
75+
876
namespace at:: native {
977

1078
__device__ __forceinline__ size_t
@@ -47,7 +115,7 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
47115
const index_t numel,
48116
scalar_t value) {
49117
#if ( \
50-
(defined(USE_ROCM)) || \
118+
(defined(USE_ROCM) && ROCM_VERSION < 60201) || \
51119
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
52120
gpuAtomicAddNoReturn(
53121
reinterpret_cast<at::Half*>(tensor) + index,
@@ -61,17 +129,22 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
61129
__half2 value2;
62130
value2.x = static_cast<__half>(value);
63131
value2.y = __int2half_rz(0);
64-
atomicAdd(reinterpret_cast<__half2*>(target_addr), value2);
132+
ATOMICADD(reinterpret_cast<__half2*>(target_addr), value2);
65133

66134
} else if (!low_byte && index > 0) {
67135
__half2 value2;
68136
value2.x = __int2half_rz(0);
69137
value2.y = static_cast<__half>(value);
70-
atomicAdd(reinterpret_cast<__half2*>(target_addr - 1), value2);
138+
ATOMICADD(reinterpret_cast<__half2*>(target_addr - 1), value2);
71139

72140
} else {
141+
#ifdef USE_ROCM
142+
gpuAtomicAddNoReturn(
143+
reinterpret_cast<at::Half*>(tensor) + index, static_cast<at::Half>(value));
144+
#else
73145
atomicAdd(
74146
reinterpret_cast<__half*>(tensor) + index, static_cast<__half>(value));
147+
#endif
75148
}
76149
#endif
77150
}
@@ -87,7 +160,7 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
87160
const index_t numel,
88161
scalar_t value) {
89162
#if ( \
90-
(defined(USE_ROCM)) || \
163+
(defined(USE_ROCM) && ROCM_VERSION < 60201) || \
91164
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))
92165
gpuAtomicAddNoReturn(
93166
reinterpret_cast<at::BFloat16*>(tensor) + index,
@@ -100,18 +173,23 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd(
100173
if (low_byte && index < (numel - 1)) {
101174
__nv_bfloat162 value2;
102175
value2.x = *reinterpret_cast<__nv_bfloat16*>(&value);
103-
value2.y = __int2bfloat16_rz(0);
104-
atomicAdd(reinterpret_cast<__nv_bfloat162*>(target_addr), value2);
176+
value2.y = NATIVE_ZERO_BF16;
177+
ATOMICADD(reinterpret_cast<__nv_bfloat162*>(target_addr), value2);
105178

106179
} else if (!low_byte && index > 0) {
107180
__nv_bfloat162 value2;
108-
value2.x = __int2bfloat16_rz(0);
181+
value2.x = NATIVE_ZERO_BF16;
109182
value2.y = *reinterpret_cast<__nv_bfloat16*>(&value);
110-
atomicAdd(reinterpret_cast<__nv_bfloat162*>(target_addr - 1), value2);
183+
ATOMICADD(reinterpret_cast<__nv_bfloat162*>(target_addr - 1), value2);
111184

112185
} else {
186+
#ifdef USE_ROCM
187+
gpuAtomicAddNoReturn(
188+
reinterpret_cast<at::BFloat16*>(tensor) + index, static_cast<at::BFloat16>(value));
189+
#else
113190
atomicAdd(
114191
reinterpret_cast<__nv_bfloat16*>(tensor) + index, *reinterpret_cast<__nv_bfloat16*>(&value));
192+
#endif
115193
}
116194
#endif
117195
}
@@ -144,4 +222,7 @@ __device__ __forceinline__ void fastAtomicAdd(
144222
}
145223
}
146224

225+
#undef ATOMICADD
226+
#undef NATIVE_ZERO_BF16
227+
147228
} // namespace at::native

0 commit comments

Comments
 (0)