@@ -85,34 +85,20 @@ __device__ __forceinline__ T gelu_tanh(const T& x) {
8585void silu_and_mul (at::Tensor& out, at::Tensor& input, bool enable_pdl) {
8686 int d = input.size (-1 ) / 2 ;
8787 int64_t num_tokens = input.numel () / input.size (-1 );
88+ dim3 grid (num_tokens);
89+
8890 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
89- const c10 ::cuda::OptionalCUDAGuard device_guard (device_of (input));
91+ const at ::cuda::OptionalCUDAGuard device_guard (device_of (input));
9092
9193 DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16 (input.scalar_type (), c_type, [&] {
9294 uint32_t vec_size = 16 / sizeof (c_type);
93- #if USE_ROCM
94- dim3 grid (num_tokens);
9595 dim3 block (std::min (d / vec_size, 1024U ));
96+ #if USE_ROCM
9697 sgl_hip::activation::act_and_mul_kernel<c_type, silu>
9798 <<<grid, block, 0 , stream>>> (static_cast <c_type*>(out.data_ptr ()), static_cast <c_type*>(input.data_ptr ()), d);
9899#else
99- cudaLaunchConfig_t config;
100- config.gridDim = num_tokens;
101- config.blockDim = std::min (d / vec_size, 1024U );
102- config.dynamicSmemBytes = 0 ;
103- config.stream = stream;
104- cudaLaunchAttribute attrs[1 ];
105- attrs[0 ].id = cudaLaunchAttributeProgrammaticStreamSerialization;
106- attrs[0 ].val .programmaticStreamSerializationAllowed = enable_pdl;
107- config.numAttrs = 1 ;
108- config.attrs = attrs;
109-
110- auto kernel = flashinfer::activation::act_and_mul_kernel<c_type, silu>;
111- cudaLaunchKernelEx (
112- &config, kernel, static_cast <c_type*>(out.data_ptr ()), static_cast <c_type*>(input.data_ptr ()), d);
113-
114- cudaError_t err = cudaGetLastError ();
115- TORCH_CHECK (err == cudaSuccess, " Failed to launch kernel: " , cudaGetErrorString (err));
100+ flashinfer::activation::act_and_mul_kernel<c_type, silu>
101+ <<<grid, block, 0 , stream>>> (static_cast <c_type*>(out.data_ptr ()), static_cast <c_type*>(input.data_ptr ()), d);
116102#endif
117103 return true ;
118104 });
0 commit comments