Skip to content

Commit f5d1c41

Browse files
hexagon: dma optimizations (mostly fixing regressions) (#21137)
* hex-fa: add simple dma cache for Mask I noticed that we were refetch the mask rows over and over. This simple cache avoids that. * hex-dma: unset in-order desc bit which caused signficant perf regression We don't rely on true in order processing of the DMA descriptors anywhere. Turns out this mode caused significant regression of around 3-4 TPS during token gen. * hex-rope: update comment to clarify that we don't need in-order DMA completions
1 parent 2405d59 commit f5d1c41

File tree

3 files changed

+74
-17
lines changed

3 files changed

+74
-17
lines changed

ggml/src/ggml-hexagon/htp/flash-attn-ops.c

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,9 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
346346

347347
const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap);
348348

349+
dma_cache m_cache;
350+
dma_cache_init(&m_cache, spad_m, factx->size_m_block, DMA_CACHE_MAX_SIZE);
351+
349352
for (uint32_t ir = ir0; ir < ir1; ++ir) {
350353
const uint32_t iq3 = fastdiv(ir, &factx->src0_div21);
351354
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1);
@@ -389,9 +392,8 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
389392
// Mask
390393
if (mask) {
391394
const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
392-
uint8_t * m_dst = spad_m + (ib % 2) * factx->size_m_block;
393395
// Mask is 1D contiguous for this row
394-
dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
396+
dma_cache_push(dma, &m_cache, m_src, current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
395397
}
396398

397399
// FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u",
@@ -554,7 +556,7 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
554556
// Mask
555557
if (mask) {
556558
const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start);
557-
dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
559+
dma_cache_push(dma, &m_cache, m_src, next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
558560
}
559561

560562
// FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u : iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u",
@@ -684,7 +686,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
684686
octx->src0_spad.size_per_thread = size_q_block * 1;
685687
octx->src1_spad.size_per_thread = factx.size_k_block * 2;
686688
octx->src2_spad.size_per_thread = factx.size_v_block * 2;
687-
octx->src3_spad.size_per_thread = mask ? factx.size_m_block * 2 : 0;
689+
octx->src3_spad.size_per_thread = mask ? factx.size_m_block * DMA_CACHE_MAX_SIZE : 0;
688690
octx->dst_spad.size_per_thread = size_vkq_acc;
689691

690692
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
@@ -705,6 +707,8 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
705707
octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size;
706708
octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
707709

710+
// FARF(ERROR, "fa: qrows-per-thread %u", factx.qrows_per_thread);
711+
708712
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
709713
worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads);
710714
}

ggml/src/ggml-hexagon/htp/hex-dma.h

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -143,16 +143,20 @@ static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t
143143
desc->desc_size = 0; // 1D mode
144144
desc->src_bypass = dma_src_l2_bypass_on;
145145
desc->dst_bypass = dma_dst_l2_bypass_on;
146-
desc->order = 1;
146+
desc->order = 0;
147147
desc->done = 0;
148148
desc->src = (void *) dptr.src;
149149
desc->dst = (void *) dptr.dst;
150150
desc->size = size;
151151

152152
q->dptr[q->push_idx] = dptr;
153153

154-
dmlink(q->tail, desc);
155-
q->tail = (dma_descriptor_2d *) desc;
154+
if (size) {
155+
dmlink(q->tail, desc);
156+
q->tail = (dma_descriptor_2d *) desc;
157+
} else {
158+
desc->done = 1;
159+
}
156160

157161
// FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src);
158162
q->push_idx = (q->push_idx + 1) & q->idx_mask;
@@ -175,7 +179,7 @@ static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t
175179
desc->dst_bypass = dma_dst_l2_bypass_on;
176180
desc->src_comp = 0;
177181
desc->dst_comp = 0;
178-
desc->order = 1;
182+
desc->order = 0;
179183
desc->done = 0;
180184
desc->src_stride = src_stride;
181185
desc->dst_stride = dst_stride;
@@ -197,8 +201,12 @@ static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t
197201

198202
q->dptr[q->push_idx] = dptr;
199203

200-
dmlink(q->tail, desc);
201-
q->tail = desc;
204+
if (nrows) {
205+
dmlink(q->tail, desc);
206+
q->tail = desc;
207+
} else {
208+
desc->done = 1;
209+
}
202210

203211
// FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src);
204212
q->push_idx = (q->push_idx + 1) & q->idx_mask;
@@ -215,12 +223,9 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) {
215223
dma_descriptor_2d * desc = &q->desc[q->pop_idx];
216224

217225
// Wait for desc to complete
218-
while (1) {
219-
dmpoll();
220-
if (desc->done) {
221-
break;
222-
}
226+
while (!desc->done) {
223227
// FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx);
228+
dmpoll();
224229
}
225230

226231
dptr = q->dptr[q->pop_idx];
@@ -312,6 +317,54 @@ static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, dma_ptr dptr, size_
312317
return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows);
313318
}
314319

320+
#define DMA_CACHE_MAX_SIZE 64U
321+
322+
typedef struct {
323+
uint8_t *base;
324+
uint32_t line_size;
325+
uint32_t capacity;
326+
uint32_t src[DMA_CACHE_MAX_SIZE];
327+
uint16_t age[DMA_CACHE_MAX_SIZE];
328+
} dma_cache;
329+
330+
static inline void dma_cache_init(dma_cache *c, uint8_t *base, uint32_t line_size, uint32_t capacity)
331+
{
332+
c->capacity = (capacity > DMA_CACHE_MAX_SIZE) ? DMA_CACHE_MAX_SIZE : capacity;
333+
c->base = base;
334+
c->line_size = line_size;
335+
336+
for (unsigned i=0; i < c->capacity; i++) {
337+
c->src[i] = 0;
338+
c->age[i] = 0;
339+
}
340+
}
341+
342+
static inline bool dma_cache_push(dma_queue *q, dma_cache *c, const uint8_t * src, uint32_t dst_stride, uint32_t src_stride, uint32_t row_size, uint32_t nrows)
343+
{
344+
uint32_t o_idx = 0;
345+
uint16_t o_age = 0;
346+
uint8_t * dst = 0;
347+
348+
for (unsigned i=0; i < c->capacity; i++) {
349+
if (c->src[i] == (uint32_t) src) {
350+
c->age[i] = 0;
351+
dst = c->base + (i * c->line_size); nrows = 0; // dummy dma
352+
// FARF(ERROR, "dma-cache: found %p", src);
353+
} else {
354+
c->age[i]++;
355+
if (c->age[i] > o_age) { o_age = c->age[i]; o_idx = i; }
356+
}
357+
}
358+
if (!dst) {
359+
// FARF(ERROR, "dma-cache: replacing #%u : age %u %p -> %p", o_idx, c->age[o_idx], (void *) c->src[o_idx], src);
360+
c->age[o_idx] = 0;
361+
c->src[o_idx] = (uint32_t) src;
362+
dst = c->base + o_idx * c->line_size; // normal nrows dma
363+
}
364+
365+
return dma_queue_push(q, dma_make_ptr(dst, src), dst_stride, src_stride, row_size, nrows);
366+
}
367+
315368
#ifdef __cplusplus
316369
} // extern "C"
317370
#endif

ggml/src/ggml-hexagon/htp/rope-ops.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,8 +333,8 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
333333
// (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
334334
}
335335

336-
// Skip DMA transactions from prev block (if any)
337-
// No need to wait for these since the DMA is setup for in-order processing
336+
// Skip output DMA transactions from prev block (if any)
337+
// No need to wait for those here since we're explicitly waiting for the latest prefecthes below.
338338
for (uint32_t d=0; d < dma_depth; d++) { dma_queue_pop_nowait(dma_queue); }
339339

340340
// Compute loop

0 commit comments

Comments
 (0)