Skip to content

Xtoken/off policy distillation gh#5

Closed
avenkateshha wants to merge 22 commits into
mainfrom
xtoken/off-policy-distillation-gh
Closed

Xtoken/off policy distillation gh#5
avenkateshha wants to merge 22 commits into
mainfrom
xtoken/off-policy-distillation-gh

Conversation

@avenkateshha

Copy link
Copy Markdown
Owner

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Adithyakrishna Hanasoge added 22 commits April 9, 2026 17:46
- Off-policy distillation pipeline (teacher Llama-3.1-8B, student Llama-3.2-1B)
  with arrow dataset support and inline MATH/MMLU generation-based evaluation
- Compute distributed log_softmax before top-k for correct KL divergence
- Add CUDA IPC buffer mechanism to avoid Ray object store bottleneck
  for large top-k logprob tensors (based on dtensor_sharath.py approach)
- Update loss function to skip re-normalization of teacher log probabilities
- Add submit scripts, configs, and eval benchmarks

Made-with: Cursor
Made-with: Cursor
- Flatten teacher IPC data structure and use mb_idx*mbs indexing instead
  of cumulative mb_offset/mb_size for microbatch slicing
- Add log_softmax for teacher top-k logits in standard (non-IPC) path
- Restore output_replicated for teacher path and add kl_loss/nll_loss
  to aggregated results
- Make arrow config self-contained with explicit settings

Made-with: Cursor
Made-with: Cursor
Refactor teacher logit sharing to use per-microbatch IPC buffers instead
of accumulating all teacher logits post-loop. Update loss functions to
handle optional microbatch indexing. Bump config to TP=4 and 10k steps.

Made-with: Cursor
- Add `use_ipc` config flag to switch between IPC (in-process
  communication) and non-IPC (data-dict) teacher logprob paths
- Simplify KL loss to use (k+1)-dim distributions with a "rest"
  bucket, unifying IPC and non-IPC code paths
- Branch teacher inference and student training for both train
  and validation loops based on use_ipc setting
- Update submit script with IPC test experiment config

Made-with: Cursor
- Add x_token/ module with TokenAligner for cross-vocabulary distillation
- Rewrite CrossTokenizerDistillationLossFn with chunk-averaged KL that
  handles 1:1, 1:many, many:1, and many:many token alignments, matching
  the original TokenAligner.compute_KL_loss_optimized() exactly
  (verified via sanity check: 0.00% difference)
- Fix teacher IPC to send full logits (topk_logits=None) instead of
  topk_logits=0 which produced empty tensors
- Pass global_top_indices through to projection for memory optimization
  (2.3GB -> 125MB for projection output tensor)
- Add cross-tokenizer data processing in training loop (dual tokenization,
  alignment, teacher data dict)
- Add unbuffered stdout for better SLURM log visibility
- Add example config and sanity check script

Made-with: Cursor
Documents architecture, all new/modified files, usage instructions,
configuration reference, and design decisions for the cross-tokenizer
off-policy distillation feature built on NeMo RL v0.5.0.

Made-with: Cursor
…compat

- Switch default teacher from Qwen3-8B to Phi-4-mini-instruct
- Add gold loss (common-vocab KL + uncommon-vocab L1) and xtoken loss modes
- Add CE loss with dynamic loss scaling option
- Replace dense projection with CSR sparse matmul for memory efficiency
- Add MMLU 5-shot evaluation benchmark
- Fix NotRequired import for Python <3.11 compatibility (16 files)
- Add submit_cross_tokenizer.sh sbatch script
- Add sanity check script for alignment and loss verification
- Update LR schedule for 80k step training (warmup 4k, cosine 76k)

Made-with: Cursor
…parse projection

- Cache CrossTokenizerDistillationLossFn on policy workers at init and
  pass None to train() calls, eliminating repeated Ray serialization of
  the loss function (which includes large sparse matrices) each step.
- Add set_loss_fn() and update_cross_tokenizer_data() to Policy and
  DTensorPolicyWorkerV2 to support per-step cross-tokenizer data updates.
- Optimize sparse token projection by pre-reducing the sparse matrix
  with index_select before projection instead of projecting full vocab
  and slicing afterward.
- Use AutoConfig.from_pretrained() for vocab sizes in sanity check script.

Made-with: Cursor
…P rank

Reduced training time with this optimization. Avoid Ray serialization of
the loss function by having each worker construct
CrossTokenizerDistillationLossFn from config + shared filesystem. Shard
teacher_input_ids and aligned_pairs per data-parallel rank instead of
broadcasting the full batch to every worker.

Made-with: Cursor
- Add O(n+m) character-offset alignment via two-pointer walk on tokenizer
  offset mappings, with automatic DP fallback for failed samples
- Precompute canonical token ID maps at startup to skip convert_ids_to_tokens
- Add Numba JIT-accelerated DP kernel and banded DP variant
- Add KD preprocessor preserving raw text for teacher tokenization
- Add numba dependency
- Update config: expand arrow data glob, set max_num_epochs=1
- Update submit script: bump max_num_steps=10, rename experiment to raw-text-kd-16node

Made-with: Cursor
Introduce CUDA kernel and Python integration module for faster TokenAligner dynamic programming base-case computation.

Made-with: Cursor
… preprocessing with current-step GPU training while keeping alignment behavior unchanged.

Add explicit token-aligner runtime switches (`use_char_offset`, `use_align_fast`, CUDA-DP toggles), clean up dead/duplicated paths, and simplify the step orchestration with typed prefetch payloads and helper extraction for maintainability.

Made-with: Cursor
Set explicit token aligner defaults and document the total-GPUs/2 heuristic for cross_tokenizer_num_workers so large-batch off-policy runs can iterate on stable, reproducible CT pool sizing.

Made-with: Cursor
Align rebased cross-tokenizer distillation code with main-branch APIs and config expectations.
- update stale loss/interface import paths after rebase
- migrate DTensorPolicyWorkerV2 init/checkpoint setup to main-style flow
- add required sequence_packing config keys instead of setup-side fallbacks
- add scoped Phi RoPE buffer repair for post-load meta/non-finite buffer cases

Made-with: Cursor
Align rebased cross-tokenizer distillation code with main-branch APIs and config expectations.
- update stale loss/interface import paths after rebase
- migrate DTensorPolicyWorkerV2 init/checkpoint setup to main-style flow
- add required sequence_packing config keys instead of setup-side fallbacks
- add scoped Phi RoPE buffer repair for post-load meta/non-finite buffer cases
- include cross-tokenizer launch shell scripts used for runs

Made-with: Cursor
…token projection tools

Refactor off-policy cross-tokenizer scripts to use the new data.train/default schema, remove hard-coded dataset paths, and support user-provided Arrow or dataset_path/HF inputs at submit time. Rename callsites to off_policy_distillation_train for consistency and move xtoken projection matrix generation utilities under the utils folder.

Made-with: Cursor
…n training path.

Update DTensor worker and lm/off-policy orchestration for IPC-based teacher-student flow, add shared IPC tensor-handle helpers in ipc_utils, and fix xtoken post-processor microbatch buffer handling/config integration to avoid runtime shape and key errors during policy training.

Made-with: Cursor
…before KL

Prevent loss-magnitude explosions in off-policy distillation when non-IPC teacher top-k paths return raw logits instead of log-probs. Add normalization in off-policy train/validation data preparation so DistillationLoss receives stable top-k log-probs, while keeping shared/common loss code unchanged.

Align llama_off_policy_arrow.yaml with the structured data.train/data.default schema, remove legacy hard-coded dataset fields, and bake in common run settings (teacher instruct variant, checkpoint optimizer save, and policy batching/packing defaults) to reduce CLI override churn.

Made-with: Cursor
@github-actions

Copy link
Copy Markdown

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

@github-actions github-actions Bot added the Stale label Apr 28, 2026
@github-actions

github-actions Bot commented May 6, 2026

Copy link
Copy Markdown

This PR was closed because it has been inactive for 7 days since being marked as stale.

@github-actions github-actions Bot closed this May 6, 2026
avenkateshha added a commit that referenced this pull request May 20, 2026
Comments addressed: #3, #5, NVIDIA-NeMo#7, NVIDIA-NeMo#8, NVIDIA-NeMo#9, NVIDIA-NeMo#10, NVIDIA-NeMo#11.

- Rename _load_M -> _get_sparse_projection_matrix and
  _load_dense_projection -> _get_topk_projection (later removed in
  favor of module-level cache helpers below).
- Drop unused alignment_student_spans / alignment_teacher_spans
  from the cross-tokenizer batch payload.
- Remove NRL_XTOKEN_LOSS_DUMP_DIR debug-dump side effect.
- Move Fp32SparseMM, chunk_average_log_probs, valid_chunk_mask to a
  new shared module nemo_rl/algorithms/x_token/utils.py.
- Extract projection-file parsing into utils.parse_projection_file;
  tokenalign.py and loss_functions.py both go through it.
- Move per-instance projection-matrix caches to process-local caches
  in utils.get_sparse_projection_matrix / get_topk_projection. The
  driver no longer holds large CUDA tensors; each Ray worker fills
  its own cache on first loss call.

Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
avenkateshha added a commit that referenced this pull request May 23, 2026
…t API

Address easy-batch items from PR NVIDIA-NeMo#2508 review.

- Move sinkhorn_one_dim / apply_canonicalization_if_enabled /
  clean_model_name_for_filename / project_token_likelihoods to
  nemo_rl/utils/x_token/_shared.py and import from there in the four
  projection-prep CLIs (#5).
- Drop unused sinkhorn (10-iter), debug_projection_map, and
  generate_projection_map from minimal_projection_generator.py (NVIDIA-NeMo#8).
- Move minimal_projection_generator.py's CLI parsing under
  `if __name__ == "__main__":` so the module is importable for the
  P1 dedup harness.
- Rename seq1 / seq2 (and the derivative s1_*/s2_*/used_seq*/joined_seq*/
  seg* names) to student_tokens / teacher_tokens throughout
  TokenAligner._align_single / _align_with_anchors / _align_dp; all call
  sites are internal (NVIDIA-NeMo#11).
- Replace **kwargs in TokenAligner._align_with_anchors with the five
  explicit keyword-only scoring knobs already declared on _align_dp,
  and pass them through named at the _align_single call site (NVIDIA-NeMo#12).
- Remove dead TokenAligner.load_projection_matrix + the private
  _load_projection_components plus the now-orphaned self._projection_*
  attributes; the live projection-load path is
  nemo_rl/algorithms/x_token/utils.py::{get_sparse_projection_matrix,
  get_topk_projection} added in 6336464 (NVIDIA-NeMo#13, NVIDIA-NeMo#14).
- git mv nemo_rl/algorithms/x_token/tokenalign.py -> token_aligner.py
  and update import sites in __init__.py, data/cross_tokenizer_collate.py,
  the four utils.x_token CLIs, and the xtoken-distillation guide; update
  Sphinx :mod: references in algorithms/x_token/utils.py (NVIDIA-NeMo#16).

Incidental: the CLIs previously called TokenAligner._canonical_token,
which was never a class attribute (the helper is module-level); the
new shared helper imports _canonical_token directly so the
use_canonicalization branch isn't broken at runtime.

Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
avenkateshha added a commit that referenced this pull request May 27, 2026
Comments addressed: #3, #5, NVIDIA-NeMo#7, NVIDIA-NeMo#8, NVIDIA-NeMo#9, NVIDIA-NeMo#10, NVIDIA-NeMo#11.

- Rename _load_M -> _get_sparse_projection_matrix and
  _load_dense_projection -> _get_topk_projection (later removed in
  favor of module-level cache helpers below).
- Drop unused alignment_student_spans / alignment_teacher_spans
  from the cross-tokenizer batch payload.
- Remove NRL_XTOKEN_LOSS_DUMP_DIR debug-dump side effect.
- Move Fp32SparseMM, chunk_average_log_probs, valid_chunk_mask to a
  new shared module nemo_rl/algorithms/x_token/utils.py.
- Extract projection-file parsing into utils.parse_projection_file;
  tokenalign.py and loss_functions.py both go through it.
- Move per-instance projection-matrix caches to process-local caches
  in utils.get_sparse_projection_matrix / get_topk_projection. The
  driver no longer holds large CUDA tensors; each Ray worker fills
  its own cache on first loss call.

Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
avenkateshha added a commit that referenced this pull request May 27, 2026
…t API

Address easy-batch items from PR NVIDIA-NeMo#2508 review.

- Move sinkhorn_one_dim / apply_canonicalization_if_enabled /
  clean_model_name_for_filename / project_token_likelihoods to
  nemo_rl/utils/x_token/_shared.py and import from there in the four
  projection-prep CLIs (#5).
- Drop unused sinkhorn (10-iter), debug_projection_map, and
  generate_projection_map from minimal_projection_generator.py (NVIDIA-NeMo#8).
- Move minimal_projection_generator.py's CLI parsing under
  `if __name__ == "__main__":` so the module is importable for the
  P1 dedup harness.
- Rename seq1 / seq2 (and the derivative s1_*/s2_*/used_seq*/joined_seq*/
  seg* names) to student_tokens / teacher_tokens throughout
  TokenAligner._align_single / _align_with_anchors / _align_dp; all call
  sites are internal (NVIDIA-NeMo#11).
- Replace **kwargs in TokenAligner._align_with_anchors with the five
  explicit keyword-only scoring knobs already declared on _align_dp,
  and pass them through named at the _align_single call site (NVIDIA-NeMo#12).
- Remove dead TokenAligner.load_projection_matrix + the private
  _load_projection_components plus the now-orphaned self._projection_*
  attributes; the live projection-load path is
  nemo_rl/algorithms/x_token/utils.py::{get_sparse_projection_matrix,
  get_topk_projection} added in 6336464 (NVIDIA-NeMo#13, NVIDIA-NeMo#14).
- git mv nemo_rl/algorithms/x_token/tokenalign.py -> token_aligner.py
  and update import sites in __init__.py, data/cross_tokenizer_collate.py,
  the four utils.x_token CLIs, and the xtoken-distillation guide; update
  Sphinx :mod: references in algorithms/x_token/utils.py (NVIDIA-NeMo#16).

Incidental: the CLIs previously called TokenAligner._canonical_token,
which was never a class attribute (the helper is module-level); the
new shared helper imports _canonical_token directly so the
use_canonicalization branch isn't broken at runtime.

Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
avenkateshha added a commit that referenced this pull request Jun 4, 2026
Comments addressed: #3, #5, NVIDIA-NeMo#7, NVIDIA-NeMo#8, NVIDIA-NeMo#9, NVIDIA-NeMo#10, NVIDIA-NeMo#11.

- Rename _load_M -> _get_sparse_projection_matrix and
  _load_dense_projection -> _get_topk_projection (later removed in
  favor of module-level cache helpers below).
- Drop unused alignment_student_spans / alignment_teacher_spans
  from the cross-tokenizer batch payload.
- Remove NRL_XTOKEN_LOSS_DUMP_DIR debug-dump side effect.
- Move Fp32SparseMM, chunk_average_log_probs, valid_chunk_mask to a
  new shared module nemo_rl/algorithms/x_token/utils.py.
- Extract projection-file parsing into utils.parse_projection_file;
  tokenalign.py and loss_functions.py both go through it.
- Move per-instance projection-matrix caches to process-local caches
  in utils.get_sparse_projection_matrix / get_topk_projection. The
  driver no longer holds large CUDA tensors; each Ray worker fills
  its own cache on first loss call.

Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
avenkateshha added a commit that referenced this pull request Jun 4, 2026
…t API

Address easy-batch items from PR NVIDIA-NeMo#2508 review.

- Move sinkhorn_one_dim / apply_canonicalization_if_enabled /
  clean_model_name_for_filename / project_token_likelihoods to
  nemo_rl/utils/x_token/_shared.py and import from there in the four
  projection-prep CLIs (#5).
- Drop unused sinkhorn (10-iter), debug_projection_map, and
  generate_projection_map from minimal_projection_generator.py (NVIDIA-NeMo#8).
- Move minimal_projection_generator.py's CLI parsing under
  `if __name__ == "__main__":` so the module is importable for the
  P1 dedup harness.
- Rename seq1 / seq2 (and the derivative s1_*/s2_*/used_seq*/joined_seq*/
  seg* names) to student_tokens / teacher_tokens throughout
  TokenAligner._align_single / _align_with_anchors / _align_dp; all call
  sites are internal (NVIDIA-NeMo#11).
- Replace **kwargs in TokenAligner._align_with_anchors with the five
  explicit keyword-only scoring knobs already declared on _align_dp,
  and pass them through named at the _align_single call site (NVIDIA-NeMo#12).
- Remove dead TokenAligner.load_projection_matrix + the private
  _load_projection_components plus the now-orphaned self._projection_*
  attributes; the live projection-load path is
  nemo_rl/algorithms/x_token/utils.py::{get_sparse_projection_matrix,
  get_topk_projection} added in 6336464 (NVIDIA-NeMo#13, NVIDIA-NeMo#14).
- git mv nemo_rl/algorithms/x_token/tokenalign.py -> token_aligner.py
  and update import sites in __init__.py, data/cross_tokenizer_collate.py,
  the four utils.x_token CLIs, and the xtoken-distillation guide; update
  Sphinx :mod: references in algorithms/x_token/utils.py (NVIDIA-NeMo#16).

Incidental: the CLIs previously called TokenAligner._canonical_token,
which was never a class attribute (the helper is module-level); the
new shared helper imports _canonical_token directly so the
use_canonicalization branch isn't broken at runtime.

Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
avenkateshha added a commit that referenced this pull request Jun 7, 2026
Comments addressed: #3, #5, NVIDIA-NeMo#7, NVIDIA-NeMo#8, NVIDIA-NeMo#9, NVIDIA-NeMo#10, NVIDIA-NeMo#11.

- Rename _load_M -> _get_sparse_projection_matrix and
  _load_dense_projection -> _get_topk_projection (later removed in
  favor of module-level cache helpers below).
- Drop unused alignment_student_spans / alignment_teacher_spans
  from the cross-tokenizer batch payload.
- Remove NRL_XTOKEN_LOSS_DUMP_DIR debug-dump side effect.
- Move Fp32SparseMM, chunk_average_log_probs, valid_chunk_mask to a
  new shared module nemo_rl/algorithms/x_token/utils.py.
- Extract projection-file parsing into utils.parse_projection_file;
  tokenalign.py and loss_functions.py both go through it.
- Move per-instance projection-matrix caches to process-local caches
  in utils.get_sparse_projection_matrix / get_topk_projection. The
  driver no longer holds large CUDA tensors; each Ray worker fills
  its own cache on first loss call.

Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
avenkateshha added a commit that referenced this pull request Jun 7, 2026
…t API

Address easy-batch items from PR NVIDIA-NeMo#2508 review.

- Move sinkhorn_one_dim / apply_canonicalization_if_enabled /
  clean_model_name_for_filename / project_token_likelihoods to
  nemo_rl/utils/x_token/_shared.py and import from there in the four
  projection-prep CLIs (#5).
- Drop unused sinkhorn (10-iter), debug_projection_map, and
  generate_projection_map from minimal_projection_generator.py (NVIDIA-NeMo#8).
- Move minimal_projection_generator.py's CLI parsing under
  `if __name__ == "__main__":` so the module is importable for the
  P1 dedup harness.
- Rename seq1 / seq2 (and the derivative s1_*/s2_*/used_seq*/joined_seq*/
  seg* names) to student_tokens / teacher_tokens throughout
  TokenAligner._align_single / _align_with_anchors / _align_dp; all call
  sites are internal (NVIDIA-NeMo#11).
- Replace **kwargs in TokenAligner._align_with_anchors with the five
  explicit keyword-only scoring knobs already declared on _align_dp,
  and pass them through named at the _align_single call site (NVIDIA-NeMo#12).
- Remove dead TokenAligner.load_projection_matrix + the private
  _load_projection_components plus the now-orphaned self._projection_*
  attributes; the live projection-load path is
  nemo_rl/algorithms/x_token/utils.py::{get_sparse_projection_matrix,
  get_topk_projection} added in 6336464 (NVIDIA-NeMo#13, NVIDIA-NeMo#14).
- git mv nemo_rl/algorithms/x_token/tokenalign.py -> token_aligner.py
  and update import sites in __init__.py, data/cross_tokenizer_collate.py,
  the four utils.x_token CLIs, and the xtoken-distillation guide; update
  Sphinx :mod: references in algorithms/x_token/utils.py (NVIDIA-NeMo#16).

Incidental: the CLIs previously called TokenAligner._canonical_token,
which was never a class attribute (the helper is module-level); the
new shared helper imports _canonical_token directly so the
use_canonicalization branch isn't broken at runtime.

Signed-off-by: Adithya Hanasoge <avenkateshha@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant