Skip to content

feat(xtoken): support TP/CP/diff-DP sharded cross-tokenizer distillation#2745

Open
RayenTian wants to merge 2 commits into
mainfrom
ruit/xtoken-tp-cp
Open

feat(xtoken): support TP/CP/diff-DP sharded cross-tokenizer distillation#2745
RayenTian wants to merge 2 commits into
mainfrom
ruit/xtoken-tp-cp

Conversation

@RayenTian

@RayenTian RayenTian commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

What does this PR do ?

Adds TP / CP / heterogeneous-DP sharding support to the cross-tokenizer (xtoken) off-policy distillation loss (P-KL + gold) on the DTensor v2 worker, fixes a context-parallel loss-invariance bug, and adds a parallelism-invariance nightly.

Issues

Closes #2682

Following #2792 (Make building the (student, teacher) projection matrix inflight )

Result

Teacher : Qwen/Qwen3-4B TP2 CP2 DP2
Student : meta-llama/Llama-3.2-1B TP4 CP2 DP1

KL Loss

image

Gold Loss

image

XT Loss

image

Usage

# 1. Build the (student, teacher) projection matrix from the tokenizer pair (one-time, offline) -> writes <output-dir>/<output-filename>_special.pt.
uv run python -m tools.x_token.minimal_projection_via_multitoken --student-model meta-llama/Llama-3.2-1B --teacher-model Qwen/Qwen3-4B --top-k 4 --enable-special-token-mapping --enable-exact-match --disable-reverse-pass --disable-scale-trick --output-filename xtoken_proj --output-dir /tmp/xtoken_proj
# 2. Distill: student TP4×CP2 from a teacher running TP2×CP2 (heterogeneous layout); the loss stays parallelism-invariant across TP/CP/DP.
uv run examples/run_xtoken_off_policy_distillation.py --config examples/configs/recipes/llm/xtoken-off-policy-distillation-qwen3-4b-to-llama3.2-1b-1n8g-dtensor-tp4cp2.yaml loss_fn.projection_matrix_path=/tmp/xtoken_proj/xtoken_proj_special.pt
# key knobs: policy.dtensor_cfg.{tensor_parallel_size=4,context_parallel_size=2}, teacher.dtensor_cfg.{tensor_parallel_size=2,context_parallel_size=2}

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

@copy-pr-bot

copy-pr-bot Bot commented Jun 9, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@RayenTian RayenTian force-pushed the ruit/xtoken-tp-cp branch from cb37de4 to 1537d7d Compare June 9, 2026 14:32
@RayenTian RayenTian force-pushed the ruit/xtoken-tp-cp branch from 1537d7d to bbe782f Compare June 10, 2026 04:57
@RayenTian RayenTian force-pushed the ruit/xtoken-tp-cp branch from 39cf1b2 to 5185829 Compare June 11, 2026 02:36
@RayenTian RayenTian added the CI:Lfast Runs a fast test suite and re-use nightly `main` container (but sync dependencies to PRs version) label Jun 12, 2026
@RayenTian

Copy link
Copy Markdown
Contributor Author

/ok to test 040969c

@RayenTian

Copy link
Copy Markdown
Contributor Author

/ok to test 17bec2f

RayenTian and others added 2 commits June 11, 2026 23:18
Extend cross-tokenizer off-policy distillation to run with the student under
tensor- and context-parallelism (and a data-parallel degree that may differ
from the teacher's) on the automodel/DTensor policy worker.

- Teacher full-vocab logits are exported per rank and shipped to the student
  via CUDA IPC (FullLogitsPostProcessor), then reassembled on the consumer
  across its CP group for heterogeneous teacher/student TP/CP layouts.
- The loss runs TP/CP-aware: vocab-parallel log-softmax / argmax / projection,
  CP load-balanced -> contiguous re-layout, CP-aware next-token shift, and
  partial chunk-average + grad-preserving all-reduce. The generic collectives
  live in model_utils; the cross-tokenizer orchestration in x_token/loss_utils.
- TP/CP process groups are derived from the student logits' own device mesh
  instead of being threaded through the generic LossPostProcessor, so the
  SFT / GRPO / distillation LOGPROB paths keep using the DTensor branch.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: ruit <ruit@nvidia.com>
…on nightly

- Unit tests for the TP/CP/diff-DP setup grid and the vocab-/CP-parallel loss
  helpers (real 2-GPU NCCL actors), plus per-step IPC buffer release. Adds the
  x_token unit-test package __init__.py so the Ray actor FQN is importable.
- Nightly recipe distillation-xtoken-off-policy-qwen3-4b-to-llama3.2-1b-1n8g-
  dtensor-tp4cp2 (student TP4xCP2 <- teacher TP2xCP2) and its driver, wired into
  nightly.txt; guards sharded-loss parallelism-invariance.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: ruit <ruit@nvidia.com>
@RayenTian

Copy link
Copy Markdown
Contributor Author

/ok to test b02211f

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:Lfast Runs a fast test suite and re-use nightly `main` container (but sync dependencies to PRs version)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

test(xtoken): add nightly with heterogeneous TP/CP parallel plan for cross-tokenizer distillation

1 participant