🐛 Bug
When scanning over modules that doesn't contain model weights, it triggers assertion error:
File "/workspaces/pytorch/xla/torch_xla/experimental/scan_layers.py", line 93, in scan_layers
final_carry, _ = scan(
File "/workspaces/pytorch/xla/torch_xla/experimental/scan.py", line 156, in scan
raise ValueError(f"`xs` {xs} is an empty PyTree.")
ValueError: `xs` ({}, {}) is an empty PyTree.
To Reproduce
Run the test in 1. This test will fail and the output of fake_fa_wrapper triggers raise ValueError(f"xs {xs} is an empty PyTree.").
Expected behavior
The output of fake_fa_wrapper should not trigger assertion error.
🐛 Bug
When scanning over modules that doesn't contain model weights, it triggers assertion error:
To Reproduce
Run the test in 1. This test will fail and the output of
fake_fa_wrappertriggersraise ValueError(f"xs {xs} is an empty PyTree.").Expected behavior
The output of
fake_fa_wrappershould not trigger assertion error.