def _compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
):
# batch_size * [3 * seq_len]
batch_size = self.seq_lens.shape[0]
mrope_positions_list = [[]] * batch_size
for batch_idx in range(batch_size):
mm_input = batch.multimodal_inputs[batch_idx]
if self.forward_mode.is_decode():
# 3 * N
if mm_input is None:
mrope_positions_list[batch_idx] = torch.full(
(3, 1),
self.seq_lens[batch_idx] - 1,
dtype=torch.int64,
device=model_runner.device,
)
else:
mrope_position_deltas = mm_input.mrope_position_delta.flatten().to(
model_runner.device, non_blocking=True
)
mrope_positions_list[batch_idx] = (
(mrope_position_deltas + self.seq_lens[batch_idx] - 1)
.unsqueeze(0)
.repeat(3, 1)
)
elif self.forward_mode.is_extend():
extend_seq_len, extend_prefix_len = (
batch.extend_seq_lens[batch_idx],
batch.extend_prefix_lens[batch_idx],
)
if mm_input is None:
# text only
mrope_positions = torch.tensor(
[
[
pos
for pos in range(
extend_prefix_len,
extend_prefix_len + extend_seq_len,
)
]
]
* 3
)
else:
mrope_positions = mm_input.mrope_positions[
:,
extend_prefix_len : extend_prefix_len + extend_seq_len,
]
mrope_positions_list[batch_idx] = mrope_positions
self.mrope_positions = torch.cat(
[pos.to(device=model_runner.device) for pos in mrope_positions_list],
dim=1,
).to(dtype=torch.int64, device=model_runner.device)
In each decode phase, the calls to mrope_position_deltas = mm_input.mrope_position_delta.flatten().to(model_runner.device, non_blocking=True) might introduce host-to-device data transfers repeatedly, which can be a performance bottleneck for device of low host-to-device bandwidth or poor CPU performance.
In each decode phase, the calls to
mrope_position_deltas = mm_input.mrope_position_delta.flatten().to(model_runner.device, non_blocking=True)might introduce host-to-device data transfers repeatedly, which can be a performance bottleneck for device of low host-to-device bandwidth or poor CPU performance.