Skip to content

Implement continuous batching#1642

Open
rltakashige wants to merge 27 commits intomainfrom
leo/implement-continuous-batching
Open

Implement continuous batching#1642
rltakashige wants to merge 27 commits intomainfrom
leo/implement-continuous-batching

Conversation

@rltakashige
Copy link
Collaborator

@rltakashige rltakashige commented Mar 2, 2026

Motivation

Following the changes made in #1632 !
Closes #1020

Changes

Why It Works

Test Plan

Manual Testing

Automated Testing

@rltakashige rltakashige force-pushed the leo/implement-continuous-batching branch from f0938d6 to e660cab Compare March 2, 2026 02:29
@rltakashige rltakashige force-pushed the leo/prepare-batch-implementation branch from 904b12f to 33f57c6 Compare March 2, 2026 02:30
@rltakashige rltakashige force-pushed the leo/implement-continuous-batching branch from e660cab to b64d929 Compare March 2, 2026 02:30
@rltakashige rltakashige force-pushed the leo/prepare-batch-implementation branch 2 times, most recently from 7ae0586 to b771f67 Compare March 2, 2026 03:09
@rltakashige rltakashige force-pushed the leo/prepare-batch-implementation branch from b771f67 to 8cb9bac Compare March 2, 2026 03:09
@rltakashige rltakashige force-pushed the leo/implement-continuous-batching branch from b64d929 to fbd14c7 Compare March 2, 2026 03:33
@rltakashige rltakashige requested a review from Evanev7 March 2, 2026 03:34
@rltakashige rltakashige marked this pull request as ready for review March 2, 2026 03:34
@rltakashige rltakashige force-pushed the leo/prepare-batch-implementation branch 2 times, most recently from ca53647 to 0b00f1a Compare March 2, 2026 17:00
@Evanev7 Evanev7 force-pushed the leo/prepare-batch-implementation branch from 593cd59 to b05ddff Compare March 2, 2026 17:03
@rltakashige rltakashige force-pushed the leo/prepare-batch-implementation branch 2 times, most recently from b9c4199 to f6eccf1 Compare March 2, 2026 17:26
@Evanev7 Evanev7 force-pushed the leo/prepare-batch-implementation branch from f6eccf1 to 6962838 Compare March 3, 2026 10:49
Base automatically changed from leo/prepare-batch-implementation to main March 3, 2026 14:38
@rltakashige rltakashige force-pushed the leo/implement-continuous-batching branch from 7622da0 to c79e0b1 Compare March 3, 2026 15:49
@rltakashige rltakashige force-pushed the leo/implement-continuous-batching branch 3 times, most recently from 6153a8f to d82a7ad Compare March 4, 2026 23:10
@rltakashige rltakashige force-pushed the leo/implement-continuous-batching branch from d82a7ad to c87f1c7 Compare March 5, 2026 00:21
rltakashige and others added 6 commits March 5, 2026 10:44
@Evanev7
Copy link
Member

Evanev7 commented Mar 6, 2026

ready for review? or not quite yet

@rltakashige rltakashige force-pushed the leo/implement-continuous-batching branch from d4d3c03 to 0164aee Compare March 6, 2026 14:26
@rltakashige rltakashige force-pushed the leo/implement-continuous-batching branch from 911d362 to 2f5b829 Compare March 7, 2026 20:39
Copy link
Member

@Evanev7 Evanev7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exciting!!! these are mostly stylistic changes with one or two minor correctness things we were probably doing wrong before anyway.

!!!!! continuous batching !!!!!

else:
cache = make_kv_cache(self.model)

seed = task_params.seed or 42
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: should use explicit None check - this will override 0 with 42.


last_tokens = prompt_tokens[-2:]

logits_processors: list[Any] = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i believe we have some repetition penalty logits processors? assuming that merged, this should presumably duplicate that logic (or maybe a single make_logits_processors idk)


max_tokens = task_params.max_output_tokens or MAX_TOKENS

uids = self._mlx_gen.insert(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this ever return multiple uids? we should guard that case

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reading further this seems to be for multiple insertion - assuming we have no interest in multiple insertion, we should just assert it's a single uid.

return []

responses = self._mlx_gen.next()
mx.clear_cache()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this clear_cache could use a comment imo

results: list[tuple[int, GenerationResponse]] = []

for response in responses:
if response.uid not in self._active_tasks:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feels like an error we should report

] = field(default_factory=dict, init=False)

def __post_init__(self) -> None:
self._mlx_gen = ExoBatchGenerator(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_mlx_gen -> _exo_gen

while self._queue and len(self._active_tasks) < EXO_MAX_CONCURRENT_REQUESTS:
task = self._queue.popleft()
try:
uid = self._build_generator(task)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is not clear that both _build_generator and _mlx_gen.submit run prefill immediately - i think we should keep behaviour as is but change the interface slightly

self._active_tasks[uid] = (task, queue, output_generator)

if not self._mlx_gen.has_work:
return self._drain_cancellations()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly "drain cancellations" imo implies removing a cancellation rather than removing it's corresponding task

] = []
for uid, response in results:
if uid not in self._active_tasks:
# should we error here?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ref comment and review comment above, i think at least a log is due here

def build(
self,
) -> InferenceGenerator:
import os
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no reason to import late here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FEATURE] Concurrent inference / continuous batching.

2 participants