Skip to content

feat(xtoken): cross-tokenizer off-policy distillation#2508

Merged
yuki-97 merged 75 commits into
NVIDIA-NeMo:mainfrom
avenkateshha:avenkateshha/xtoken-off-policy-distillation
Jun 9, 2026
Merged

feat(xtoken): cross-tokenizer off-policy distillation#2508
yuki-97 merged 75 commits into
NVIDIA-NeMo:mainfrom
avenkateshha:avenkateshha/xtoken-off-policy-distillation

Conversation

@avenkateshha

@avenkateshha avenkateshha commented May 16, 2026

Copy link
Copy Markdown
Contributor

What does this PR do ?

Adds cross-tokenizer (x-token) off-policy distillation — distilling a teacher into a student that does not share its tokenizer (e.g. Qwen3-4B → Llama-3.2-1B) by projecting student logits into the teacher's vocab through a precomputed projection matrix.

Summary

  • TokenAligner (nemo_rl/algorithms/x_token/token_aligner.py) — DP alignment over canonicalized tokens; loads the projection matrix.
  • CrossTokenizerDistillationLossFn with three modes: P-KL (full-vocab teacher logits via CUDA IPC + microbatch-global top-k), gold-loss (exact-mapped common KL + L1 uncommon tail), H-KL (gold + xtoken) (gold-loss with relaxed >= 0.6 exact-map + multi-token collision resolution).
  • Backend: DTensor V2 only (tensor/context parallel must be 1); teacher logits travel via same-node CUDA IPC. Megatron and sequence packing are not supported.
  • Projection-prep CLIs under tools/x_token/ (minimal_projection_generator, minimal_projection_via_multitoken, reapply_exact_map, sort_and_cut_projection_matrix), chained by tools/x_token/build_projection_matrix.sh.
  • Docs guide at docs/guides/xtoken-off-policy-distillation.md (registered in docs/index.md).

This PR supersedes #2347 (which was force-closed when its head branch was renamed). RayenTian's review feedback on #2347 is addressed in commit 755fb8e4:

  • minimal_projection_via_multitoken.py: restored --output-filename override; extended gemma vocab-size branch to also fire for qwen3.5; documented .pt save-key schema.
  • minimal_projection_generator.py: matching schema annotation near torch.save.
  • reapply_exact_map.py: validate loaded projection map is a dict with indices/likelihoods; raise ValueError with the file path on mismatch.
  • sort_and_cut_projection_matrix.py: factored argparse into parse_arguments(); factored verbose stats into print_projection_statistics(); replaced positional input_path with --initial-projection-path to match the other tools.
  • docs/guides/xtoken-distillation.md: synced the Step 4 example with the new CLI.

Issues

Pending: #2682

Usage

# 1. Build the (student, teacher) projection matrix
./tools/x_token/build_projection_matrix.sh \
    --student-model meta-llama/Llama-3.2-1B \
    --teacher-model Qwen/Qwen3-4B \
    --runtime-top-k 4

# 2. Launch distillation
uv run python examples/run_xtoken_off_policy_distillation.py \
    --config examples/configs/xtoken_off_policy_distillation.yaml \
    loss_fn.projection_matrix_path=cross_tokenizer_data/projection_matrix_Llama-3.2_Qwen3_top4.pt

Results — 100-step P-KL run

Llama-3.2-1B ← Qwen3-4B, default config (gbs 96, mbs 1, seq 2048, 2 nodes) on the default Nemotron-Pretraining-Specialized-v1.1 (Formal-Logic) corpus, 100 steps:

train/loss, train/kl_loss, train/ce_loss, train/accuracy over 100 P-KL distillation steps

Metric Start → End
train/loss 1.51 → 0.78
train/kl_loss 2.67 → 0.78
train/ce_loss 0.75 → 0.39
train/accuracy 0.82 → 0.88

Throughput/memory per step: mean step 4.07 s · ≈48k valid tok/s · peak 29.5 GB/GPU · teacher-logit IPC tray ≈0.6 GB/sample-step ([T_t≈2048, V_t≈151,936] bf16, zero-copy same-node handle). Convergence curves: see the guide's "Results" section.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests? (loss, token aligner, collator, projection-prep CLIs)
  • Did you run the unit tests and functional tests locally?
  • Did you add or update any necessary documentation? (docs/guides/xtoken-off-policy-distillation.md)

Test plan

  • CI green
  • Projection-prep reproducibility: re-running the CLI tools on the Llama-3.2-3B ↔ Qwen3-4B-Base pair reproduces the canonical llama_qwen_best_special_exact_map_remapped.pt bitwise (torch.equal on indices and likelihoods).
  • Smoke run: 2-node Llama-3.2-1B ← Qwen3-4B, 100-step P-KL on the default Nemotron-Pretraining-Specialized-v1.1 (Formal-Logic) corpus.

Additional Information

@avenkateshha avenkateshha requested review from a team as code owners May 16, 2026 02:05
@copy-pr-bot

copy-pr-bot Bot commented May 16, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@github-actions github-actions Bot added Documentation Improvements or additions to documentation community-request labels May 16, 2026
@svcnvidia-nemo-ci svcnvidia-nemo-ci added the waiting-on-maintainers Waiting on maintainers to respond label May 18, 2026
Comment thread nemo_rl/algorithms/loss/loss_functions.py
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
@svcnvidia-nemo-ci svcnvidia-nemo-ci added waiting-on-customer Waiting on the original author to respond and removed waiting-on-maintainers Waiting on maintainers to respond labels May 20, 2026
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
Comment thread nemo_rl/data/cross_tokenizer_collate.py Outdated
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
@RayenTian

Copy link
Copy Markdown
Contributor

Hi, @avenkateshha! Thanks for putting this together, and thanks @yuki-97 for the review help. I have only partly reviewed the PR so far, but I left a few comments. I’ll continue reviewing the remaining files separately.

avenkateshha added a commit to avenkateshha/RL that referenced this pull request May 20, 2026
Comments addressed: #3, #5, NVIDIA-NeMo#7, NVIDIA-NeMo#8, NVIDIA-NeMo#9, NVIDIA-NeMo#10, NVIDIA-NeMo#11.

- Rename _load_M -> _get_sparse_projection_matrix and
  _load_dense_projection -> _get_topk_projection (later removed in
  favor of module-level cache helpers below).
- Drop unused alignment_student_spans / alignment_teacher_spans
  from the cross-tokenizer batch payload.
- Remove NRL_XTOKEN_LOSS_DUMP_DIR debug-dump side effect.
- Move Fp32SparseMM, chunk_average_log_probs, valid_chunk_mask to a
  new shared module nemo_rl/algorithms/x_token/utils.py.
- Extract projection-file parsing into utils.parse_projection_file;
  tokenalign.py and loss_functions.py both go through it.
- Move per-instance projection-matrix caches to process-local caches
  in utils.get_sparse_projection_matrix / get_topk_projection. The
  driver no longer holds large CUDA tensors; each Ray worker fills
  its own cache on first loss call.

Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
The lint check's whitelist guard requires every file with zero pyrefly
errors to appear in project-includes. This PR adds new x_token, arrow_text,
and tools/x_token files that are zero-error but were not yet whitelisted,
so the guard failed (it reports only the first offender because the step
runs under bash -e). Add all six to the whitelist.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
@avenkateshha

Copy link
Copy Markdown
Contributor Author

/ok to test 96378e4

get_topk_projection's cache-miss path returned a freshly-built
(indices, likelihoods) tuple instead of the one stored in
_TOPK_PROJECTION_CACHE, so the first (miss) call and subsequent (hit)
calls handed back different tuple objects. This broke
test_repeat_call_hits_cache's identity assertion (a is b). Store and
return the same tuple.

Also update test_dtensor_hsdp_dispatches_distinct_batches: Policy.train
now threads check_dim_skip_keys through common_kwargs (added in this PR
for cross-tokenizer skip-keys plumbing), so the expected
run_all_workers_sharded_data call must include 'check_dim_skip_keys': None.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <adithyakvh@gmail.com>
@avenkateshha

Copy link
Copy Markdown
Contributor Author

/ok to test b33b53f

@terrykong terrykong left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review: PR #2508 — feat(xtoken): cross-tokenizer off-policy distillation

Reviewed by a coordinated team of 5 agents (rl-expert, test-agent, bug-finder, comment-reviewer, devil's-advocate).

1 blocker (failing test), 9 suggestions (doc fixes, test gaps, defensive hardening, convention alignment), 3 informational.

Prior review threads

29 of 30 existing review threads from @RayenTian are verified resolved in HEAD — confirming replies with permalinks posted below. The 1 genuinely open thread (alignment inside collator, ID:3275131900) is an architectural question for @yuki-97, not a correctness bug.

Devil's advocate summary

9 confirmed | 0 disputed | 2 downgraded | 2 filtered (below confidence threshold)
PR necessity: confirmed — substantial new feature.

Generated by Claude Code

Comment thread tests/unit/tools/x_token/test_projection_tools.py Outdated
Comment thread nemo_rl/algorithms/loss/loss_functions.py Outdated
Comment thread nemo_rl/algorithms/loss/loss_functions.py
Comment thread examples/run_xtoken_off_policy_distillation.py
Comment thread nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
Comment thread nemo_rl/algorithms/x_token/loss_utils.py
Comment thread nemo_rl/algorithms/xtoken_off_policy_distillation.py
Comment thread tools/x_token/build_projection_matrix.sh Outdated
Comment thread nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
Comment thread nemo_rl/data/cross_tokenizer_collate.py
clean_model_name_for_filename strips parameter counts (digits + B/M) and
known suffixes; it correctly keeps model-family version numbers, so
'microsoft/phi-4-Base' -> 'microsoft/phi-4'. The assertion expected
'microsoft/phi', which also contradicts the same test's first case that
keeps the 'Llama-3.2' version. Correct the expected value.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <adithyakvh@gmail.com>
@avenkateshha

Copy link
Copy Markdown
Contributor Author

/ok to test 113bb2f

avenkateshha and others added 4 commits June 8, 2026 11:19
…fresh prep docs

Address review comments on PR NVIDIA-NeMo#2508:
- cross_tokenizer_collate: pin tokenizer padding_side="right" so
  input_lengths / [:length] slicing and token-chunk alignment stay
  correct for tokenizers that default to left-padding.
- xtoken setup: drop hidden .get() defaults in the backend-gate asserts
  (_v2 -> .get("_v2"); tensor/context_parallel_size -> direct reads) per
  config-conventions.
- get_full_logits_ipc: synchronize after copy_ into the persistent IPC
  buffer so the consumer can't read a partially written buffer.
- xtoken train loop: refresh the stale release_ipc_buffer comment (it is
  a no-op under the persistent-buffer design).
- build_projection_matrix.sh: renumber prep steps to match the guide
  (seed + Steps 1-3 instead of 1/2/3/4).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
…cstring

The class docstring claimed the gold-loss and xtoken-loss modes raise
NotImplementedError and that gold-loss uses CE on paired tokens. Both are
implemented, and the gold path is (kl_common + l1_uncommon) * T**2 with no CE
term (CE is computed only in the P-KL path). Rewrite the mode table and the
per-path Returns block to match the implementation, and note that
(gold_loss=False, xtoken_loss=True) is rejected in __init__.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
…onLossFn

Address review comment on PR NVIDIA-NeMo#2508: the CT loss fn was previously exercised
only via mocks. Add CPU synthetic-tensor tests in test_loss_functions.py:
- gold path: assert loss == (kl_common + l1_uncommon) * T**2 and independently
  recompute kl_common from the public helpers, pinning the next-token shift,
  chunk averaging, common-index slice, forward-KL direction, sample_mask
  gating, and valid-chunk normalization.
- gold path: all-masked sample_mask -> zero loss and zero valid chunks.
- _compute_ce: uniform logits -> log(V_student); a masked sample contributes 0.

Also drop two dangling references to local-only test files (test_gpu_smoke.py,
test_alignment_snapshot.py) from the test_loss_utils.py and
test_cross_tokenizer_collate.py module docstrings.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Address review comment on PR NVIDIA-NeMo#2508: build_exact_token_map's collision
handling had no value assertions (only cache-identity tests). Add two
hand-constructed cases to TestBuildExactTokenMap:
- relaxed mode (xtoken_loss=True): on a teacher collision the highest
  first-projection weight wins, ties break to the lowest student index, and
  the >= 0.6 threshold (including the 0.6 boundary) is exercised.
- strict mode (xtoken_loss=False): weight-1.0 + -1-sentinel exact maps,
  collision picks the lowest student index, fuzzy rows excluded.
Both assert the full common/uncommon student and teacher partitions.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
@avenkateshha

Copy link
Copy Markdown
Contributor Author

/ok to test a947e52

avenkateshha and others added 5 commits June 8, 2026 19:18
The wrapper unconditionally ran minimal_projection_generator as a seed
pass and fed its output into the multitoken step, so the one-command
Quickstart produced a different matrix than the documented 3-step recipe.
Start the chain at Step 1 (minimal_projection_via_multitoken) building
from scratch. minimal_projection_generator.py is left in place for manual
use; only the wrapper stops invoking it.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Document x-token (cross-tokenizer) off-policy distillation on the landing
page: a new section with single-node and multi-node instructions mirroring
the GRPO / On-policy Distillation layout, links to the
Nemotron-Pretraining-Specialized-v1.1 dataset and the implementation guide,
a Support Matrix row, and a Features entry.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Add a Results section to the cross-tokenizer guide with convergence curves
(train/kl_loss, train/ce_loss, train/accuracy) from the 100-step P-KL smoke
run (Llama-3.2-1B <- Qwen3-4B), plus throughput (~47k valid tok/s), peak GPU
memory (29.5 GB/GPU), and the ~0.6 GB/sample-step teacher-logit IPC tray
size. Addresses the review request for quantitative results.

Rename the gold_loss=true, xtoken_loss=true mode label from "Gold + x-token
loss" to "H-KL (gold + xtoken)".

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Remove the "smoke run" label from the results heading/body and the
downstream-eval note, per review preference to present the 100-step P-KL
results without the throwaway-run framing.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
@copy-pr-bot

copy-pr-bot Bot commented Jun 9, 2026

Copy link
Copy Markdown

/ok to test 6a0e03a

@RayenTian, there was an error processing your request: E2

See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/2/

@RayenTian

Copy link
Copy Markdown
Contributor

/ok to test e7b5538

Replace the results curves/numbers with a 100-step P-KL run on the default
Nemotron-Pretraining-Specialized-v1.1 (Formal-Logic) corpus, and add the
train/loss panel (now a 4-panel figure): loss 1.51->0.78, kl 2.67->0.78,
ce 0.75->0.39, accuracy 0.82->0.88; ~48k tok/s, peak 29.5 GB/GPU.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
@RayenTian

Copy link
Copy Markdown
Contributor

/ok to test 8a2af80

RayenTian and others added 2 commits June 8, 2026 22:59
setup() asserts policy/teacher dtensor_cfg tensor_parallel_size==1 and
context_parallel_size==1 via direct indexing, but the test fixture's
dtensor_cfg omitted those keys, causing a KeyError in
test_setup_injects_vocab_sizes_into_loss_config. Add the keys to match
real configs.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: ruit <ruit@nvidia.com>
@RayenTian

Copy link
Copy Markdown
Contributor

/ok to test 5c2cdd6

RayenTian
RayenTian previously approved these changes Jun 9, 2026
terrykong
terrykong previously approved these changes Jun 9, 2026

@yuki-97 yuki-97 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: return "text" and "messages" are duplicated in nemo_rl/data/datasets/response_datasets/arrow_text_dataset.py, let's remove it in a subsequent PR to not block this PR and merge as it.

@yuki-97

yuki-97 commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

5c2cdd6 already passed in https://github.com/NVIDIA-NeMo/RL/actions/runs/27187146550, so merge without rerun it again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:Lfast Runs a fast test suite and re-use nightly `main` container (but sync dependencies to PRs version) Documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants