Static shape inference for ONNX models where standard tools fail.
Tested on 138 VNN-COMP 2024 models with 100% success rate. Provides more accurate shape information than ONNX's built-in inference.
ONNX's built-in onnx.shape_inference.infer_shapes handles most models correctly, but fails in several critical scenarios:
- Models with inconsistent ONNX versions or opset mismatches
- Non-standard conversions from PyTorch or other frameworks
- Dynamic shape operations where shape computations depend on data
- Shape operator chains (
Shape → Gather → Slice → Concat → Reshape) - Models with custom shape manipulations (Vision Transformers, GANs)
ShapeONNX goes beyond ONNX's capabilities through advanced static shape computation:
- Shape Tensor Tracking: Propagates actual shape values (e.g.,
[1, 48, 2, 2]) where ONNX only tracks tensor metadata (e.g.,[4]) - Static Resolution: Resolves shapes ONNX marks as dynamic (-1) to concrete values when statically computable
- Operator Chain Analysis: Processes complex
Shape → Gather → Slice → Concatpatterns to static constants - Explicit Shape Propagation: Distinguishes shape tensors from data tensors for accurate downstream inference
- Verification-Ready: Provides the precise static shapes required by neural network verification tools
ShapeONNX provides more accurate shape information than ONNX's built-in inference:
| Scenario | ONNX Result | ShapeONNX Result |
|---|---|---|
| Shape operation output | [4] (tensor metadata) |
[1, 48, 2, 2] (actual values) |
| Slice of shape tensor | [2] (tensor type) |
[1, 48] (sliced values) |
| Concat of shapes | [3] (1D array) |
[1, 48, -1] (reshape target) |
| ConstantOfShape | [-1, -1, -1] (dynamic) |
[1, 1, 48] (concrete shape) |
| Batch dimensions | -1 (dynamic) |
1 (static when determinable) |
Why this matters: Neural network verification tools need exact static shapes for layer-by-layer analysis. ShapeONNX resolves shapes ONNX marks as dynamic to concrete values when they're statically computable.
Neural network verification tools require precise static shapes for:
- Layer-by-layer bound propagation
- Memory allocation for symbolic execution
- Constraint generation for SMT solvers
- Model optimization and fusion (SlimONNX)
When ONNX shape inference fails or returns dynamic shapes, verification pipelines break. ShapeONNX fills this gap by providing robust static shape inference for the complex models encountered in verification research.
- Superior Shape Inference: More accurate than ONNX's built-in inference for shape tensors
- Advanced Shape Tracking: Propagates actual shape values through operator chains
- Shape Operator Chains: Resolves
Shape → Gather → Slice → Concat → Reshapepatterns - Dynamic to Static: Converts shapes ONNX marks as dynamic to concrete static values
- Explicit Shape Propagation: Distinguishes shape tensors from data tensors
- 51 Operators: Comprehensive coverage including Sign, Conv1d, and DivCst
- Fast Performance: Single-pass O(1) forward propagation
- Pure Python: No C/C++ dependencies, easy integration
- Production Ready: Tested on 138 VNN-COMP 2024 models with ONNX consistency validation
ShapeONNX is essential for:
- Neural Network Verification: Tools requiring static shapes (α,β-CROWN, ERAN, Marabou)
- Model Optimization: Pre-optimization shape resolution (SlimONNX)
- Shape-Dependent Transformations: Operations requiring known tensor dimensions
- Complex Model Analysis: Understanding shape propagation in non-standard models
- Python 3.11 or higher
- onnx 1.16.0
- numpy 1.26.4
Important: ShapeONNX is not available on PyPI. Local installation from source is required.
ShapeONNX must be installed locally from the source repository:
# Clone the repository
git clone <repository-url>
cd shapeonnx
# Install dependencies
pip install onnx==1.16.0 numpy==1.26.4
# Install ShapeONNX in editable mode for development
pip install -e .
# Optional: Install development tools
pip install -e ".[dev]" # Includes ruff, mypy, pytest, pytest-cov- ONNX 1.16.0: Tested opset range 17-21
- NumPy 1.26.4: Required for Python 3.11+ compatibility
- Python 3.11+: Required for modern type hint syntax (using
|for unions)
Models should be converted to ONNX IR version 21 using onnx.version_converter for maximum compatibility.
import onnx
from shapeonnx import infer_onnx_shape
from shapeonnx.utils import (
get_initializers,
get_input_nodes,
get_output_nodes,
convert_constant_to_initializer,
)
# Load and prepare model
model = onnx.load("model.onnx")
model = onnx.version_converter.convert_version(model, target_version=21)
# Extract model components
initializers = get_initializers(model)
input_nodes = get_input_nodes(model, initializers, has_batch_dim=True)
output_nodes = get_output_nodes(model, has_batch_dim=True)
# Convert Constant nodes to initializers (required preprocessing)
nodes = convert_constant_to_initializer(list(model.graph.node), initializers)
# Infer shapes
shapes = infer_onnx_shape(
input_nodes,
output_nodes,
nodes,
initializers,
has_batch_dim=True,
verbose=False,
)
# Access inferred shapes
for tensor_name, shape in shapes.items():
print(f"{tensor_name}: {shape}")Main shape inference function.
def infer_onnx_shape(
input_nodes: list[ValueInfoProto],
output_nodes: list[ValueInfoProto],
nodes: list[NodeProto],
initializers: dict[str, TensorProto],
has_batch_dim: bool = True,
verbose: bool = False,
) -> dict[str, list[int]]Parameters:
input_nodes(list[ValueInfoProto]): Model input value infosoutput_nodes(list[ValueInfoProto]): Model output value infosnodes(list[NodeProto]): Model computation nodes (Constant nodes must be converted to initializers)initializers(dict[str, TensorProto]): Model initializers (weights and constants)has_batch_dim(bool): Whether model has batch dimension (default: True)verbose(bool): Print debug information during inference (default: False)
Returns: dict[str, list[int]] - Dictionary mapping tensor names to inferred shapes
Note: Constant nodes must be converted to initializers before calling this function using convert_constant_to_initializer().
Extract shapes from model input/output nodes.
def extract_io_shapes(
nodes: list[ValueInfoProto],
has_batch_dim: bool
) -> dict[str, list[int]]Parameters:
nodes(list[ValueInfoProto]): Input or output value infoshas_batch_dim(bool): Whether tensors have batch dimension
Returns: dict[str, list[int]] - Dictionary mapping names to shapes
Convert Constant nodes to initializers (required preprocessing step).
def convert_constant_to_initializer(
nodes: list[NodeProto],
initializers: dict[str, TensorProto]
) -> list[NodeProto]Parameters:
nodes(list[NodeProto]): Model nodesinitializers(dict[str, TensorProto]): Initializer dictionary (modified in-place)
Returns: list[NodeProto] - Nodes with Constant nodes removed
Extract initializers from model.
def get_initializers(model: ModelProto) -> dict[str, TensorProto]Extract input nodes with proper shape formatting.
def get_input_nodes(
model: ModelProto,
initializers: dict[str, TensorProto],
has_batch_dim: bool
) -> list[ValueInfoProto]Extract output nodes with proper shape formatting.
def get_output_nodes(
model: ModelProto,
has_batch_dim: bool
) -> list[ValueInfoProto]ShapeONNX supports 51 operators across 10 categories:
Add, Sub, Mul, Div, DivCst, Pow, Neg
Relu, LeakyRelu, Sigmoid, Tanh, Clip, Sin, Cos, Sign
Conv, Conv1d, ConvTranspose, MaxPool, AveragePool, GlobalAveragePool
BatchNormalization
Reshape, Transpose, Squeeze, Unsqueeze, Flatten, Expand
Slice, Split, Gather, Concat
Shape, ConstantOfShape, Range
ReduceMean, ReduceSum, ArgMax
Equal, Where, Max, Min
MatMul, Gemm
Cast, Dropout, Pad, Resize, Scatter, ScatterElements, ScatterND, Softmax, Floor
- Immutable Context: Frozen dataclass for shape inference context
- Pure Functions: All shape inference functions are stateless with explicit inputs
- Direct Dictionary Access: Minimal abstraction for performance
- Full Type Hints: Complete type annotations using Python 3.11+ syntax
- Single-Pass Forward Propagation: O(1) complexity per operator
- Pre-Converted Initializers: Integer tensors converted once at initialization
- Efficient Operator Dispatch: Dictionary-based operator function mapping
- Minimal Memory Allocations: Shape lists reused where possible
Benchmark: 140 VNN-COMP 2024 models processed in approximately 6.5 seconds on Intel i5-12400F.
shapeonnx/
├── __init__.py # Public API exports
├── infer_shape.py # Main shape inference engine and ShapeInferenceContext
├── onnx_attrs.py # ONNX attribute extraction utilities
└── utils.py # Helper functions (get_initializers, input/output extraction, etc.)
import onnx
from shapeonnx import infer_onnx_shape
from shapeonnx.utils import (
get_initializers,
get_input_nodes,
get_output_nodes,
convert_constant_to_initializer,
)
# Load model
model = onnx.load("resnet18.onnx")
# Prepare components
initializers = get_initializers(model)
input_nodes = get_input_nodes(model, initializers, has_batch_dim=True)
output_nodes = get_output_nodes(model, has_batch_dim=True)
nodes = convert_constant_to_initializer(list(model.graph.node), initializers)
# Infer shapes
shapes = infer_onnx_shape(
input_nodes, output_nodes, nodes, initializers,
has_batch_dim=True, verbose=True
)
# Print all tensor shapes
for name, shape in sorted(shapes.items()):
print(f"{name}: {shape}")import onnx
from shapeonnx import infer_onnx_shape
from shapeonnx.utils import (
get_initializers,
get_input_nodes,
get_output_nodes,
convert_constant_to_initializer,
)
# Load and prepare model
model = onnx.load("model.onnx")
model = onnx.version_converter.convert_version(model, target_version=21)
initializers = get_initializers(model)
input_nodes = get_input_nodes(model, initializers, has_batch_dim=True)
output_nodes = get_output_nodes(model, has_batch_dim=True)
nodes = convert_constant_to_initializer(list(model.graph.node), initializers)
# Infer shapes for optimization
shapes = infer_onnx_shape(
input_nodes, output_nodes, nodes, initializers,
has_batch_dim=True
)
# Use shapes for optimization decisions
for node in nodes:
for input_name in node.input:
if input_name in shapes:
input_shape = shapes[input_name]
# Make optimization decisions based on shape
if len(input_shape) == 2:
# Can apply matrix-specific optimizations
passimport onnx
from shapeonnx import infer_onnx_shape
from shapeonnx.utils import (
get_initializers,
get_input_nodes,
get_output_nodes,
convert_constant_to_initializer,
)
# Model with Shape → Gather → Add → Reshape pattern
model = onnx.load("dynamic_reshape_model.onnx")
initializers = get_initializers(model)
input_nodes = get_input_nodes(model, initializers, has_batch_dim=True)
output_nodes = get_output_nodes(model, has_batch_dim=True)
nodes = convert_constant_to_initializer(list(model.graph.node), initializers)
# ShapeONNX resolves shape chains to static values
shapes = infer_onnx_shape(
input_nodes, output_nodes, nodes, initializers,
has_batch_dim=True, verbose=True
)
# Dynamic reshape operations now have static target shapes
print("Resolved static shapes for all tensors")ShapeONNX has been extensively tested on models from VNN-COMP 2024:
- Total Models Tested: 138 diverse neural networks
- Success Rate: 100% (all models successfully processed)
- Model Types: CNNs, ResNets, VGG, Vision Transformers, GANs, Graph Neural Networks
- Opset Coverage: Opset 17-21
- ONNX Consistency: Validated against ONNX reference with special handling for shape tensors
The comprehensive pytest-based test suite includes:
- Shape Inference Tests (
test_shapeonnx.py): Validates shape inference on all 136 models - Baseline Tests (
test_shapeonnx_regression.py):- Creates and verifies baselines for regression detection
- Compares with ONNX reference implementation
- Handles shape tensor differences (shapeonnx tracks values, ONNX tracks metadata)
cd shapeonnx
pytest tests/test_shapeonnx.py -v # Run shape inference tests
pytest tests/test_shapeonnx_regression.py -v # Run regression tests with ONNX comparisonExpected output: 414 passed (138 models × 3 test types)
The test suite demonstrates ShapeONNX's superior capabilities:
- Shape Tensor Tracking: Correctly infers
[1, 48, 2, 2]for Shape operations where ONNX only knows[4] - Static Resolution: Resolves
ConstantOfShapeoutputs to[1, 1, 48]where ONNX shows[-1, -1, -1] - Operator Chains: Processes
Shape → Slice → Concat → Reshapeto concrete target shapes - Dynamic to Static: Converts batch dimensions and other dynamic shapes to concrete values when statically determinable
Hardware: Intel i5-12400F (6 cores, 12 threads)
Results:
- 138 VNN-COMP 2024 models: ~9.6 seconds total (shape inference)
- Average per model: ~70 milliseconds
- Complex models (Vision Transformers, GANs): <200ms
- Simple models (ACAS Xu, TLL): <10ms
- Full test suite (590 tests): ~1.5 seconds
Memory: Typical peak memory usage under 500MB for largest models.
Note: ShapeONNX's comprehensive shape tensor tracking adds minimal overhead while providing significantly more accurate shape information than ONNX's built-in inference.
- Constant nodes must be converted to initializers before shape inference
- Asymmetric padding in Conv/Pool operations not supported
- Control flow operators (If, Loop, Scan) not supported
- Some operators have limited attribute support
- Assumes static input shapes (dynamic batch size handled via
has_batch_dimflag)
- SlimONNX: ONNX model optimization. Uses ShapeONNX for shape-dependent optimizations like constant folding and redundant operation removal.
- TorchVNNLIB: VNN-LIB to tensor converter for neural network verification.
- VNN-COMP: International Verification of Neural Networks Competition.
See CONTRIBUTING.md for development setup, testing procedures, code quality standards, pull request guidelines, and instructions for adding new operators.
MIT License. See LICENSE file for details.