Skip to content

fix(tests): initialize num_microbatches calculator in vision cudagraph tests#4986

Merged
ko3n1g merged 1 commit into
NVIDIA:mainfrom
ko3n1g:ko3n1g/fix/3802-vision-cuda-graphs-microbatch-calc
May 26, 2026
Merged

fix(tests): initialize num_microbatches calculator in vision cudagraph tests#4986
ko3n1g merged 1 commit into
NVIDIA:mainfrom
ko3n1g:ko3n1g/fix/3802-vision-cuda-graphs-microbatch-calc

Conversation

@ko3n1g

@ko3n1g ko3n1g commented May 26, 2026

Copy link
Copy Markdown
Contributor
Claude summary

Fixes #3802 (and closes its duplicate #3803, which was already closed manually).

Bug

After the TE 2.13 bump (#3800) the vision-encoder CUDA graph unit tests started failing with:

megatron/core/num_microbatches_calculator.py:19: AttributeError
E       AttributeError: 'NoneType' object has no attribute 'get'

The call chain is:

TestVisionTECudaGraphHelper.test_create_and_delete_cudagraphs
VisionTECudaGraphHelper.create_cudagraphs()
TECudaGraphHelper.create_cudagraphs()
_get_cuda_graph_input_data()
get_make_graphed_callables_kwargs() (defined inline)
get_num_microbatches() (megatron/core/transformer/cuda_graphs.py:2255)

get_num_microbatches() does _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get(), but the global calculator was never initialized in this unit test (only the model-parallel state is initialized in setup_method), so the global is None.

The follow-up ci: Skip more tests in test_vision_cuda_graphs for LTS (#3860) merely marked the two tests @pytest.mark.flaky / @pytest.mark.flaky_in_dev, masking the failure rather than fixing it. The PP2 variant tracked under #3804 is a separate hang and is left untouched here.

Fix

In tests/unit_tests/transformer/test_vision_cuda_graphs.py:

  1. Initialize the global num_microbatches calculator inside TestVisionTECudaGraphHelper._make_helper with the requested num_microbatches, matching the canonical pattern used in tests/unit_tests/transformer/test_cuda_graphs.py.
  2. Destroy the calculator in teardown_method (and on each new _make_helper call) so the tests remain hermetic.
  3. Drop the @pytest.mark.flaky / @pytest.mark.flaky_in_dev markers from test_create_and_delete_cudagraphs and test_create_cudagraphs_multi_microbatch so CI exercises the real code path again.

Minimal example of the fix shape:

def _make_helper(self, num_microbatches=1):
    destroy_num_microbatches_calculator()
    init_num_microbatches_calculator(
        rank=0,
        global_batch_size=self.micro_batch_size * num_microbatches,
        micro_batch_size=self.micro_batch_size,
        data_parallel_size=1,
        decrease_batch_size_if_needed=False,
    )
    return VisionTECudaGraphHelper(...)

Scope

PR scope is intentionally limited to the PP=1 TestVisionTECudaGraphHelper. The TestVisionTECudaGraphHelperPP2 variant has the same uninitialized-calculator latent issue but also a separate hang tracked in #3804; it will be addressed when #3804 is fixed.

…h tests

Closes NVIDIA#3802.

`TestVisionTECudaGraphHelper._make_helper` constructs a `VisionTECudaGraphHelper`
and then `create_cudagraphs()` calls into `cuda_graphs.py`'s
`get_make_graphed_callables_kwargs`, which calls `get_num_microbatches()`.
With the global calculator never initialized in this unit test,
`_GLOBAL_NUM_MICROBATCHES_CALCULATOR` is `None` and the call fails with
`AttributeError: 'NoneType' object has no attribute 'get'`.

Initialize the global calculator inside `_make_helper` with the requested
`num_microbatches`, and destroy it in `teardown_method` (and again on the next
`_make_helper` call) so tests are hermetic. This mirrors the canonical pattern
in `tests/unit_tests/transformer/test_cuda_graphs.py`.

With the real bug fixed, drop the `@pytest.mark.flaky` /
`@pytest.mark.flaky_in_dev` masking from `test_create_and_delete_cudagraphs`
and `test_create_cudagraphs_multi_microbatch` so they run in CI again.

Signed-off-by: oliver könig <okoenig@nvidia.com>
@copy-pr-bot

copy-pr-bot Bot commented May 26, 2026

Copy link
Copy Markdown

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@ko3n1g

ko3n1g commented May 26, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test

@tomlifu tomlifu left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ko3n1g ko3n1g added this pull request to the merge queue May 26, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/26461817940

@ko3n1g ko3n1g removed this pull request from the merge queue due to a manual request May 26, 2026
@ko3n1g ko3n1g added this pull request to the merge queue May 26, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/26464000672

Merged via the queue into NVIDIA:main with commit 6ce6fac May 26, 2026
236 of 242 checks passed
@ko3n1g ko3n1g deleted the ko3n1g/fix/3802-vision-cuda-graphs-microbatch-calc branch May 26, 2026 18:35
Victarry added a commit to yanring/Megatron-LM that referenced this pull request May 27, 2026
* origin/main: (50 commits)
  Drain predecessor reduce-scatter at dispatch time (NVIDIA#4940)
  ci: Add allow_failure flag to gpt and moe recipes that are failing in nightlies (NVIDIA#4905)
  fix(tests): initialize num_microbatches calculator in vision cudagraph tests (NVIDIA#4986)
  test: re-enable test_pp2_create_cudagraphs_first_stage on TE 2.15+ (NVIDIA#4985)
  ci: Add support for MBridge job gating based on PR labels  (NVIDIA#4926)
  test(ci): re-enable 8experts2parallel_multi_dist_optimizer_instances_1node (NVIDIA#4984)
  test: re-enable paged stashing MoE tests (NVIDIA#4978)
  Fix elastification unwrap_model import (NVIDIA#4972)
  Avoid offsetting functional test master port (NVIDIA#4973)
  test: enable NVTE_CUTEDSL_FUSED_GROUPED_MLP via pytest fixture (NVIDIA#4931)
  chore(beep boop 🤖): Bump  (main) (2026-05-25)
  test(release): add release goldens for deepseekv3/nemotron3 and set tp2pp2 exit-interval (NVIDIA#4932)
  Fix `get_batch` return order to ignore BlendedDataset provenance fields (NVIDIA#4952)
  ci: restore perf test torchrun logs (NVIDIA#4951)
  Various training utils (NVIDIA#4872)
  ci: Update training script paths in BERT and T5 (NVIDIA#4939)
  [MXFP8/FP4-param-gather] Post processing after forced param AG in eval (NVIDIA#4562)
  Fix mxfp8 param gather numerical issue when DP overlap is off (NVIDIA#4800)
  Add TEFusedDenseMLP for Dense+Grouped GEMM fusion on SM100+ (NVIDIA#4318) (NVIDIA#4786)
  Fix paged stashing test submodules lookup (NVIDIA#4925)
  ...

# Conflicts:
#	megatron/training/training.py
janEbert pushed a commit to janEbert/Megatron-LM that referenced this pull request Jun 2, 2026
…h tests (NVIDIA#4986)

Signed-off-by: oliver könig <okoenig@nvidia.com>
mathemakitten pushed a commit to mathemakitten/Megatron-LM that referenced this pull request Jun 12, 2026
…h tests (NVIDIA#4986)

Signed-off-by: oliver könig <okoenig@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

TE2.13: AttributeError: 'NoneType' object has no attribute 'get'

3 participants