Xtoken/off policy distillation gh#5
Closed
avenkateshha wants to merge 22 commits into
Closed
Conversation
added 22 commits
April 9, 2026 17:46
- Off-policy distillation pipeline (teacher Llama-3.1-8B, student Llama-3.2-1B) with arrow dataset support and inline MATH/MMLU generation-based evaluation - Compute distributed log_softmax before top-k for correct KL divergence - Add CUDA IPC buffer mechanism to avoid Ray object store bottleneck for large top-k logprob tensors (based on dtensor_sharath.py approach) - Update loss function to skip re-normalization of teacher log probabilities - Add submit scripts, configs, and eval benchmarks Made-with: Cursor
Made-with: Cursor
- Flatten teacher IPC data structure and use mb_idx*mbs indexing instead of cumulative mb_offset/mb_size for microbatch slicing - Add log_softmax for teacher top-k logits in standard (non-IPC) path - Restore output_replicated for teacher path and add kl_loss/nll_loss to aggregated results - Make arrow config self-contained with explicit settings Made-with: Cursor
Made-with: Cursor
Refactor teacher logit sharing to use per-microbatch IPC buffers instead of accumulating all teacher logits post-loop. Update loss functions to handle optional microbatch indexing. Bump config to TP=4 and 10k steps. Made-with: Cursor
Made-with: Cursor
- Add `use_ipc` config flag to switch between IPC (in-process communication) and non-IPC (data-dict) teacher logprob paths - Simplify KL loss to use (k+1)-dim distributions with a "rest" bucket, unifying IPC and non-IPC code paths - Branch teacher inference and student training for both train and validation loops based on use_ipc setting - Update submit script with IPC test experiment config Made-with: Cursor
- Add x_token/ module with TokenAligner for cross-vocabulary distillation - Rewrite CrossTokenizerDistillationLossFn with chunk-averaged KL that handles 1:1, 1:many, many:1, and many:many token alignments, matching the original TokenAligner.compute_KL_loss_optimized() exactly (verified via sanity check: 0.00% difference) - Fix teacher IPC to send full logits (topk_logits=None) instead of topk_logits=0 which produced empty tensors - Pass global_top_indices through to projection for memory optimization (2.3GB -> 125MB for projection output tensor) - Add cross-tokenizer data processing in training loop (dual tokenization, alignment, teacher data dict) - Add unbuffered stdout for better SLURM log visibility - Add example config and sanity check script Made-with: Cursor
Documents architecture, all new/modified files, usage instructions, configuration reference, and design decisions for the cross-tokenizer off-policy distillation feature built on NeMo RL v0.5.0. Made-with: Cursor
…compat - Switch default teacher from Qwen3-8B to Phi-4-mini-instruct - Add gold loss (common-vocab KL + uncommon-vocab L1) and xtoken loss modes - Add CE loss with dynamic loss scaling option - Replace dense projection with CSR sparse matmul for memory efficiency - Add MMLU 5-shot evaluation benchmark - Fix NotRequired import for Python <3.11 compatibility (16 files) - Add submit_cross_tokenizer.sh sbatch script - Add sanity check script for alignment and loss verification - Update LR schedule for 80k step training (warmup 4k, cosine 76k) Made-with: Cursor
Made-with: Cursor
…parse projection - Cache CrossTokenizerDistillationLossFn on policy workers at init and pass None to train() calls, eliminating repeated Ray serialization of the loss function (which includes large sparse matrices) each step. - Add set_loss_fn() and update_cross_tokenizer_data() to Policy and DTensorPolicyWorkerV2 to support per-step cross-tokenizer data updates. - Optimize sparse token projection by pre-reducing the sparse matrix with index_select before projection instead of projecting full vocab and slicing afterward. - Use AutoConfig.from_pretrained() for vocab sizes in sanity check script. Made-with: Cursor
…P rank Reduced training time with this optimization. Avoid Ray serialization of the loss function by having each worker construct CrossTokenizerDistillationLossFn from config + shared filesystem. Shard teacher_input_ids and aligned_pairs per data-parallel rank instead of broadcasting the full batch to every worker. Made-with: Cursor
- Add O(n+m) character-offset alignment via two-pointer walk on tokenizer offset mappings, with automatic DP fallback for failed samples - Precompute canonical token ID maps at startup to skip convert_ids_to_tokens - Add Numba JIT-accelerated DP kernel and banded DP variant - Add KD preprocessor preserving raw text for teacher tokenization - Add numba dependency - Update config: expand arrow data glob, set max_num_epochs=1 - Update submit script: bump max_num_steps=10, rename experiment to raw-text-kd-16node Made-with: Cursor
Introduce CUDA kernel and Python integration module for faster TokenAligner dynamic programming base-case computation. Made-with: Cursor
… preprocessing with current-step GPU training while keeping alignment behavior unchanged. Add explicit token-aligner runtime switches (`use_char_offset`, `use_align_fast`, CUDA-DP toggles), clean up dead/duplicated paths, and simplify the step orchestration with typed prefetch payloads and helper extraction for maintainability. Made-with: Cursor
Set explicit token aligner defaults and document the total-GPUs/2 heuristic for cross_tokenizer_num_workers so large-batch off-policy runs can iterate on stable, reproducible CT pool sizing. Made-with: Cursor
Align rebased cross-tokenizer distillation code with main-branch APIs and config expectations. - update stale loss/interface import paths after rebase - migrate DTensorPolicyWorkerV2 init/checkpoint setup to main-style flow - add required sequence_packing config keys instead of setup-side fallbacks - add scoped Phi RoPE buffer repair for post-load meta/non-finite buffer cases Made-with: Cursor
Align rebased cross-tokenizer distillation code with main-branch APIs and config expectations. - update stale loss/interface import paths after rebase - migrate DTensorPolicyWorkerV2 init/checkpoint setup to main-style flow - add required sequence_packing config keys instead of setup-side fallbacks - add scoped Phi RoPE buffer repair for post-load meta/non-finite buffer cases - include cross-tokenizer launch shell scripts used for runs Made-with: Cursor
…token projection tools Refactor off-policy cross-tokenizer scripts to use the new data.train/default schema, remove hard-coded dataset paths, and support user-provided Arrow or dataset_path/HF inputs at submit time. Rename callsites to off_policy_distillation_train for consistency and move xtoken projection matrix generation utilities under the utils folder. Made-with: Cursor
…n training path. Update DTensor worker and lm/off-policy orchestration for IPC-based teacher-student flow, add shared IPC tensor-handle helpers in ipc_utils, and fix xtoken post-processor microbatch buffer handling/config integration to avoid runtime shape and key errors during policy training. Made-with: Cursor
…before KL Prevent loss-magnitude explosions in off-policy distillation when non-IPC teacher top-k paths return raw logits instead of log-probs. Add normalization in off-policy train/validation data preparation so DistillationLoss receives stable top-k log-probs, while keeping shared/common loss code unchanged. Align llama_off_policy_arrow.yaml with the structured data.train/data.default schema, remove legacy hard-coded dataset fields, and bake in common run settings (teacher instruct variant, checkpoint optimizer save, and policy batching/packing defaults) to reduce CLI override churn. Made-with: Cursor
|
This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days. |
|
This PR was closed because it has been inactive for 7 days since being marked as stale. |
avenkateshha
added a commit
that referenced
this pull request
May 20, 2026
Comments addressed: #3, #5, NVIDIA-NeMo#7, NVIDIA-NeMo#8, NVIDIA-NeMo#9, NVIDIA-NeMo#10, NVIDIA-NeMo#11. - Rename _load_M -> _get_sparse_projection_matrix and _load_dense_projection -> _get_topk_projection (later removed in favor of module-level cache helpers below). - Drop unused alignment_student_spans / alignment_teacher_spans from the cross-tokenizer batch payload. - Remove NRL_XTOKEN_LOSS_DUMP_DIR debug-dump side effect. - Move Fp32SparseMM, chunk_average_log_probs, valid_chunk_mask to a new shared module nemo_rl/algorithms/x_token/utils.py. - Extract projection-file parsing into utils.parse_projection_file; tokenalign.py and loss_functions.py both go through it. - Move per-instance projection-matrix caches to process-local caches in utils.get_sparse_projection_matrix / get_topk_projection. The driver no longer holds large CUDA tensors; each Ray worker fills its own cache on first loss call. Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
avenkateshha
added a commit
that referenced
this pull request
May 23, 2026
…t API Address easy-batch items from PR NVIDIA-NeMo#2508 review. - Move sinkhorn_one_dim / apply_canonicalization_if_enabled / clean_model_name_for_filename / project_token_likelihoods to nemo_rl/utils/x_token/_shared.py and import from there in the four projection-prep CLIs (#5). - Drop unused sinkhorn (10-iter), debug_projection_map, and generate_projection_map from minimal_projection_generator.py (NVIDIA-NeMo#8). - Move minimal_projection_generator.py's CLI parsing under `if __name__ == "__main__":` so the module is importable for the P1 dedup harness. - Rename seq1 / seq2 (and the derivative s1_*/s2_*/used_seq*/joined_seq*/ seg* names) to student_tokens / teacher_tokens throughout TokenAligner._align_single / _align_with_anchors / _align_dp; all call sites are internal (NVIDIA-NeMo#11). - Replace **kwargs in TokenAligner._align_with_anchors with the five explicit keyword-only scoring knobs already declared on _align_dp, and pass them through named at the _align_single call site (NVIDIA-NeMo#12). - Remove dead TokenAligner.load_projection_matrix + the private _load_projection_components plus the now-orphaned self._projection_* attributes; the live projection-load path is nemo_rl/algorithms/x_token/utils.py::{get_sparse_projection_matrix, get_topk_projection} added in 6336464 (NVIDIA-NeMo#13, NVIDIA-NeMo#14). - git mv nemo_rl/algorithms/x_token/tokenalign.py -> token_aligner.py and update import sites in __init__.py, data/cross_tokenizer_collate.py, the four utils.x_token CLIs, and the xtoken-distillation guide; update Sphinx :mod: references in algorithms/x_token/utils.py (NVIDIA-NeMo#16). Incidental: the CLIs previously called TokenAligner._canonical_token, which was never a class attribute (the helper is module-level); the new shared helper imports _canonical_token directly so the use_canonicalization branch isn't broken at runtime. Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
avenkateshha
added a commit
that referenced
this pull request
May 27, 2026
Comments addressed: #3, #5, NVIDIA-NeMo#7, NVIDIA-NeMo#8, NVIDIA-NeMo#9, NVIDIA-NeMo#10, NVIDIA-NeMo#11. - Rename _load_M -> _get_sparse_projection_matrix and _load_dense_projection -> _get_topk_projection (later removed in favor of module-level cache helpers below). - Drop unused alignment_student_spans / alignment_teacher_spans from the cross-tokenizer batch payload. - Remove NRL_XTOKEN_LOSS_DUMP_DIR debug-dump side effect. - Move Fp32SparseMM, chunk_average_log_probs, valid_chunk_mask to a new shared module nemo_rl/algorithms/x_token/utils.py. - Extract projection-file parsing into utils.parse_projection_file; tokenalign.py and loss_functions.py both go through it. - Move per-instance projection-matrix caches to process-local caches in utils.get_sparse_projection_matrix / get_topk_projection. The driver no longer holds large CUDA tensors; each Ray worker fills its own cache on first loss call. Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
avenkateshha
added a commit
that referenced
this pull request
May 27, 2026
…t API Address easy-batch items from PR NVIDIA-NeMo#2508 review. - Move sinkhorn_one_dim / apply_canonicalization_if_enabled / clean_model_name_for_filename / project_token_likelihoods to nemo_rl/utils/x_token/_shared.py and import from there in the four projection-prep CLIs (#5). - Drop unused sinkhorn (10-iter), debug_projection_map, and generate_projection_map from minimal_projection_generator.py (NVIDIA-NeMo#8). - Move minimal_projection_generator.py's CLI parsing under `if __name__ == "__main__":` so the module is importable for the P1 dedup harness. - Rename seq1 / seq2 (and the derivative s1_*/s2_*/used_seq*/joined_seq*/ seg* names) to student_tokens / teacher_tokens throughout TokenAligner._align_single / _align_with_anchors / _align_dp; all call sites are internal (NVIDIA-NeMo#11). - Replace **kwargs in TokenAligner._align_with_anchors with the five explicit keyword-only scoring knobs already declared on _align_dp, and pass them through named at the _align_single call site (NVIDIA-NeMo#12). - Remove dead TokenAligner.load_projection_matrix + the private _load_projection_components plus the now-orphaned self._projection_* attributes; the live projection-load path is nemo_rl/algorithms/x_token/utils.py::{get_sparse_projection_matrix, get_topk_projection} added in 6336464 (NVIDIA-NeMo#13, NVIDIA-NeMo#14). - git mv nemo_rl/algorithms/x_token/tokenalign.py -> token_aligner.py and update import sites in __init__.py, data/cross_tokenizer_collate.py, the four utils.x_token CLIs, and the xtoken-distillation guide; update Sphinx :mod: references in algorithms/x_token/utils.py (NVIDIA-NeMo#16). Incidental: the CLIs previously called TokenAligner._canonical_token, which was never a class attribute (the helper is module-level); the new shared helper imports _canonical_token directly so the use_canonicalization branch isn't broken at runtime. Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
avenkateshha
added a commit
that referenced
this pull request
Jun 4, 2026
Comments addressed: #3, #5, NVIDIA-NeMo#7, NVIDIA-NeMo#8, NVIDIA-NeMo#9, NVIDIA-NeMo#10, NVIDIA-NeMo#11. - Rename _load_M -> _get_sparse_projection_matrix and _load_dense_projection -> _get_topk_projection (later removed in favor of module-level cache helpers below). - Drop unused alignment_student_spans / alignment_teacher_spans from the cross-tokenizer batch payload. - Remove NRL_XTOKEN_LOSS_DUMP_DIR debug-dump side effect. - Move Fp32SparseMM, chunk_average_log_probs, valid_chunk_mask to a new shared module nemo_rl/algorithms/x_token/utils.py. - Extract projection-file parsing into utils.parse_projection_file; tokenalign.py and loss_functions.py both go through it. - Move per-instance projection-matrix caches to process-local caches in utils.get_sparse_projection_matrix / get_topk_projection. The driver no longer holds large CUDA tensors; each Ray worker fills its own cache on first loss call. Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
avenkateshha
added a commit
that referenced
this pull request
Jun 4, 2026
…t API Address easy-batch items from PR NVIDIA-NeMo#2508 review. - Move sinkhorn_one_dim / apply_canonicalization_if_enabled / clean_model_name_for_filename / project_token_likelihoods to nemo_rl/utils/x_token/_shared.py and import from there in the four projection-prep CLIs (#5). - Drop unused sinkhorn (10-iter), debug_projection_map, and generate_projection_map from minimal_projection_generator.py (NVIDIA-NeMo#8). - Move minimal_projection_generator.py's CLI parsing under `if __name__ == "__main__":` so the module is importable for the P1 dedup harness. - Rename seq1 / seq2 (and the derivative s1_*/s2_*/used_seq*/joined_seq*/ seg* names) to student_tokens / teacher_tokens throughout TokenAligner._align_single / _align_with_anchors / _align_dp; all call sites are internal (NVIDIA-NeMo#11). - Replace **kwargs in TokenAligner._align_with_anchors with the five explicit keyword-only scoring knobs already declared on _align_dp, and pass them through named at the _align_single call site (NVIDIA-NeMo#12). - Remove dead TokenAligner.load_projection_matrix + the private _load_projection_components plus the now-orphaned self._projection_* attributes; the live projection-load path is nemo_rl/algorithms/x_token/utils.py::{get_sparse_projection_matrix, get_topk_projection} added in 6336464 (NVIDIA-NeMo#13, NVIDIA-NeMo#14). - git mv nemo_rl/algorithms/x_token/tokenalign.py -> token_aligner.py and update import sites in __init__.py, data/cross_tokenizer_collate.py, the four utils.x_token CLIs, and the xtoken-distillation guide; update Sphinx :mod: references in algorithms/x_token/utils.py (NVIDIA-NeMo#16). Incidental: the CLIs previously called TokenAligner._canonical_token, which was never a class attribute (the helper is module-level); the new shared helper imports _canonical_token directly so the use_canonicalization branch isn't broken at runtime. Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
avenkateshha
added a commit
that referenced
this pull request
Jun 7, 2026
Comments addressed: #3, #5, NVIDIA-NeMo#7, NVIDIA-NeMo#8, NVIDIA-NeMo#9, NVIDIA-NeMo#10, NVIDIA-NeMo#11. - Rename _load_M -> _get_sparse_projection_matrix and _load_dense_projection -> _get_topk_projection (later removed in favor of module-level cache helpers below). - Drop unused alignment_student_spans / alignment_teacher_spans from the cross-tokenizer batch payload. - Remove NRL_XTOKEN_LOSS_DUMP_DIR debug-dump side effect. - Move Fp32SparseMM, chunk_average_log_probs, valid_chunk_mask to a new shared module nemo_rl/algorithms/x_token/utils.py. - Extract projection-file parsing into utils.parse_projection_file; tokenalign.py and loss_functions.py both go through it. - Move per-instance projection-matrix caches to process-local caches in utils.get_sparse_projection_matrix / get_topk_projection. The driver no longer holds large CUDA tensors; each Ray worker fills its own cache on first loss call. Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
avenkateshha
added a commit
that referenced
this pull request
Jun 7, 2026
…t API Address easy-batch items from PR NVIDIA-NeMo#2508 review. - Move sinkhorn_one_dim / apply_canonicalization_if_enabled / clean_model_name_for_filename / project_token_likelihoods to nemo_rl/utils/x_token/_shared.py and import from there in the four projection-prep CLIs (#5). - Drop unused sinkhorn (10-iter), debug_projection_map, and generate_projection_map from minimal_projection_generator.py (NVIDIA-NeMo#8). - Move minimal_projection_generator.py's CLI parsing under `if __name__ == "__main__":` so the module is importable for the P1 dedup harness. - Rename seq1 / seq2 (and the derivative s1_*/s2_*/used_seq*/joined_seq*/ seg* names) to student_tokens / teacher_tokens throughout TokenAligner._align_single / _align_with_anchors / _align_dp; all call sites are internal (NVIDIA-NeMo#11). - Replace **kwargs in TokenAligner._align_with_anchors with the five explicit keyword-only scoring knobs already declared on _align_dp, and pass them through named at the _align_single call site (NVIDIA-NeMo#12). - Remove dead TokenAligner.load_projection_matrix + the private _load_projection_components plus the now-orphaned self._projection_* attributes; the live projection-load path is nemo_rl/algorithms/x_token/utils.py::{get_sparse_projection_matrix, get_topk_projection} added in 6336464 (NVIDIA-NeMo#13, NVIDIA-NeMo#14). - git mv nemo_rl/algorithms/x_token/tokenalign.py -> token_aligner.py and update import sites in __init__.py, data/cross_tokenizer_collate.py, the four utils.x_token CLIs, and the xtoken-distillation guide; update Sphinx :mod: references in algorithms/x_token/utils.py (NVIDIA-NeMo#16). Incidental: the CLIs previously called TokenAligner._canonical_token, which was never a class attribute (the helper is module-level); the new shared helper imports _canonical_token directly so the use_canonicalization branch isn't broken at runtime. Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information