Skip to content

avoid host-to-device data transfers repeatedly with respect to mm_input.mrope_position_delta #11046

@ash-sigh

Description

@ash-sigh
def _compute_mrope_positions(
        self, model_runner: ModelRunner, batch: ModelWorkerBatch
    ):
        # batch_size * [3 * seq_len]
        batch_size = self.seq_lens.shape[0]
        mrope_positions_list = [[]] * batch_size
        for batch_idx in range(batch_size):
            mm_input = batch.multimodal_inputs[batch_idx]
            if self.forward_mode.is_decode():
                # 3 * N
                if mm_input is None:
                    mrope_positions_list[batch_idx] = torch.full(
                        (3, 1),
                        self.seq_lens[batch_idx] - 1,
                        dtype=torch.int64,
                        device=model_runner.device,
                    )
                else:
                    mrope_position_deltas = mm_input.mrope_position_delta.flatten().to(
                        model_runner.device, non_blocking=True
                    )
                    mrope_positions_list[batch_idx] = (
                        (mrope_position_deltas + self.seq_lens[batch_idx] - 1)
                        .unsqueeze(0)
                        .repeat(3, 1)
                    )
            elif self.forward_mode.is_extend():
                extend_seq_len, extend_prefix_len = (
                    batch.extend_seq_lens[batch_idx],
                    batch.extend_prefix_lens[batch_idx],
                )
                if mm_input is None:
                    # text only
                    mrope_positions = torch.tensor(
                        [
                            [
                                pos
                                for pos in range(
                                    extend_prefix_len,
                                    extend_prefix_len + extend_seq_len,
                                )
                            ]
                        ]
                        * 3
                    )
                else:
                    mrope_positions = mm_input.mrope_positions[
                        :,
                        extend_prefix_len : extend_prefix_len + extend_seq_len,
                    ]
                mrope_positions_list[batch_idx] = mrope_positions

        self.mrope_positions = torch.cat(
            [pos.to(device=model_runner.device) for pos in mrope_positions_list],
            dim=1,
        ).to(dtype=torch.int64, device=model_runner.device)

In each decode phase, the calls to mrope_position_deltas = mm_input.mrope_position_delta.flatten().to(model_runner.device, non_blocking=True) might introduce host-to-device data transfers repeatedly, which can be a performance bottleneck for device of low host-to-device bandwidth or poor CPU performance. 

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions