-
Notifications
You must be signed in to change notification settings - Fork 795
Edge case in nnx.tabulate when nnx.Module stores empty dictionary #4889
Copy link
Copy link
Closed
Description
System information
Name: flax
Version: 0.11.1
---
Name: jax
Version: 0.7.0
---
Name: jaxlib
Version: 0.7.0
- Python version: 3.11
Steps to reproduce:
import jax.numpy as jnp
from flax import nnx
class Model(nnx.Module):
def __init__(self):
self.foo = {} # This fails with tabulate
# self.foo = None # This fails with tabulate
# self.foo = 1 # This works
# self.foo = {"bar": 1} # This works
def subroutine(self, foo, x):
return x
def __call__(self, x):
return self.subroutine(self.foo, x)
model = Model()
inputs = jnp.zeros((1, 1024))
# This works
output = model(inputs) # Shape: (1, 1024)
# This works with depth=0
print(nnx.tabulate(model, inputs, depth=0)) # ✓
# This fails with depth >= 1
print(nnx.tabulate(model, inputs, depth=1)) # ✗ AssertionError in flax.nnx.summary's _unflatten_to_simple_structureOutput:
Model Summary
┏━━━━━━┳━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ path ┃ type ┃ inputs ┃ outputs ┃
┡━━━━━━╇━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ │ Model │ float32[1,1024] │ float32[1,1024] │
├──────┼───────┼─────────────────┼─────────────────┤
│ │ │ │ Total │
└──────┴───────┴─────────────────┴─────────────────┘
Total Parameters: 0 (0 B)
Traceback (most recent call last):
File "/mnt/c/Users/admin/tmp/nnx_tabulate_issue.py", line 29, in <module>
print(nnx.tabulate(model, inputs, depth=1)) # ✗ AssertionError in flax.nnx.summary's _unflatten_to_simple_structure
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/summary.py", line 386, in tabulate
inputs_repr += _as_yaml_str(input_args)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/summary.py", line 535, in _as_yaml_str
value = _maybe_pytree_to_dict(value)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/summary.py", line 497, in _maybe_pytree_to_dict
return _unflatten_to_simple_structure(path_leaves)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/summary.py", line 523, in _unflatten_to_simple_structure
assert path[-1] == len(cursor)
^^^^^^^^^^^^^^^^^^^^^^^
AssertionError
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels