🐛 Describe the bug
Custom passes are defined as a CustomGraphPassType or an optional Callable Function. So for now, we can only register one pass as the custom pass.
post_grad_custom_pre_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
post_grad_custom_post_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
joint_custom_pre_pass: Optional[Callable[[torch.fx.Graph], None]] = None
joint_custom_post_pass: Optional[Callable[[torch.fx.Graph], None]] = None
pre_grad_custom_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None
For example, if joint_custom_pre_pass is assigned two times, the latter pass will overwrite the previous one, without any warning.
Hence, it is better to extend the custom pass to a list.
Versions
Pytorch: main branch
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @aakhundov
🐛 Describe the bug
Custom passes are defined as a
CustomGraphPassTypeor an optionalCallable Function. So for now, we can only register one pass as the custom pass.For example, if
joint_custom_pre_passis assigned two times, the latter pass will overwrite the previous one, without any warning.Hence, it is better to extend the custom pass to a list.
Versions
Pytorch: main branch
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @aakhundov