[JIT][tracer] Sanity checks for tracing#10841
[JIT][tracer] Sanity checks for tracing#10841jamesr66a wants to merge 10 commits intopytorch:masterfrom
Conversation
0f4a2cb to
538be05
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
jamesr66a has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
jamesr66a has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
jamesr66a has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
jamesr66a has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
jamesr66a has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
| %3 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule | ||
| %4 : Double(3, 7) = aten::mm(%3, %2), scope: TracedModule/TracedModule[TracedModule1][mod] | ||
| %5 : Long() = prim::Constant[value={1}](), scope: TracedModule | ||
| %5 : Double() = prim::Constant[value={1}](), scope: TracedModule |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| @torch.jit.trace(torch.rand(3, 4)) | ||
| def traced_fn(x): | ||
| return pm(x) + 1 | ||
| return pm(x) + 1.0 |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_jit.py
Outdated
| def foo(x): | ||
| return torch.dropout(x, p=0.5, train=False) | ||
|
|
||
| np.testing.assert_allclose(foo(input), input) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/ir.cpp
Outdated
| if (kind() == aten::dropout && is_constant(attr::train) && !get<bool>(attr::train).value()) { | ||
| return false; | ||
| } | ||
| // batch_norm with training = False is deterministic |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/ir.cpp
Outdated
| } // namespace | ||
|
|
||
| bool Node::isNondeterministic() const { | ||
| if (nondeterminstic_aten_ops().count(kind()) == 0) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/jit/__init__.py
Outdated
| try: | ||
| mod_tensor_val = n_mod.t('value') | ||
| check_tensor_val = n_check.t('value') | ||
| except RuntimeError: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| tensor_compare_errors = None | ||
| # Check Tensor-valued constant nodes | ||
| for n_mod, n_check in zip(mod_canonicalized.nodes(), check_canonicalized.nodes()): | ||
| if n_mod.kind() == n_check.kind() and n_mod.kind() == 'prim::Constant': |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/jit/__init__.py
Outdated
| try: | ||
| traced_outs = wrap_non_iterable(module(*_clone_inputs(inputs))) | ||
| # TODO: multi-level nested compare of results. | ||
| traced_outs = [out for out in traced_outs if isinstance(out, torch.Tensor)] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/jit/__init__.py
Outdated
| iter(x) | ||
| except TypeError: | ||
| x = [x] | ||
| return x |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/jit/__init__.py
Outdated
|
|
||
| check_inputs = kwargs.pop('check_inputs', None) | ||
| disable_checks = kwargs.pop('disable_checks', False) | ||
| check_tolerance = kwargs.pop('check_tolerance', 1e-7) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
zdevito
left a comment
There was a problem hiding this comment.
I don't have additional things beside what Adam said except for nit.
We should consider setting the tolerance low enough that we don't cause anyone false positive. It would be a pretty bad experience to realize the only reason tracing failed was due to numeric nondeterminism.
| void DecayTypes(const std::shared_ptr<Graph>& graph) { | ||
| for (Value * input : graph->inputs()) { | ||
| if (input->type()->cast<TensorType>()) { | ||
| input->setType(DynamicType::get()); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/python_ir.cpp
Outdated
| ss << n; | ||
| return ss.str(); | ||
| }) | ||
| .def("sourceLocation", [](Node & n) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/python_ir.cpp
Outdated
| std::stringstream ss; | ||
| if (auto sl = n.getSourceLocation()) | ||
| sl->highlight(ss); | ||
| return ss.str(); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/jit/__init__.py
Outdated
|
|
||
|
|
||
| # Check the traced module against a set of user-provided validation inputs | ||
| def check_trace(check_inputs, func, executor_options, module, check_tolerance): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/jit/__init__.py
Outdated
| def check_trace(check_inputs, func, executor_options, module, check_tolerance): | ||
| for inputs in check_inputs: | ||
| check_mod = TopLevelTracedModule(func, **executor_options) | ||
| check_mod._create_method_from_trace('forward', func, _clone_inputs(inputs)) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/jit/__init__.py
Outdated
| try: | ||
| mod_tensor_val = n_mod.t('value') | ||
| check_tensor_val = n_check.t('value') | ||
| except RuntimeError: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/jit/__init__.py
Outdated
|
|
||
| check_inputs = kwargs.pop('check_inputs', None) | ||
| disable_checks = kwargs.pop('disable_checks', False) | ||
| check_tolerance = kwargs.pop('check_tolerance', 1e-7) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
addf644 to
4300eaf
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
jamesr66a has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
| prim::NoneGenerator, | ||
| [](Node* node) { | ||
| return [](Stack& stack) { | ||
| stack.push_back(at::Tensor()); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| except Exception as e: | ||
| msg = 'Encountered an exception while running checking trace with check inputs.\nException:\n' \ | ||
| + indent(str(e)) | ||
| raise TracingCheckError(*graph_diagnostic_info(), extra_msg=msg) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/ir.cpp
Outdated
| "aten::normal(Tensor mean, float std, *, Tensor generator) -> Tensor", | ||
| "aten::poisson(Tensor self, Tensor generator) -> Tensor", | ||
| "aten::rrelu(Tensor self, Scalar lower, Scalar upper, int training, Tensor generator) -> Tensor", | ||
| "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, int training, Tensor generator) -> Tensor" |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
zdevito
left a comment
There was a problem hiding this comment.
Probably should add aten::rand to the non-deterministic things. I am approving anyway so that it won't block landing.
|
@pytorchbot retest this please |
facebook-github-bot
left a comment
There was a problem hiding this comment.
jamesr66a has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary:
TODO: integrate into torch.onnx.export -- separate PR
*Problem:* We have a facility to trace PyTorch operations on Python code, but there are several failure modes where the trace is not representative of the actual underlying computation:
* The tracer encountered dynamic control flow
* Some computation escaped the tracer, and appeared as a Constant tensor node in the graph
* Some stateful function was traced, e.g. someone did an optimization in Python by memoizing function outputs
*Objective*: In an ideal world, this whole process would be automated and the user can trust that the system will magically capture the intended semantics from the program. Realistically speaking, we will likely have to settle with a human-in-the-loop error reporting system, allowing for the user to identify problems and modify the source code to allow for tracing.
*Stage 1* (this PR): Output-level checking & graph diff. torch.jit.trace gains a kwarg 'check_inputs', which is a list of tuples of input arguments. We will iterate through the list and trace the function again for each set of check inputs. We'll also interpret the original trace with these inputs and compare output values and graphs, printing a diff of the graph if there is a difference.
Examples:
```
torch.jit.trace(torch.rand(3, 4), check_inputs=[(torch.rand(4, 5),)])
def foo(x):
y = torch.arange(0, x.shape[0]).float()
return x + y.unsqueeze(1)
```
```
torch.jit.TracingCheckError: Tracing failed sanity checks!
ERROR: Graphs differed across invocations!
Graph diff:
graph(%0 : Dynamic) {
- %1 : Dynamic = prim::Constant[value= 0 1 2 [ CPULongType{3} ]]()
? ^
+ %1 : Dynamic = prim::Constant[value= 0 1 2 3 [ CPULongType{4} ]]()
? +++ ^
%2 : int = prim::Constant[value=0]()
%3 : Dynamic = aten::_cast_Float(%1, %2)
%4 : int = prim::Constant[value=1]()
%5 : Dynamic = aten::unsqueeze(%3, %4)
%6 : int = prim::Constant[value=1]()
%7 : Dynamic = aten::add(%0, %5, %6)
return (%7);
}
Node diff:
- %1 : Dynamic = prim::Constant[value= 0 1 2 [ CPULongType{3} ]]()
? ^
+ %1 : Dynamic = prim::Constant[value= 0 1 2 3 [ CPULongType{4} ]]()
? +++ ^
Trace source location:
dank.py(5): foo
/Users/jamesreed/onnx-fairseq/pytorch/torch/jit/__init__.py(402): wrapper
dank.py(3): <module>
Check source location:
dank.py(5): foo
/Users/jamesreed/onnx-fairseq/pytorch/torch/jit/__init__.py(281): check_trace
/Users/jamesreed/onnx-fairseq/pytorch/torch/jit/__init__.py(408): wrapper
dank.py(3): <module>
ERROR: Tensor-valued Constant nodes differed in value across invocations. This often indicates that the tracer has encountered untraceable code.
Node:
%1 : Dynamic = prim::Constant[value= 0 1 2 [ CPULongType{3} ]]()
Source Location:
dank.py(5): foo
/Users/jamesreed/onnx-fairseq/pytorch/torch/jit/__init__.py(402): wrapper
dank.py(3): <module>
Comparison exception:
Not equal to tolerance rtol=1e-07, atol=0
(shapes (3,), (4,) mismatch)
x: array([0, 1, 2])
y: array([0, 1, 2, 3])
```
==
```
torch.jit.trace(torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
def foo(x):
y = x.data
return x + y
```
```
torch.jit.TracingCheckError: Tracing failed sanity checks!
ERROR: Traced function outputs do not match the Python function outputs.
ERROR: Tensor-valued Constant nodes differed in value across invocations. This often indicates that the tracer has encountered untraceable code.
Node:
%1 : Dynamic = prim::Constant[value=<Tensor>]()
Source Location:
dank.py(6): foo
/Users/jamesreed/onnx-fairseq/pytorch/torch/jit/__init__.py(402): wrapper
dank.py(3): <module>
Comparison exception:
Not equal to tolerance rtol=1e-07, atol=0
(mismatch 100.0%)
x: array([0.397137, 0.956105, 0.169478, 0.560292, 0.392568, 0.108441,
0.97645 , 0.34412 , 0.951246, 0.793061, 0.557595, 0.770245],
dtype=float32)
y: array([0.243178, 0.315964, 0.972041, 0.0215 , 0.927751, 0.457512,
0.951092, 0.97883 , 0.048688, 0.118066, 0.779345, 0.271272],
dtype=float32)
```
==
```
import torch
torch.jit.trace(torch.rand(3, 4), check_inputs=[(torch.rand(4, 4),)])
def foo(x):
for _ in range(x.size(0)):
x = torch.neg(x)
return x
```
```
torch.jit.TracingCheckError: Tracing failed sanity checks!
ERROR: Traced function outputs do not match the Python function outputs.
ERROR: Graphs differed across invocations!
Graph diff:
graph(%0 : Dynamic) {
%1 : Dynamic = aten::neg(%0)
%2 : Dynamic = aten::neg(%1)
%3 : Dynamic = aten::neg(%2)
+ %4 : Dynamic = aten::neg(%3)
- return (%3);
? ^
+ return (%4);
? ^
}
```
==
```
import torch
def foo(x):
if not hasattr(foo, 'cache'):
foo.cache = torch.neg(x)
return x + foo.cache
traced = torch.jit.trace(torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])(foo)
```
```
torch.jit.TracingCheckError: Tracing failed sanity checks!
ERROR: Traced function outputs do not match the Python function outputs.
ERROR: Graphs differed across invocations!
Graph diff:
graph(%0 : Dynamic) {
- %1 : Dynamic = aten::neg(%0)
+ %1 : Dynamic = prim::Constant[value=<Tensor>]()
%2 : int = prim::Constant[value=1]()
%3 : Dynamic = aten::add(%0, %1, %2)
return (%3);
}
Node diff:
- %1 : Dynamic = aten::neg(%0)
+ %1 : Dynamic = prim::Constant[value=<Tensor>]()
Trace source location:
test.py(5): foo
/Users/jamesreed/onnx-fairseq/pytorch/torch/jit/__init__.py(402): wrapper
test.py(8): <module>
Check source location:
test.py(6): foo
/Users/jamesreed/onnx-fairseq/pytorch/torch/jit/__init__.py(281): check_trace
/Users/jamesreed/onnx-fairseq/pytorch/torch/jit/__init__.py(408): wrapper
test.py(8): <module>
```
The following two examples show instances where program semantics are lost in the Python -> trace transformation, and repeated invocation does not give us useful debug information. Further design in underway for catching these scenarios.
```
import torch
torch.jit.trace(torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
def foo(x):
for i in range(3):
x[i, :] = torch.zeros(4)
return x
```
```
torch.jit.TracingCheckError: Tracing failed sanity checks!
ERROR: Traced function outputs do not match the Python function outputs.
Exception:
Not equal to tolerance rtol=1e-07, atol=0
(mismatch 100.0%)
x: array([0.830221, 0.915481, 0.940281, 0.555241], dtype=float32)
y: array([0., 0., 0., 0.], dtype=float32)
```
==
```
import torch
torch.jit.trace(torch.rand(3, 4), check_inputs=[(torch.rand(5, 6),)])
def foo(x):
x.view(-1).add_(-x.view(-1))
return x
```
```
torch.jit.TracingCheckError: Tracing failed sanity checks!
ERROR: Traced function outputs do not match the Python function outputs.
Exception:
Not equal to tolerance rtol=1e-07, atol=0
(mismatch 100.0%)
x: array([0.734441, 0.445327, 0.640592, 0.30076 , 0.891674, 0.124771],
dtype=float32)
y: array([0., 0., 0., 0., 0., 0.], dtype=float32)
```
Pull Request resolved: pytorch#10841
Differential Revision: D9499945
Pulled By: jamesr66a
fbshipit-source-id: 1f842a32d0b0645259cc43b29700b86d99c59a45
TODO: integrate into torch.onnx.export -- separate PR
Problem: We have a facility to trace PyTorch operations on Python code, but there are several failure modes where the trace is not representative of the actual underlying computation:
Objective: In an ideal world, this whole process would be automated and the user can trust that the system will magically capture the intended semantics from the program. Realistically speaking, we will likely have to settle with a human-in-the-loop error reporting system, allowing for the user to identify problems and modify the source code to allow for tracing.
Stage 1 (this PR): Output-level checking & graph diff. torch.jit.trace gains a kwarg 'check_inputs', which is a list of tuples of input arguments. We will iterate through the list and trace the function again for each set of check inputs. We'll also interpret the original trace with these inputs and compare output values and graphs, printing a diff of the graph if there is a difference.
Examples:
==
==
==
The following two examples show instances where program semantics are lost in the Python -> trace transformation, and repeated invocation does not give us useful debug information. Further design in underway for catching these scenarios.
==