Skip to content

Edge case in nnx.tabulate when nnx.Module stores empty dictionary #4889

@DBraun

Description

@DBraun

System information

Name: flax
Version: 0.11.1
---
Name: jax
Version: 0.7.0
---
Name: jaxlib
Version: 0.7.0
  • Python version: 3.11

Steps to reproduce:

import jax.numpy as jnp
from flax import nnx


class Model(nnx.Module):
    def __init__(self):
        self.foo = {}  # This fails with tabulate
        # self.foo = None  # This fails with tabulate
        # self.foo = 1  # This works
        # self.foo = {"bar": 1}  # This works

    def subroutine(self, foo, x):
        return x
        
    def __call__(self, x):
        return self.subroutine(self.foo, x)


model = Model()
inputs = jnp.zeros((1, 1024))

# This works
output = model(inputs)  # Shape: (1, 1024)

# This works with depth=0
print(nnx.tabulate(model, inputs, depth=0))  # ✓

# This fails with depth >= 1
print(nnx.tabulate(model, inputs, depth=1))  # ✗ AssertionError in flax.nnx.summary's _unflatten_to_simple_structure

Output:

                   Model Summary
┏━━━━━━┳━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ path ┃ type  ┃ inputs          ┃ outputs         ┃
┡━━━━━━╇━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│      │ Model │ float32[1,1024] │ float32[1,1024] │
├──────┼───────┼─────────────────┼─────────────────┤
│      │       │                 │           Total │
└──────┴───────┴─────────────────┴─────────────────┘

             Total Parameters: 0 (0 B)

Traceback (most recent call last):
  File "/mnt/c/Users/admin/tmp/nnx_tabulate_issue.py", line 29, in <module>
    print(nnx.tabulate(model, inputs, depth=1))  # ✗ AssertionError in flax.nnx.summary's _unflatten_to_simple_structure
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/summary.py", line 386, in tabulate
    inputs_repr += _as_yaml_str(input_args)
                   ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/summary.py", line 535, in _as_yaml_str
    value = _maybe_pytree_to_dict(value)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/summary.py", line 497, in _maybe_pytree_to_dict
    return _unflatten_to_simple_structure(path_leaves)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/summary.py", line 523, in _unflatten_to_simple_structure
    assert path[-1] == len(cursor)
           ^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions