Skip to content

tool mask support#688

Merged
danielhanchen merged 1 commit into
unslothai:mainfrom
Datta0:tool_mask_support
May 27, 2026
Merged

tool mask support#688
danielhanchen merged 1 commit into
unslothai:mainfrom
Datta0:tool_mask_support

Conversation

@Datta0

@Datta0 Datta0 commented May 22, 2026

Copy link
Copy Markdown
Collaborator

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the align_completion_tool_mask function in unsloth_zoo/rl_replacements.py to align tool/environment masks with the repacked loss mask. It also updates the grpo_accumulated_loss function to accept an optional tool_mask and apply this alignment to the completion_mask. I have no feedback to provide.

@Datta0 Datta0 marked this pull request as ready for review May 25, 2026 08:13
@danielhanchen danielhanchen merged commit bea8b48 into unslothai:main May 27, 2026
11 checks passed
danielhanchen pushed a commit to mmathew23/unsloth-zoo that referenced this pull request May 27, 2026
Merges 5 main-side mlx fixes (unslothai#673 zero-token CCE, unslothai#679 + unslothai#692 LoRA save
metadata, unslothai#682 invalid label NaN-poisoning, unslothai#688 tool mask). All 13
conflict regions in unsloth_zoo/mlx/utils.py resolved to keep PR unslothai#684's
behavior where it conflicts on semantics:

  - half-open `<` length mask (PR unslothai#684 fix) wins over main's inclusive `<=`
  - `if labels is None` branch preserved (PR unslothai#684 generality) alongside
    main's `_normalize_cce_label_dtype` dtype widening
  - `_get_image_token_ids` legacy wrapper kept alongside main's new
    `_normalize_cce_label_dtype` / `_normalize_numpy_cce_labels`
  - `_mask_label_token_ids` calls `_normalize_cce_label_dtype` first so
    image masking honors main's uint-widening contract
  - HEAD's `_expand_token_replacements` dropped; main's three-function
    split (`_normalize_numpy_cce_labels` + `_expand_image_token_sequences`
    + `_expand_token_runs`) is canonical; duplicate HEAD wrappers removed
  - `_collate_vlm_prompt_completion_batch` reads back the masked labels
    in int64 so image + attention masking survives without narrowing
  - prompt-completion VLM collator routes through `_apply_vlm_label_masks`
    after dtype normalisation so ignore_token_ids and wide invalid ids
    both reach runtime CCE intact
  - `_to_mx_vlm_batch` uses main's `_normalize_cce_label_dtype` for labels
    while keeping PR unslothai#684's token_type_ids / mm_token_type_ids handling
  - `_unsloth_*` prefix filter preserved so the new collated_position_ids
    flag and main's raw-input-ids carrier both get stripped

152 MLX tests pass post-merge.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants