Skip to content

[quant][fx][bc-breaking] Add required example_inputs argument to prepare_fx and prepare_qat_fx#76496

Closed
jerryzh168 wants to merge 12 commits intogh/jerryzh168/790/basefrom
gh/jerryzh168/790/head
Closed

[quant][fx][bc-breaking] Add required example_inputs argument to prepare_fx and prepare_qat_fx#76496
jerryzh168 wants to merge 12 commits intogh/jerryzh168/790/basefrom
gh/jerryzh168/790/head

Conversation

@jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Apr 27, 2022

Stack from ghstack (oldest at bottom):

Summary:
FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide
example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make #76496 (comment) simpler, but
it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now.

If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to
pass the arguments by keyword

BC-breaking Note:
Before:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)

After:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))

Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
python test/test_quantization.py TestQuantizeFxModels

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: D35984526

…_fx and prepare_qat_fx

Summary:
FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like `NON_OBSERVABLE_ARG_DICT` (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the `prepare_fx` and `prepare_qat_fx` api to require user to either provide
sample positional arguments thrugh `sample_args` or a sample keyword arguments through `sample_kwargs`, basically
`sample_args` and `sample_kwargs` can't both be None

BC-breaking Note:
Before:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)

After:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, sample_args=(torch.randn(1, 3, 224, 224),))

Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
python test/test_quantization.py TestQuantizeFxModels

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Apr 27, 2022

🔗 Helpful links

❌ 12 New Failures

As of commit a9b9293 (more details on the Dr. CI page):

Expand to see more
  • 12/12 failures introduced in this PR

🕵️ 11 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See GitHub Actions build pull / linux-bionic-py3.7-clang9 / test (default, 2, 2, linux.2xlarge) (1/11)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-11T03:51:21.6493229Z FAIL [1.338s]: test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T03:51:21.5430687Z   test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T03:51:21.5930945Z An API is considered public, if  its  `__module__` starts with `torch.` ... FAIL (0.053s)
2022-05-11T03:51:21.5931419Z     test_correct_module_names failed - num_retries_left: 1
2022-05-11T03:51:21.5957490Z   test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T03:51:21.6453024Z An API is considered public, if  its  `__module__` starts with `torch.` ... FAIL (0.052s)
2022-05-11T03:51:21.6453616Z     test_correct_module_names failed - num_retries_left: 0
2022-05-11T03:51:21.6485742Z   test_no_new_bindings (__main__.TestPublicBindings)
2022-05-11T03:51:21.6491580Z This test aims to stop the introduction of new JIT bindings into torch._C ... ok (0.004s)
2022-05-11T03:51:21.6492634Z 
2022-05-11T03:51:21.6492806Z ======================================================================
2022-05-11T03:51:21.6493229Z FAIL [1.338s]: test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T03:51:21.6493724Z An API is considered public, if  its  `__module__` starts with `torch.`
2022-05-11T03:51:21.6494232Z ----------------------------------------------------------------------
2022-05-11T03:51:21.6494692Z Traceback (most recent call last):
2022-05-11T03:51:21.6495140Z   File "test_public_bindings.py", line 394, in test_correct_module_names
2022-05-11T03:51:21.6495577Z     self.assertTrue(not failure_list, msg)
2022-05-11T03:51:21.6496627Z AssertionError: False is not true : All the APIs below do not meet our guidelines for public API from https://github.com/pytorch/pytorch/wiki/Public-API-definition-and-documentation.
2022-05-11T03:51:21.6497431Z Make sure that everything that is public is expected (in particular that the module has a properly populated `__all__` attribute) and that everything that is supposed to be public does look public (it does not start with `_` and has a `__module__` that is properly populated).
2022-05-11T03:51:21.6497901Z 
2022-05-11T03:51:21.6497977Z Full list:
2022-05-11T03:51:21.6498252Z # torch.ao.quantization.fx.prepare.ExampleInputs:

See GitHub Actions build pull / linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 1, 4, linux.4xlarge.nvidia.gpu) (2/11)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-11T04:24:55.1152365Z FAIL [2.867s]: test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T04:24:54.9664373Z   test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T04:24:55.0364502Z An API is considered public, if  its  `__module__` starts with `torch.` ... FAIL (0.073s)
2022-05-11T04:24:55.0365251Z     test_correct_module_names failed - num_retries_left: 1
2022-05-11T04:24:55.0400123Z   test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T04:24:55.1100175Z An API is considered public, if  its  `__module__` starts with `torch.` ... FAIL (0.073s)
2022-05-11T04:24:55.1100690Z     test_correct_module_names failed - num_retries_left: 0
2022-05-11T04:24:55.1143930Z   test_no_new_bindings (__main__.TestPublicBindings)
2022-05-11T04:24:55.1150483Z This test aims to stop the introduction of new JIT bindings into torch._C ... ok (0.005s)
2022-05-11T04:24:55.1151294Z 
2022-05-11T04:24:55.1151609Z ======================================================================
2022-05-11T04:24:55.1152365Z FAIL [2.867s]: test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T04:24:55.1153316Z An API is considered public, if  its  `__module__` starts with `torch.`
2022-05-11T04:24:55.1153994Z ----------------------------------------------------------------------
2022-05-11T04:24:55.1154573Z Traceback (most recent call last):
2022-05-11T04:24:55.1154939Z   File "test_public_bindings.py", line 394, in test_correct_module_names
2022-05-11T04:24:55.1155304Z     self.assertTrue(not failure_list, msg)
2022-05-11T04:24:55.1156052Z AssertionError: False is not true : All the APIs below do not meet our guidelines for public API from https://github.com/pytorch/pytorch/wiki/Public-API-definition-and-documentation.
2022-05-11T04:24:55.1156912Z Make sure that everything that is public is expected (in particular that the module has a properly populated `__all__` attribute) and that everything that is supposed to be public does look public (it does not start with `_` and has a `__module__` that is properly populated).
2022-05-11T04:24:55.1157505Z 
2022-05-11T04:24:55.1157590Z Full list:
2022-05-11T04:24:55.1157932Z # torch.ao.quantization.fx.prepare.ExampleInputs:

See GitHub Actions build pull / linux-xenial-py3.7-gcc5.4 / test (default, 1, 2, linux.2xlarge) (3/11)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-11T03:49:58.7596596Z FAIL [1.383s]: test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T03:49:58.6499967Z   test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T03:49:58.7012921Z An API is considered public, if  its  `__module__` starts with `torch.` ... FAIL (0.054s)
2022-05-11T03:49:58.7013682Z     test_correct_module_names failed - num_retries_left: 1
2022-05-11T03:49:58.7039456Z   test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T03:49:58.7555466Z An API is considered public, if  its  `__module__` starts with `torch.` ... FAIL (0.054s)
2022-05-11T03:49:58.7556013Z     test_correct_module_names failed - num_retries_left: 0
2022-05-11T03:49:58.7589263Z   test_no_new_bindings (__main__.TestPublicBindings)
2022-05-11T03:49:58.7595704Z This test aims to stop the introduction of new JIT bindings into torch._C ... ok (0.004s)
2022-05-11T03:49:58.7596034Z 
2022-05-11T03:49:58.7596174Z ======================================================================
2022-05-11T03:49:58.7596596Z FAIL [1.383s]: test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T03:49:58.7597033Z An API is considered public, if  its  `__module__` starts with `torch.`
2022-05-11T03:49:58.7599507Z ----------------------------------------------------------------------
2022-05-11T03:49:58.7599817Z Traceback (most recent call last):
2022-05-11T03:49:58.7600085Z   File "test_public_bindings.py", line 394, in test_correct_module_names
2022-05-11T03:49:58.7600339Z     self.assertTrue(not failure_list, msg)
2022-05-11T03:49:58.7600935Z AssertionError: False is not true : All the APIs below do not meet our guidelines for public API from https://github.com/pytorch/pytorch/wiki/Public-API-definition-and-documentation.
2022-05-11T03:49:58.7601590Z Make sure that everything that is public is expected (in particular that the module has a properly populated `__all__` attribute) and that everything that is supposed to be public does look public (it does not start with `_` and has a `__module__` that is properly populated).
2022-05-11T03:49:58.7601960Z 
2022-05-11T03:49:58.7602028Z Full list:
2022-05-11T03:49:58.7602263Z # torch.ao.quantization.fx.prepare.ExampleInputs:

See GitHub Actions build pull / linux-xenial-py3.7-gcc5.4 / test (distributed, 1, 1, linux.2xlarge) (4/11)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-11T03:06:52.8894154Z RuntimeError: distributed/fsdp/test_fsdp_mixed_precision failed!
2022-05-11T03:06:52.2069755Z Executing ['/opt/conda/bin/python', 'distributed/fsdp/test_fsdp_mixed_precision.py', '-v', '--import-slow-tests', '--import-disabled-tests'] ... [2022-05-11 03:06:52.206655]
2022-05-11T03:06:52.7769452Z Traceback (most recent call last):
2022-05-11T03:06:52.7770062Z   File "distributed/fsdp/test_fsdp_mixed_precision.py", line 20, in <module>
2022-05-11T03:06:52.7770460Z     from torch.distributed.fsdp.wrap import default_auto_wrap_policy
2022-05-11T03:06:52.7771164Z ImportError: cannot import name 'default_auto_wrap_policy' from 'torch.distributed.fsdp.wrap' (/opt/conda/lib/python3.7/site-packages/torch/distributed/fsdp/wrap.py)
2022-05-11T03:06:52.8888695Z Traceback (most recent call last):
2022-05-11T03:06:52.8889249Z   File "test/run_test.py", line 1072, in <module>
2022-05-11T03:06:52.8891505Z     main()
2022-05-11T03:06:52.8892036Z   File "test/run_test.py", line 1050, in main
2022-05-11T03:06:52.8893528Z     raise RuntimeError(err_message)
2022-05-11T03:06:52.8894154Z RuntimeError: distributed/fsdp/test_fsdp_mixed_precision failed!
2022-05-11T03:06:53.1713232Z + cleanup
2022-05-11T03:06:53.1713575Z + retcode=1
2022-05-11T03:06:53.1713846Z + set +x
2022-05-11T03:06:53.1747778Z ##[error]Process completed with exit code 1.
2022-05-11T03:06:53.1792252Z ##[group]Run pytorch/pytorch/.github/actions/get-workflow-job-id@master
2022-05-11T03:06:53.1792526Z with:
2022-05-11T03:06:53.1792956Z   github-token: ***
2022-05-11T03:06:53.1793125Z env:
2022-05-11T03:06:53.1793292Z   IN_CI: 1
2022-05-11T03:06:53.1793465Z   IS_GHA: 1

See GitHub Actions build Lint / lintrunner (5/11)

Step: "Run lintrunner on all files" (full log | diagnosis details | 🔁 rerun)

2022-05-11T03:10:04.8504313Z ##[error]Process completed with exit code 1.
2022-05-11T03:10:04.8479456Z         �[2m14�[0m  |ExampleInputs = Tuple[Any, ...]
2022-05-11T03:10:04.8479768Z         �[2m15�[0m  |
2022-05-11T03:10:04.8480046Z     >>> �[2m16�[0m  |�[33m__all__ = [
2022-05-11T03:10:04.8480364Z �[0m        �[2m17�[0m  |    "Pattern",
2022-05-11T03:10:04.8480691Z         �[2m18�[0m  |    "NodePattern",
2022-05-11T03:10:04.8481025Z         �[2m19�[0m  |    "QuantizerCls",
2022-05-11T03:10:04.8481168Z 
2022-05-11T03:10:04.8481198Z 
2022-05-11T03:10:04.8490329Z �[1m�[36mYou can reproduce these results locally by using `lintrunner`.�[0m
2022-05-11T03:10:04.8490901Z �[1m�[36mSee https://github.com/pytorch/pytorch/wiki/lintrunner for setup instructions.�[0m
2022-05-11T03:10:04.8504313Z ##[error]Process completed with exit code 1.
2022-05-11T03:10:04.8547908Z ##[group]Run # The easiest way to get annotations is to just run lintrunner again
2022-05-11T03:10:04.8548364Z �[36;1m# The easiest way to get annotations is to just run lintrunner again�[0m
2022-05-11T03:10:04.8548746Z �[36;1m# in JSON mode and use jq to massage the output into GitHub Actions�[0m
2022-05-11T03:10:04.8549046Z �[36;1m# workflow commands.�[0m
2022-05-11T03:10:04.8549374Z �[36;1mlintrunner --paths-cmd='git grep -Il .' --output=json | \�[0m
2022-05-11T03:10:04.8549920Z �[36;1m  jq --raw-output '"::\(if .severity == "advice" or .severity == "disabled" then "warning" else .severity end) file=\(.path),line=\(.line),col=\(.char),title=\(.code) \(.name)::" + (.description | gsub("\\n"; "%0A"))'�[0m
2022-05-11T03:10:04.8611500Z shell: /bin/bash -e {0}
2022-05-11T03:10:04.8611720Z env:
2022-05-11T03:10:04.8612023Z   pythonLocation: /opt/hostedtoolcache/Python/3.8.12/x64
2022-05-11T03:10:04.8612383Z   LD_LIBRARY_PATH: /opt/hostedtoolcache/Python/3.8.12/x64/lib

See GitHub Actions build pull / linux-xenial-py3.7-clang7-asan / test (default, 4, 5, linux.2xlarge) (6/11)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-11T04:05:50.7886264Z FAIL [2.106s]: test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T04:05:50.6448258Z   test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T04:05:50.7077157Z An API is considered public, if  its  `__module__` starts with `torch.` ... FAIL (0.066s)
2022-05-11T04:05:50.7077650Z     test_correct_module_names failed - num_retries_left: 1
2022-05-11T04:05:50.7111031Z   test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T04:05:50.7835537Z An API is considered public, if  its  `__module__` starts with `torch.` ... FAIL (0.076s)
2022-05-11T04:05:50.7836041Z     test_correct_module_names failed - num_retries_left: 0
2022-05-11T04:05:50.7876158Z   test_no_new_bindings (__main__.TestPublicBindings)
2022-05-11T04:05:50.7885054Z This test aims to stop the introduction of new JIT bindings into torch._C ... ok (0.005s)
2022-05-11T04:05:50.7885543Z 
2022-05-11T04:05:50.7885841Z ======================================================================
2022-05-11T04:05:50.7886264Z FAIL [2.106s]: test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T04:05:50.7886814Z An API is considered public, if  its  `__module__` starts with `torch.`
2022-05-11T04:05:50.7887395Z ----------------------------------------------------------------------
2022-05-11T04:05:50.7887674Z Traceback (most recent call last):
2022-05-11T04:05:50.7887942Z   File "test_public_bindings.py", line 394, in test_correct_module_names
2022-05-11T04:05:50.7888422Z     self.assertTrue(not failure_list, msg)
2022-05-11T04:05:50.7889310Z AssertionError: False is not true : All the APIs below do not meet our guidelines for public API from https://github.com/pytorch/pytorch/wiki/Public-API-definition-and-documentation.
2022-05-11T04:05:50.7890451Z Make sure that everything that is public is expected (in particular that the module has a properly populated `__all__` attribute) and that everything that is supposed to be public does look public (it does not start with `_` and has a `__module__` that is properly populated).
2022-05-11T04:05:50.7891101Z 
2022-05-11T04:05:50.7891161Z Full list:
2022-05-11T04:05:50.7891425Z # torch.ao.quantization.fx.prepare.ExampleInputs:

See GitHub Actions build pull / linux-bionic-rocm5.1-py3.7 / test (default, 2, 2, linux.rocm.gpu) (7/11)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-11T08:54:51.5889873Z RuntimeError: test_prims failed!
2022-05-11T08:54:49.9190673Z 
2022-05-11T08:54:49.9190804Z FAILED (errors=1, expected failures=3)
2022-05-11T08:54:49.9190809Z 
2022-05-11T08:54:49.9190921Z Generating XML reports...
2022-05-11T08:54:49.9191330Z Generated XML report: test-reports/python-unittest/test_prims/TEST-TestPrimsCUDA-20220511085448.xml
2022-05-11T08:54:51.5876109Z Traceback (most recent call last):
2022-05-11T08:54:51.5877018Z   File "test/run_test.py", line 1072, in <module>
2022-05-11T08:54:51.5883229Z     main()
2022-05-11T08:54:51.5883916Z   File "test/run_test.py", line 1050, in main
2022-05-11T08:54:51.5889086Z     raise RuntimeError(err_message)
2022-05-11T08:54:51.5889873Z RuntimeError: test_prims failed!
2022-05-11T08:54:54.1954378Z 
2022-05-11T08:54:54.1954744Z real	89m9.261s
2022-05-11T08:54:54.1955374Z user	112m50.849s
2022-05-11T08:54:54.1955991Z sys	20m22.934s
2022-05-11T08:54:54.1956646Z + cleanup
2022-05-11T08:54:54.1957463Z + retcode=1
2022-05-11T08:54:54.1958341Z + set +x
2022-05-11T08:54:54.2080441Z ##[error]Process completed with exit code 1.
2022-05-11T08:54:54.2157280Z ##[group]Run # copy test results back to the mounted workspace, needed sudo, resulting permissions were correct
2022-05-11T08:54:54.2158298Z �[36;1m# copy test results back to the mounted workspace, needed sudo, resulting permissions were correct�[0m

See GitHub Actions build pull / linux-bionic-py3.7-clang9 / test (crossref, 2, 2, linux.2xlarge) (8/11)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-11T04:03:12.0602204Z FAIL [1.346s]: test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T04:03:11.9523706Z   test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T04:03:12.0031838Z An API is considered public, if  its  `__module__` starts with `torch.` ... FAIL (0.053s)
2022-05-11T04:03:12.0032341Z     test_correct_module_names failed - num_retries_left: 1
2022-05-11T04:03:12.0056719Z   test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T04:03:12.0561807Z An API is considered public, if  its  `__module__` starts with `torch.` ... FAIL (0.053s)
2022-05-11T04:03:12.0562288Z     test_correct_module_names failed - num_retries_left: 0
2022-05-11T04:03:12.0595023Z   test_no_new_bindings (__main__.TestPublicBindings)
2022-05-11T04:03:12.0601334Z This test aims to stop the introduction of new JIT bindings into torch._C ... ok (0.004s)
2022-05-11T04:03:12.0601653Z 
2022-05-11T04:03:12.0601815Z ======================================================================
2022-05-11T04:03:12.0602204Z FAIL [1.346s]: test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T04:03:12.0602611Z An API is considered public, if  its  `__module__` starts with `torch.`
2022-05-11T04:03:12.0603057Z ----------------------------------------------------------------------
2022-05-11T04:03:12.0603675Z Traceback (most recent call last):
2022-05-11T04:03:12.0604075Z   File "test_public_bindings.py", line 394, in test_correct_module_names
2022-05-11T04:03:12.0604481Z     self.assertTrue(not failure_list, msg)
2022-05-11T04:03:12.0605495Z AssertionError: False is not true : All the APIs below do not meet our guidelines for public API from https://github.com/pytorch/pytorch/wiki/Public-API-definition-and-documentation.
2022-05-11T04:03:12.0606830Z Make sure that everything that is public is expected (in particular that the module has a properly populated `__all__` attribute) and that everything that is supposed to be public does look public (it does not start with `_` and has a `__module__` that is properly populated).
2022-05-11T04:03:12.0607534Z 
2022-05-11T04:03:12.0607624Z Full list:
2022-05-11T04:03:12.0607878Z # torch.ao.quantization.fx.prepare.ExampleInputs:

See GitHub Actions build pull / linux-bionic-rocm5.1-py3.7 / test (default, 1, 2, linux.rocm.gpu) (9/11)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-11T08:52:12.2175898Z FAIL [1.740s]: test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T08:52:12.0644047Z   test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T08:52:12.1389825Z An API is considered public, if  its  `__module__` starts with `torch.` ... FAIL (0.078s)
2022-05-11T08:52:12.1391280Z     test_correct_module_names failed - num_retries_left: 1
2022-05-11T08:52:12.1428260Z   test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T08:52:12.2126832Z An API is considered public, if  its  `__module__` starts with `torch.` ... FAIL (0.073s)
2022-05-11T08:52:12.2128268Z     test_correct_module_names failed - num_retries_left: 0
2022-05-11T08:52:12.2172822Z   test_no_new_bindings (__main__.TestPublicBindings)
2022-05-11T08:52:12.2173930Z This test aims to stop the introduction of new JIT bindings into torch._C ... ok (0.005s)
2022-05-11T08:52:12.2174572Z 
2022-05-11T08:52:12.2174900Z ======================================================================
2022-05-11T08:52:12.2175898Z FAIL [1.740s]: test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T08:52:12.2176940Z An API is considered public, if  its  `__module__` starts with `torch.`
2022-05-11T08:52:12.2178912Z ----------------------------------------------------------------------
2022-05-11T08:52:12.2179806Z Traceback (most recent call last):
2022-05-11T08:52:12.2180738Z   File "test_public_bindings.py", line 394, in test_correct_module_names
2022-05-11T08:52:12.2181664Z     self.assertTrue(not failure_list, msg)
2022-05-11T08:52:12.2183639Z AssertionError: False is not true : All the APIs below do not meet our guidelines for public API from https://github.com/pytorch/pytorch/wiki/Public-API-definition-and-documentation.
2022-05-11T08:52:12.2185861Z Make sure that everything that is public is expected (in particular that the module has a properly populated `__all__` attribute) and that everything that is supposed to be public does look public (it does not start with `_` and has a `__module__` that is properly populated).
2022-05-11T08:52:12.2187271Z 
2022-05-11T08:52:12.2187509Z Full list:
2022-05-11T08:52:12.2188308Z # torch.ao.quantization.fx.prepare.ExampleInputs:

See GitHub Actions build pull / linux-xenial-cuda11.3-py3.7-gcc7 / test (distributed, 2, 2, linux.8xlarge.nvidia.gpu) (10/11)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-11T03:20:26.2156427Z RuntimeError: distributed/fsdp/test_fsdp_mixed_precision failed!
2022-05-11T03:20:25.2690737Z Executing ['/opt/conda/bin/python', 'distributed/fsdp/test_fsdp_mixed_precision.py', '-v', '--import-slow-tests', '--import-disabled-tests'] ... [2022-05-11 03:20:25.268734]
2022-05-11T03:20:26.1109452Z Traceback (most recent call last):
2022-05-11T03:20:26.1109894Z   File "distributed/fsdp/test_fsdp_mixed_precision.py", line 20, in <module>
2022-05-11T03:20:26.1110343Z     from torch.distributed.fsdp.wrap import default_auto_wrap_policy
2022-05-11T03:20:26.1111204Z ImportError: cannot import name 'default_auto_wrap_policy' from 'torch.distributed.fsdp.wrap' (/opt/conda/lib/python3.7/site-packages/torch/distributed/fsdp/wrap.py)
2022-05-11T03:20:26.2149848Z Traceback (most recent call last):
2022-05-11T03:20:26.2150220Z   File "test/run_test.py", line 1072, in <module>
2022-05-11T03:20:26.2153285Z     main()
2022-05-11T03:20:26.2153559Z   File "test/run_test.py", line 1050, in main
2022-05-11T03:20:26.2156052Z     raise RuntimeError(err_message)
2022-05-11T03:20:26.2156427Z RuntimeError: distributed/fsdp/test_fsdp_mixed_precision failed!
2022-05-11T03:20:26.9117404Z + cleanup
2022-05-11T03:20:26.9117684Z + retcode=1
2022-05-11T03:20:26.9117926Z + set +x
2022-05-11T03:20:26.9157021Z ##[error]Process completed with exit code 1.
2022-05-11T03:20:26.9204298Z ##[group]Run pytorch/pytorch/.github/actions/get-workflow-job-id@master
2022-05-11T03:20:26.9204666Z with:
2022-05-11T03:20:26.9205230Z   github-token: ***
2022-05-11T03:20:26.9205469Z env:
2022-05-11T03:20:26.9205697Z   IN_CI: 1
2022-05-11T03:20:26.9205927Z   IS_GHA: 1

See GitHub Actions build pull / linux-xenial-py3.7-gcc7 / test (default, 1, 2, linux.2xlarge) (11/11)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-11T03:55:35.4579211Z FAIL [1.195s]: test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T03:55:35.3488958Z   test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T03:55:35.4004975Z An API is considered public, if  its  `__module__` starts with `torch.` ... FAIL (0.054s)
2022-05-11T03:55:35.4005446Z     test_correct_module_names failed - num_retries_left: 1
2022-05-11T03:55:35.4029559Z   test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T03:55:35.4541018Z An API is considered public, if  its  `__module__` starts with `torch.` ... FAIL (0.053s)
2022-05-11T03:55:35.4541512Z     test_correct_module_names failed - num_retries_left: 0
2022-05-11T03:55:35.4571703Z   test_no_new_bindings (__main__.TestPublicBindings)
2022-05-11T03:55:35.4578199Z This test aims to stop the introduction of new JIT bindings into torch._C ... ok (0.004s)
2022-05-11T03:55:35.4578572Z 
2022-05-11T03:55:35.4578805Z ======================================================================
2022-05-11T03:55:35.4579211Z FAIL [1.195s]: test_correct_module_names (__main__.TestPublicBindings)
2022-05-11T03:55:35.4579728Z An API is considered public, if  its  `__module__` starts with `torch.`
2022-05-11T03:55:35.4580419Z ----------------------------------------------------------------------
2022-05-11T03:55:35.4581446Z Traceback (most recent call last):
2022-05-11T03:55:35.4581856Z   File "test_public_bindings.py", line 394, in test_correct_module_names
2022-05-11T03:55:35.4582500Z     self.assertTrue(not failure_list, msg)
2022-05-11T03:55:35.4583594Z AssertionError: False is not true : All the APIs below do not meet our guidelines for public API from https://github.com/pytorch/pytorch/wiki/Public-API-definition-and-documentation.
2022-05-11T03:55:35.4584609Z Make sure that everything that is public is expected (in particular that the module has a properly populated `__all__` attribute) and that everything that is supposed to be public does look public (it does not start with `_` and has a `__module__` that is properly populated).
2022-05-11T03:55:35.4585034Z 
2022-05-11T03:55:35.4585104Z Full list:
2022-05-11T03:55:35.4585337Z # torch.ao.quantization.fx.prepare.ExampleInputs:

🕵️‍♀️ 1 failure not recognized by patterns:

The following CI failures may be due to changes from the PR
Job Step Action
GitHub Actions pull / linux-xenial-cuda11.3-py3.7-gcc7 / test (distributed, 1, 2, linux.8xlarge.nvidia.gpu) Unknown 🔁 rerun

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

jerryzh168 added a commit that referenced this pull request Apr 27, 2022
…_fx and prepare_qat_fx

Summary:
FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like `NON_OBSERVABLE_ARG_DICT` (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the `prepare_fx` and `prepare_qat_fx` api to require user to either provide
sample positional arguments thrugh `sample_args` or a sample keyword arguments through `sample_kwargs`, basically
`sample_args` and `sample_kwargs` can't both be None

BC-breaking Note:
Before:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)

After:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, sample_args=(torch.randn(1, 3, 224, 224),))

Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
python test/test_quantization.py TestQuantizeFxModels

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: d8c9c37
Pull Request resolved: #76496
@jerryzh168
Copy link
Contributor Author

@jerryzh168 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

… to prepare_fx and prepare_qat_fx"

Summary:
FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like `NON_OBSERVABLE_ARG_DICT` (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the `prepare_fx` and `prepare_qat_fx` api to require user to either provide
sample positional arguments thrugh `sample_args` or a sample keyword arguments through `sample_kwargs`, basically
`sample_args` and `sample_kwargs` can't both be None

BC-breaking Note:
Before:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)

After:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, sample_args=(torch.randn(1, 3, 224, 224),))

Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
python test/test_quantization.py TestQuantizeFxModels

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D35984526](https://our.internmc.facebook.com/intern/diff/D35984526)

[ghstack-poisoned]
@jerryzh168 jerryzh168 changed the title [quant][fx][bc-breaking] Add required sample_args argument to prepare_fx and prepare_qat_fx [quant][fx][bc-breaking] Add required example_inputs argument to prepare_fx and prepare_qat_fx Apr 28, 2022
…ent to prepare_fx and prepare_qat_fx"

Summary:
FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like `NON_OBSERVABLE_ARG_DICT` (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the `prepare_fx` and `prepare_qat_fx` api to require user to either provide
example arguments thrugh `example_inputs`, Note this api doesn't support kwargs, kwargs can make #76496 (comment) simpler, but
it will be rare, and even then we can still workaround with positional arguments, also `torch.jit.trace`(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single `example_inputs` argument for now.

If needed, we can extend the api with an optional `example_kwargs`. e.g. in case when there are a lot of arguments for forward and it makes more sense to
pass the arguments by keyword

BC-breaking Note:
Before:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)

After:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))

Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
python test/test_quantization.py TestQuantizeFxModels

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D35984526](https://our.internmc.facebook.com/intern/diff/D35984526)

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Apr 28, 2022
…are_fx and prepare_qat_fx

Summary:
FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like `NON_OBSERVABLE_ARG_DICT` (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the `prepare_fx` and `prepare_qat_fx` api to require user to either provide
example arguments thrugh `example_inputs`, Note this api doesn't support kwargs, kwargs can make #76496 (comment) simpler, but
it will be rare, and even then we can still workaround with positional arguments, also `torch.jit.trace`(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single `example_inputs` argument for now.

If needed, we can extend the api with an optional `example_kwargs`. e.g. in case when there are a lot of arguments for forward and it makes more sense to
pass the arguments by keyword

BC-breaking Note:
Before:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)

After:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))

Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
python test/test_quantization.py TestQuantizeFxModels

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 81e9c11
Pull Request resolved: #76496
@jerryzh168
Copy link
Contributor Author

@jerryzh168 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@jerryzh168 jerryzh168 requested review from andrewor14 and vkuzo May 3, 2022 21:39
@jerryzh168
Copy link
Contributor Author

changed to single example_inputs argument

…ent to prepare_fx and prepare_qat_fx"

Summary:
FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like `NON_OBSERVABLE_ARG_DICT` (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the `prepare_fx` and `prepare_qat_fx` api to require user to either provide
example arguments thrugh `example_inputs`, Note this api doesn't support kwargs, kwargs can make #76496 (comment) simpler, but
it will be rare, and even then we can still workaround with positional arguments, also `torch.jit.trace`(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single `example_inputs` argument for now.

If needed, we can extend the api with an optional `example_kwargs`. e.g. in case when there are a lot of arguments for forward and it makes more sense to
pass the arguments by keyword

BC-breaking Note:
Before:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)

After:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))

Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
python test/test_quantization.py TestQuantizeFxModels

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D35984526](https://our.internmc.facebook.com/intern/diff/D35984526)

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request May 4, 2022
…are_fx and prepare_qat_fx

Summary:
FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like `NON_OBSERVABLE_ARG_DICT` (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the `prepare_fx` and `prepare_qat_fx` api to require user to either provide
example arguments thrugh `example_inputs`, Note this api doesn't support kwargs, kwargs can make #76496 (comment) simpler, but
it will be rare, and even then we can still workaround with positional arguments, also `torch.jit.trace`(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single `example_inputs` argument for now.

If needed, we can extend the api with an optional `example_kwargs`. e.g. in case when there are a lot of arguments for forward and it makes more sense to
pass the arguments by keyword

BC-breaking Note:
Before:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)

After:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))

Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
python test/test_quantization.py TestQuantizeFxModels

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: c815e7e
Pull Request resolved: #76496
jerryzh168 added a commit that referenced this pull request May 4, 2022
…ple_inputs argument to prepare_fx and prepare_qat_fx"

Summary:
FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like `NON_OBSERVABLE_ARG_DICT` (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the `prepare_fx` and `prepare_qat_fx` api to require user to either provide
example arguments thrugh `example_inputs`, Note this api doesn't support kwargs, kwargs can make #76496 (comment) simpler, but
it will be rare, and even then we can still workaround with positional arguments, also `torch.jit.trace`(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single `example_inputs` argument for now.

If needed, we can extend the api with an optional `example_kwargs`. e.g. in case when there are a lot of arguments for forward and it makes more sense to
pass the arguments by keyword

BC-breaking Note:
Before:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)

After:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))

Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
python test/test_quantization.py TestQuantizeFxModels

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D35984526](https://our.internmc.facebook.com/intern/diff/D35984526)

[ghstack-poisoned]
…ent to prepare_fx and prepare_qat_fx"

Summary:
FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like `NON_OBSERVABLE_ARG_DICT` (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the `prepare_fx` and `prepare_qat_fx` api to require user to either provide
example arguments thrugh `example_inputs`, Note this api doesn't support kwargs, kwargs can make #76496 (comment) simpler, but
it will be rare, and even then we can still workaround with positional arguments, also `torch.jit.trace`(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single `example_inputs` argument for now.

If needed, we can extend the api with an optional `example_kwargs`. e.g. in case when there are a lot of arguments for forward and it makes more sense to
pass the arguments by keyword

BC-breaking Note:
Before:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)

After:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))

Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
python test/test_quantization.py TestQuantizeFxModels

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D35984526](https://our.internmc.facebook.com/intern/diff/D35984526)

[ghstack-poisoned]
@jerryzh168 jerryzh168 closed this May 16, 2022
jerryzh168 added a commit to jerryzh168/ClassyVision-1 that referenced this pull request May 16, 2022
Summary:
FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide
example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make pytorch/pytorch#76496 (comment) (comment) simpler, but
it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now.

If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to
pass the arguments by keyword

BC-breaking Note:
Before:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)

After:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))

Reviewed By: vkuzo, andrewor14

Differential Revision: D35984526

fbshipit-source-id: bf9db33515941c0732affdd4a3d7812a07768117
jerryzh168 added a commit to jerryzh168/mobile-vision that referenced this pull request May 16, 2022
Summary:
FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide
example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make pytorch/pytorch#76496 (comment) (comment) simpler, but
it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now.

If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to
pass the arguments by keyword

BC-breaking Note:
Before:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)

After:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))

Reviewed By: vkuzo, andrewor14

Differential Revision: D35984526

fbshipit-source-id: a6aae9a41060990c20422051e44be73bd01be88c
jerryzh168 added a commit to jerryzh168/d2go that referenced this pull request May 16, 2022
Summary:
FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide
example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make pytorch/pytorch#76496 (comment) (comment) simpler, but
it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now.

If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to
pass the arguments by keyword

BC-breaking Note:
Before:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)

After:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))

Reviewed By: vkuzo, andrewor14

Differential Revision: D35984526

fbshipit-source-id: fc3ff97ed5d16442dc5434ccec34a889ccd53129
jerryzh168 added a commit to jerryzh168/benchmark that referenced this pull request May 16, 2022
Summary:
FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide
example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make pytorch/pytorch#76496 (comment) (comment) simpler, but
it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now.

If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to
pass the arguments by keyword

BC-breaking Note:
Before:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)

After:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))

Reviewed By: vkuzo, andrewor14

Differential Revision: D35984526

fbshipit-source-id: 153f71963bc984d0a66be4980ce0e5a5096020e6
jerryzh168 added a commit to jerryzh168/ClassyVision-1 that referenced this pull request May 17, 2022
…#77608)

Summary:
X-link: pytorch/pytorch#77608

X-link: meta-pytorch/fx2trt#76

X-link: facebookresearch/d2go#249

X-link: fairinternal/ClassyVision#104

X-link: pytorch/benchmark#916

Pull Request resolved: facebookresearch#791

X-link: facebookresearch/mobile-vision#68

FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide
example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make pytorch/pytorch#76496 (comment) (comment) simpler, but
it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now.

If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to
pass the arguments by keyword

BC-breaking Note:
Before:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)

After:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))

Reviewed By: vkuzo, andrewor14

Differential Revision: D35984526

fbshipit-source-id: ef5536ff98a3e621ab0d10341940dcb4a2dfcd32
jerryzh168 added a commit to jerryzh168/d2go that referenced this pull request May 17, 2022
Summary:
X-link: pytorch/pytorch#77608

X-link: meta-pytorch/fx2trt#76

Pull Request resolved: facebookresearch#249

X-link: fairinternal/ClassyVision#104

X-link: pytorch/benchmark#916

X-link: facebookresearch/ClassyVision#791

X-link: facebookresearch/mobile-vision#68

FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide
example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make pytorch/pytorch#76496 (comment) (comment) simpler, but
it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now.

If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to
pass the arguments by keyword

BC-breaking Note:
Before:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)

After:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))

Reviewed By: vkuzo, andrewor14

Differential Revision: D35984526

fbshipit-source-id: 2fc9c06805d443fc1478d530232cdbcfeef39f67
jerryzh168 added a commit to jerryzh168/mobile-vision that referenced this pull request May 17, 2022
Summary:
X-link: pytorch/pytorch#77608

X-link: meta-pytorch/fx2trt#76

X-link: facebookresearch/d2go#249

X-link: fairinternal/ClassyVision#104

X-link: pytorch/benchmark#916

X-link: facebookresearch/ClassyVision#791

Pull Request resolved: facebookresearch#68

FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide
example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make pytorch/pytorch#76496 (comment) (comment) simpler, but
it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now.

If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to
pass the arguments by keyword

BC-breaking Note:
Before:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)

After:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))

Reviewed By: vkuzo, andrewor14

Differential Revision: D35984526

fbshipit-source-id: 2a9df6332f24650b26dfbc4c754b9156d38ea890
jerryzh168 added a commit to jerryzh168/benchmark that referenced this pull request May 17, 2022
…#77608)

Summary:
X-link: pytorch/pytorch#77608

X-link: meta-pytorch/fx2trt#76

X-link: facebookresearch/d2go#249

X-link: fairinternal/ClassyVision#104

Pull Request resolved: pytorch#916

X-link: facebookresearch/ClassyVision#791

X-link: facebookresearch/mobile-vision#68

FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide
example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make pytorch/pytorch#76496 (comment) (comment) simpler, but
it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now.

If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to
pass the arguments by keyword

BC-breaking Note:
Before:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)

After:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))

Reviewed By: vkuzo, andrewor14

Differential Revision: D35984526

fbshipit-source-id: 7e1ce6dc13a1ecc4d46939c8e3b3f3721248c727
jerryzh168 added a commit to jerryzh168/ClassyVision-1 that referenced this pull request May 19, 2022
…#77608)

Summary:
X-link: pytorch/pytorch#77608

X-link: meta-pytorch/fx2trt#76

X-link: facebookresearch/d2go#249

X-link: fairinternal/ClassyVision#104

X-link: pytorch/benchmark#916

Pull Request resolved: facebookresearch#791

X-link: facebookresearch/mobile-vision#68

FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide
example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make pytorch/pytorch#76496 (comment) (comment) simpler, but
it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now.

If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to
pass the arguments by keyword

BC-breaking Note:
Before:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)

After:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))

Reviewed By: vkuzo, andrewor14

Differential Revision: D35984526

fbshipit-source-id: 58c1e0afa7421ce79c164a31e88bb7dc4541f42b
jerryzh168 added a commit to jerryzh168/d2go that referenced this pull request May 19, 2022
Summary:
X-link: pytorch/pytorch#77608

X-link: meta-pytorch/fx2trt#76

Pull Request resolved: facebookresearch#249

X-link: fairinternal/ClassyVision#104

X-link: pytorch/benchmark#916

X-link: facebookresearch/ClassyVision#791

X-link: facebookresearch/mobile-vision#68

FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide
example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make pytorch/pytorch#76496 (comment) (comment) simpler, but
it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now.

If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to
pass the arguments by keyword

BC-breaking Note:
Before:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)

After:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))

Reviewed By: vkuzo, andrewor14

Differential Revision: D35984526

fbshipit-source-id: 7150e372404a9a6a9352163b7dce8963a7a3293b
jerryzh168 added a commit to jerryzh168/mobile-vision that referenced this pull request May 19, 2022
Summary:
X-link: pytorch/pytorch#77608

X-link: meta-pytorch/fx2trt#76

X-link: facebookresearch/d2go#249

X-link: fairinternal/ClassyVision#104

X-link: pytorch/benchmark#916

X-link: facebookresearch/ClassyVision#791

Pull Request resolved: facebookresearch#68

FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide
example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make pytorch/pytorch#76496 (comment) (comment) simpler, but
it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now.

If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to
pass the arguments by keyword

BC-breaking Note:
Before:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)

After:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))

Reviewed By: vkuzo, andrewor14

Differential Revision: D35984526

fbshipit-source-id: c01860fe846684bb1e781dac19a7b2d89d004329
jerryzh168 added a commit to jerryzh168/benchmark that referenced this pull request May 19, 2022
…#77608)

Summary:
X-link: pytorch/pytorch#77608

X-link: meta-pytorch/fx2trt#76

X-link: facebookresearch/d2go#249

X-link: fairinternal/ClassyVision#104

Pull Request resolved: pytorch#916

X-link: facebookresearch/ClassyVision#791

X-link: facebookresearch/mobile-vision#68

FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide
example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make pytorch/pytorch#76496 (comment) (comment) simpler, but
it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now.

If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to
pass the arguments by keyword

BC-breaking Note:
Before:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)

After:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))

Reviewed By: vkuzo, andrewor14

Differential Revision: D35984526

fbshipit-source-id: 7abfc1c5c57633e7a7e38060d9552e45659cb2a1
jerryzh168 added a commit to jerryzh168/ClassyVision-1 that referenced this pull request May 19, 2022
…#77608)

Summary:
X-link: pytorch/pytorch#77608

X-link: meta-pytorch/fx2trt#76

X-link: facebookresearch/d2go#249

X-link: fairinternal/ClassyVision#104

X-link: pytorch/benchmark#916

Pull Request resolved: facebookresearch#791

X-link: facebookresearch/mobile-vision#68

FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide
example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make pytorch/pytorch#76496 (comment) (comment) simpler, but
it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now.

If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to
pass the arguments by keyword

BC-breaking Note:
Before:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)

After:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))

Reviewed By: vkuzo, andrewor14

Differential Revision: D35984526

fbshipit-source-id: bc7b108b768293a74561825b2df95d84fb4822ee
jerryzh168 added a commit to jerryzh168/d2go that referenced this pull request May 19, 2022
Summary:
X-link: pytorch/pytorch#77608

X-link: meta-pytorch/fx2trt#76

Pull Request resolved: facebookresearch#249

X-link: fairinternal/ClassyVision#104

X-link: pytorch/benchmark#916

X-link: facebookresearch/ClassyVision#791

X-link: facebookresearch/mobile-vision#68

FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide
example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make pytorch/pytorch#76496 (comment) (comment) simpler, but
it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now.

If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to
pass the arguments by keyword

BC-breaking Note:
Before:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)

After:
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))

Reviewed By: vkuzo, andrewor14

Differential Revision: D35984526

fbshipit-source-id: 5b7837005a34a095b331dbca7d6a8c2d6fa5ee51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants