-
Notifications
You must be signed in to change notification settings - Fork 141
4. API Guide
The model_explorer package provides the following APIs to let you visualize models, and create and visualize custom node data using Model Explorer from python code. Make sure to install it first by following the installation guide.
Table of Contents
model_explorer provides convenient APIs to quickly visualize models from files or from a PyTorch module, and a lower level API to visualize models from multiple sources.
Usage:
import model_explorer
model_explorer.visualize('/path/to/model/file')API reference:
visualize( model_paths=[], host='localhost', port=8080, extensions: list[str] = [], node_data: Union[NodeDataInfo, list[NodeDataInfo]] = [], colab_height=850, reuse_server: bool = False, reuse_server_host: str = DEFAULT_HOST, reuse_server_port: Union[int, None] = None)
Starts the Model Explorer local server and visualizes the models by the given paths.
When you've passed multiple models to model_paths, the visualization page will initially show the largest subgraph from the first model. You can easily switch between models and their subgraphs using the model graph selector in the top-right corner.
Args:
-
model_paths: str|list[str]: a model path or a list of model paths to visualize. -
host: str: The host of the server. Default to localhost. -
port: int: The port of the server. Default to 8080. -
extensions: list[str]: List of extension names to be run with model explorer. -
node_data: list[NodeDataInfo]|NodeDataInfo: The node data or a list of node data to display. Example:node_data={'name': 'my node data', 'node_data': node_data_json_str}. -
colab_height: int: The height of the embedded iFrame when running in colab. Default to 850. -
reuse_server: bool: Whether to reuse the current server/browser tab(s) to visualize. -
reuse_server_host: str: The host of the server to reuse. Default to localhost. -
reuse_server_port: int: The port of the server to reuse. If unspecified, it will try to find a running server from port 8080 to 8099.
Visualizing PyTorch models requires a slightly different approach due to their lack of a standard serialization format. Model Explorer offers a specialized API to visualize PyTorch models directly, using the ExportedProgram from torch.export.export.
Usage:
import model_explorer
import torch
import torchvision
# Prepare a PyTorch model and its inputs.
model = torchvision.models.mobilenet_v2().eval()
inputs = (torch.rand([1, 3, 224, 224]),)
ep = torch.export.export(model, inputs)
# Visualize.
model_explorer.visualize_pytorch('mobilenet', exported_program=ep)API reference:
visualize_pytorch(
name,
exported_program,
settings={'const_element_count_limit': 16},
host='localhost', port=8080, colab_height=850)
Starts the Model Explorer local server and visualizes the given PyTorch ExportedProgram.
Args:
-
name: str: The name of the model for display purpose. -
exported_program: torch.export.ExportedProgram: TheExportedProgramfromtorch.export.export. -
settings: Dict: Key-value pairs of settings. For now it only supports one setting with the keyconst_element_count_limitthat controls how many values should be returned for a const from the adapter. -
see the
visualizefunction above forhost,port,extensions,node_data, andcolab_height
Sometimes you want to load models from files as well as a PyTorch model at the same time into Model Explorer. To accomplish this, you will need to use the following lower level APIs. The basic steps are:
- Create a
configobject and add models to it. - Pass it to
visualize_from_configAPI.
Usage:
import model_explorer
import torch
import torchvision
# Prepare a PyTorch model and its inputs.
model = torchvision.models.mobilenet_v2().eval()
inputs = (torch.rand([1, 3, 224, 224]),)
# Create Model Explorer config.
config = model_explorer.config()
# Add model file path and PyTorch model to the config.
config.add_model_from_path('/path/to/model').add_model_from_pytorch('mobilenet', model, inputs)
# Visualize with config.
model_explorer.visualize_from_config(config=config)API reference:
-
config() -> ModelExplorerConfigCreates a new Model Explorer config object.
-
ModelExplorerConfig.add_model_from_path(self, path) -> ModelExplorerConfig
Adds a model path to the config.
Args:
-
path:str: the model file path to add.
-
-
ModelExplorerConfig.add_model_from_pytorch( self, name, exported_program, settings={'const_element_count_limit': 16}) -> ModelExplorerConfigAdds a PyTorch model with inputs to the config. After calling this method, Model Explorer will invoke the internal adapter to convert the given PyTorch model into Model Explorer graphs. This process might take some time depending on the complexity of the model.
Args:
-
name: str: The name of the model for display purpose. -
See
visualize_pytorchabove for theexported_programand thesettingsparameter.
-
-
ModelExplorerConfig.set_reuse_server( self, server_host: str = 'localhost', server_port: Union[int, None] = None) -> ModelExplorerConfig
Makes it to reuse the existing server instead of starting a new one.
Args:
-
server_host: str: The host of the server to reuse. -
server_port: int|None: The port of the server to reuse. If unspecified, it will try to find a running server from port 8080 to 8099.
-
-
-
visualize_from_config(config=None, host='localhost', port=8080, colab_height=850)Starts the visualization from the given config.
Args:
-
config: ModelExplorerConfig|None: the object that stores the models to be visualized. -
see the
visualizefunction above forhost,port, andcolab_height
-
model_explorer provides APIs to create custom node data and visualize it in a model graph. For more info about how custom node data works, see the user guide.
We provide a set of data classes to help you build custom node data. See the comments in node_data_builder.py as the official documentation.
From a high level, the custom node data has the following structure:
-
ModelNodeData: The top-level container storing all the data for a model. It consists one or moreGraphNodeDataobjects indexed by graph ids. -
GraphNodeData: Holds the data for a specific graph within the model. It includes:-
results: Stores the custom node values, indexed by either node ids or output tensor names. -
thresholdsorgradient: color configurations that associate each node value with a corresponding node background color or label color, enabling visual representation of the data.
-
Usage:
from model_explorer import node_data_builder as ndb
# Populate values for the main graph in a model.
main_graph_results: dict[str, ndb.NodeDataResult] = {}
main_graph_results['node_id1'] = ndb.NodeDataResult(value=100)
main_graph_results['node_id2'] = ndb.NodeDataResult(value=200)
main_graph_results['any/output/tensor/name/'] = ndb.NodeDataResult(value=300)
# Create a gradient color mapping.
#
# The minimum value in `main_graph_results` maps to the color with stop=0.
# The maximum value in `main_graph_results` maps to the color with stop=1.
# Other values maps to a interpolated color in-between.
gradient: list[ndb.GradientItem] = [
ndb.GradientItem(stop=0, bgColor='yellow'),
ndb.GradientItem(stop=1, bgColor='red'),
]
# Construct the data for the main graph.
main_graph_data = ndb.GraphNodeData(
results=main_graph_results, gradient=gradient)
# Construct the data for the model.
model_data = ndb.ModelNodeData(graphsData={'main': main_graph_data})
# You can save the data to a json file.
model_data.save_to_file('path/to/file.json')You can visualize the custom node data from data classes (see previous section above), from a json string, or from a json file (see save_to_file above). The basic steps are:
- Create a
configobject and add various custom node data sources to it. - Pass it to
visualize_from_configAPI.
Usage:
import model_explorer
from model_explorer import node_data_builder as ndb
# Create a `ModelNodeData` as shown in previous section.
model_node_data = ...
# Create a config.
config = model_explorer.config()
# Add model and custom node data to it.
(config
.add_model_from_path('/path/to/a/model')
# Add node data from a json file.
# A node data json file can be generated by calling `ModelNodeData.save_to_file`
.add_node_data_from_path('/path/to/node_data.json')
# Add node data from data class object
.add_node_data('my data', model_node_data))
# Add node data from a json string (the content of `ModelNodeData.save_to_file`)
.add_node_data('my data 2', model_node_data_json_str))
# Visualize
model_explorer.visualize_from_config(config)