feat(xtoken): cross-tokenizer off-policy distillation#2508
Conversation
|
Hi, @avenkateshha! Thanks for putting this together, and thanks @yuki-97 for the review help. I have only partly reviewed the PR so far, but I left a few comments. I’ll continue reviewing the remaining files separately. |
Comments addressed: #3, #5, NVIDIA-NeMo#7, NVIDIA-NeMo#8, NVIDIA-NeMo#9, NVIDIA-NeMo#10, NVIDIA-NeMo#11. - Rename _load_M -> _get_sparse_projection_matrix and _load_dense_projection -> _get_topk_projection (later removed in favor of module-level cache helpers below). - Drop unused alignment_student_spans / alignment_teacher_spans from the cross-tokenizer batch payload. - Remove NRL_XTOKEN_LOSS_DUMP_DIR debug-dump side effect. - Move Fp32SparseMM, chunk_average_log_probs, valid_chunk_mask to a new shared module nemo_rl/algorithms/x_token/utils.py. - Extract projection-file parsing into utils.parse_projection_file; tokenalign.py and loss_functions.py both go through it. - Move per-instance projection-matrix caches to process-local caches in utils.get_sparse_projection_matrix / get_topk_projection. The driver no longer holds large CUDA tensors; each Ray worker fills its own cache on first loss call. Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
The lint check's whitelist guard requires every file with zero pyrefly errors to appear in project-includes. This PR adds new x_token, arrow_text, and tools/x_token files that are zero-error but were not yet whitelisted, so the guard failed (it reports only the first offender because the step runs under bash -e). Add all six to the whitelist. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
|
/ok to test 96378e4 |
get_topk_projection's cache-miss path returned a freshly-built (indices, likelihoods) tuple instead of the one stored in _TOPK_PROJECTION_CACHE, so the first (miss) call and subsequent (hit) calls handed back different tuple objects. This broke test_repeat_call_hits_cache's identity assertion (a is b). Store and return the same tuple. Also update test_dtensor_hsdp_dispatches_distinct_batches: Policy.train now threads check_dim_skip_keys through common_kwargs (added in this PR for cross-tokenizer skip-keys plumbing), so the expected run_all_workers_sharded_data call must include 'check_dim_skip_keys': None. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <adithyakvh@gmail.com>
|
/ok to test b33b53f |
terrykong
left a comment
There was a problem hiding this comment.
Review: PR #2508 — feat(xtoken): cross-tokenizer off-policy distillation
Reviewed by a coordinated team of 5 agents (rl-expert, test-agent, bug-finder, comment-reviewer, devil's-advocate).
1 blocker (failing test), 9 suggestions (doc fixes, test gaps, defensive hardening, convention alignment), 3 informational.
Prior review threads
29 of 30 existing review threads from @RayenTian are verified resolved in HEAD — confirming replies with permalinks posted below. The 1 genuinely open thread (alignment inside collator, ID:3275131900) is an architectural question for @yuki-97, not a correctness bug.
Devil's advocate summary
9 confirmed | 0 disputed | 2 downgraded | 2 filtered (below confidence threshold)
PR necessity: confirmed — substantial new feature.
Generated by Claude Code
clean_model_name_for_filename strips parameter counts (digits + B/M) and known suffixes; it correctly keeps model-family version numbers, so 'microsoft/phi-4-Base' -> 'microsoft/phi-4'. The assertion expected 'microsoft/phi', which also contradicts the same test's first case that keeps the 'Llama-3.2' version. Correct the expected value. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <adithyakvh@gmail.com>
|
/ok to test 113bb2f |
…fresh prep docs Address review comments on PR NVIDIA-NeMo#2508: - cross_tokenizer_collate: pin tokenizer padding_side="right" so input_lengths / [:length] slicing and token-chunk alignment stay correct for tokenizers that default to left-padding. - xtoken setup: drop hidden .get() defaults in the backend-gate asserts (_v2 -> .get("_v2"); tensor/context_parallel_size -> direct reads) per config-conventions. - get_full_logits_ipc: synchronize after copy_ into the persistent IPC buffer so the consumer can't read a partially written buffer. - xtoken train loop: refresh the stale release_ipc_buffer comment (it is a no-op under the persistent-buffer design). - build_projection_matrix.sh: renumber prep steps to match the guide (seed + Steps 1-3 instead of 1/2/3/4). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
…cstring The class docstring claimed the gold-loss and xtoken-loss modes raise NotImplementedError and that gold-loss uses CE on paired tokens. Both are implemented, and the gold path is (kl_common + l1_uncommon) * T**2 with no CE term (CE is computed only in the P-KL path). Rewrite the mode table and the per-path Returns block to match the implementation, and note that (gold_loss=False, xtoken_loss=True) is rejected in __init__. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
…onLossFn Address review comment on PR NVIDIA-NeMo#2508: the CT loss fn was previously exercised only via mocks. Add CPU synthetic-tensor tests in test_loss_functions.py: - gold path: assert loss == (kl_common + l1_uncommon) * T**2 and independently recompute kl_common from the public helpers, pinning the next-token shift, chunk averaging, common-index slice, forward-KL direction, sample_mask gating, and valid-chunk normalization. - gold path: all-masked sample_mask -> zero loss and zero valid chunks. - _compute_ce: uniform logits -> log(V_student); a masked sample contributes 0. Also drop two dangling references to local-only test files (test_gpu_smoke.py, test_alignment_snapshot.py) from the test_loss_utils.py and test_cross_tokenizer_collate.py module docstrings. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Address review comment on PR NVIDIA-NeMo#2508: build_exact_token_map's collision handling had no value assertions (only cache-identity tests). Add two hand-constructed cases to TestBuildExactTokenMap: - relaxed mode (xtoken_loss=True): on a teacher collision the highest first-projection weight wins, ties break to the lowest student index, and the >= 0.6 threshold (including the 0.6 boundary) is exercised. - strict mode (xtoken_loss=False): weight-1.0 + -1-sentinel exact maps, collision picks the lowest student index, fuzzy rows excluded. Both assert the full common/uncommon student and teacher partitions. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
|
/ok to test a947e52 |
The wrapper unconditionally ran minimal_projection_generator as a seed pass and fed its output into the multitoken step, so the one-command Quickstart produced a different matrix than the documented 3-step recipe. Start the chain at Step 1 (minimal_projection_via_multitoken) building from scratch. minimal_projection_generator.py is left in place for manual use; only the wrapper stops invoking it. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Document x-token (cross-tokenizer) off-policy distillation on the landing page: a new section with single-node and multi-node instructions mirroring the GRPO / On-policy Distillation layout, links to the Nemotron-Pretraining-Specialized-v1.1 dataset and the implementation guide, a Support Matrix row, and a Features entry. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Add a Results section to the cross-tokenizer guide with convergence curves (train/kl_loss, train/ce_loss, train/accuracy) from the 100-step P-KL smoke run (Llama-3.2-1B <- Qwen3-4B), plus throughput (~47k valid tok/s), peak GPU memory (29.5 GB/GPU), and the ~0.6 GB/sample-step teacher-logit IPC tray size. Addresses the review request for quantitative results. Rename the gold_loss=true, xtoken_loss=true mode label from "Gold + x-token loss" to "H-KL (gold + xtoken)". Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Remove the "smoke run" label from the results heading/body and the downstream-eval note, per review preference to present the 100-step P-KL results without the throwaway-run framing. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
@RayenTian, there was an error processing your request: See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/2/ |
|
/ok to test e7b5538 |
Replace the results curves/numbers with a 100-step P-KL run on the default Nemotron-Pretraining-Specialized-v1.1 (Formal-Logic) corpus, and add the train/loss panel (now a 4-panel figure): loss 1.51->0.78, kl 2.67->0.78, ce 0.75->0.39, accuracy 0.82->0.88; ~48k tok/s, peak 29.5 GB/GPU. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
|
/ok to test 8a2af80 |
setup() asserts policy/teacher dtensor_cfg tensor_parallel_size==1 and context_parallel_size==1 via direct indexing, but the test fixture's dtensor_cfg omitted those keys, causing a KeyError in test_setup_injects_vocab_sizes_into_loss_config. Add the keys to match real configs. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: ruit <ruit@nvidia.com>
|
/ok to test 5c2cdd6 |
yuki-97
left a comment
There was a problem hiding this comment.
nit: return "text" and "messages" are duplicated in nemo_rl/data/datasets/response_datasets/arrow_text_dataset.py, let's remove it in a subsequent PR to not block this PR and merge as it.
|
5c2cdd6 already passed in https://github.com/NVIDIA-NeMo/RL/actions/runs/27187146550, so merge without rerun it again. |
What does this PR do ?
Adds cross-tokenizer (x-token) off-policy distillation — distilling a teacher into a student that does not share its tokenizer (e.g. Qwen3-4B → Llama-3.2-1B) by projecting student logits into the teacher's vocab through a precomputed projection matrix.
Summary
TokenAligner(nemo_rl/algorithms/x_token/token_aligner.py) — DP alignment over canonicalized tokens; loads the projection matrix.CrossTokenizerDistillationLossFnwith three modes: P-KL (full-vocab teacher logits via CUDA IPC + microbatch-global top-k), gold-loss (exact-mapped common KL + L1 uncommon tail), H-KL (gold + xtoken) (gold-loss with relaxed>= 0.6exact-map + multi-token collision resolution).tools/x_token/(minimal_projection_generator,minimal_projection_via_multitoken,reapply_exact_map,sort_and_cut_projection_matrix), chained bytools/x_token/build_projection_matrix.sh.docs/guides/xtoken-off-policy-distillation.md(registered indocs/index.md).This PR supersedes #2347 (which was force-closed when its head branch was renamed). RayenTian's review feedback on #2347 is addressed in commit
755fb8e4:minimal_projection_via_multitoken.py: restored--output-filenameoverride; extended gemma vocab-size branch to also fire for qwen3.5; documented.ptsave-key schema.minimal_projection_generator.py: matching schema annotation neartorch.save.reapply_exact_map.py: validate loaded projection map is a dict withindices/likelihoods; raiseValueErrorwith the file path on mismatch.sort_and_cut_projection_matrix.py: factored argparse intoparse_arguments(); factored verbose stats intoprint_projection_statistics(); replaced positionalinput_pathwith--initial-projection-pathto match the other tools.docs/guides/xtoken-distillation.md: synced the Step 4 example with the new CLI.Issues
Pending: #2682
Usage
Results — 100-step P-KL run
Llama-3.2-1B ← Qwen3-4B, default config (gbs 96, mbs 1, seq 2048, 2 nodes) on the default Nemotron-Pretraining-Specialized-v1.1 (Formal-Logic) corpus, 100 steps:
Throughput/memory per step: mean step 4.07 s · ≈48k valid tok/s · peak 29.5 GB/GPU · teacher-logit IPC tray ≈0.6 GB/sample-step (
[T_t≈2048, V_t≈151,936]bf16, zero-copy same-node handle). Convergence curves: see the guide's "Results" section.Before your PR is "Ready for review"
Pre checks:
docs/guides/xtoken-off-policy-distillation.md)Test plan
llama_qwen_best_special_exact_map_remapped.ptbitwise (torch.equalon indices and likelihoods).Additional Information
755fb8e4above for the migrated review feedback).