[WIP/NO_MERGE] Prototype RegularizedShortcut#4549
[WIP/NO_MERGE] Prototype RegularizedShortcut#4549datumbox wants to merge 10 commits intopytorch:mainfrom
Conversation
jamesr66a
left a comment
There was a problem hiding this comment.
I think overall this looks OK. If I understand correctly, the procedure is:
- Iterate through the named modules in the module hierarchy, and for each module that's part of the
block_typesof interest:
a. Add the shortcut module
b. trace the module and search for a residual connection (i.e. add node with two input and a placeholder input)
c. Replace the residual connection with the shortcut module
datumbox
left a comment
There was a problem hiding this comment.
I think overall this looks OK. If I understand correctly, the procedure is...
@jamesr66a Thanks a lot for reviewing. Your description of the approach is correct.
I was worried that looping through named_modules, tracing independently the graphs of the submodules and then overwriting the original modules would be problematic. Just to be safe, below I highlight the bits that concerned me. If you have any thoughts on how to improve it I'm happy to adopt it.
torchvision/prototype/ops/_utils.py
Outdated
| if isinstance(m, block_types): | ||
| # Add the Layer directly on submodule prior tracing | ||
| # workaround due to https://github.com/pytorch/pytorch/issues/66197 | ||
| m.add_module(_MODULE_NAME, RegularizedShortcut(regularizer_layer)) |
There was a problem hiding this comment.
Ideally I wanted to create the Layer on the fly and attach it directly on the graph but the pytorch/pytorch#66197 issue prohibits me from doing this. Here I attach it to the module just before tracing it as a workaround. Any concerns?
It will be removed once the issue is fixed.
torchvision/prototype/ops/_utils.py
Outdated
| with graph.inserting_after(node): | ||
| # Always put the shortcut value first | ||
| args = node.args if node.args[0] == input else node.args[::-1] | ||
| node.replace_all_uses_with(graph.call_module(_MODULE_NAME, args)) |
There was a problem hiding this comment.
Calling the previously created module by name. Hopefully this will be replaced with something like the following:
fn_impl_traced = torch.fx.symbolic_trace(RegularizedShortcut(regularizer_layer))
args = node.args if node.args[0] == input else node.args[::-1]
fn_impl_output_node = fn_impl_traced(*map_arg(args, Proxy))
node.replace_all_uses_with(fn_impl_output_node.node)| for node in graph.nodes: | ||
| # The isinstance() won't work if the model has already been traced before because it loses | ||
| # the class info of submodules. See https://github.com/pytorch/pytorch/issues/66335 | ||
| if node.op == "call_module" and isinstance(model.get_submodule(node.target), block_types): |
There was a problem hiding this comment.
@jamesr66a We just figured out that FX traced models lose their submodule class information. This means that for a model that has been traced before, we can't use isinstance() to identify its Block type. Is this intentional or a bug?
This is an early prototype utility based on FX.
The target is to detect Residual Connections in arbitrary Model architectures and modify the network to add regularlization blocks (such as
StochasticDepth).Example usage:
Output:
Before
After addition
After deletion
Also tested with:
Affected by pytorch/pytorch#66197 and pytorch/pytorch#66335