Adding limited support for aten::Int#870
Conversation
There was a problem hiding this comment.
There are some changes that do not conform to C++ style guidelines:
diff --git a/workspace/core/lowering/passes/remove_unnecessary_casts.cpp b/tmp/changes.txt
index 7f6cc85..064a2cb 100644
--- a/workspace/core/lowering/passes/remove_unnecessary_casts.cpp
+++ b/tmp/changes.txt
@@ -1,5 +1,5 @@
-#include "torch/csrc/jit/passes/subgraph_rewrite.h"
#include "torch/csrc/jit/ir/constants.h"
+#include "torch/csrc/jit/passes/subgraph_rewrite.h"
#include "core/util/prelude.h"
@@ -10,7 +10,6 @@ namespace core {
namespace lowering {
namespace passes {
-
// Presumably this is safe since torch::jit::EraseNumberTypesOnBlock exists which just
// removes prim::TensorToNum, aten::Float, aten::Int and prim::NumToTensor nodes outright
void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph) {
@@ -77,8 +76,8 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
if (user->output()->uses().size() == 1) {
auto potential_cast = user->output()->uses()[0].user;
// The downstream user is aten::Int
- if (potential_cast->kind() == c10::Symbol::fromQualString("aten::Int")
- || potential_cast->kind() == c10::Symbol::fromQualString("aten::Float")) {
+ if (potential_cast->kind() == c10::Symbol::fromQualString("aten::Int") ||
+ potential_cast->kind() == c10::Symbol::fromQualString("aten::Float")) {
LOG_GRAPH("Downstream user is aten::Int/aten::Float");
auto arg = use.offset;
@@ -88,13 +87,11 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
LOG_GRAPH("Input " << k << " is a Tensor");
if (user->inputs()[k]->node()->kind() == c10::Symbol::fromQualString("prim::NumToTensor")) {
auto num_to_tensor = user->inputs()[k]->node();
-
- LOG_GRAPH("Found a prim::NumToTensor / aten::[Int/Float] pair with an intermediate operation:\n "
- << *(*it)
- << *num_to_tensor
- << *user
- << *potential_cast);
-
+
+ LOG_GRAPH(
+ "Found a prim::NumToTensor / aten::[Int/Float] pair with an intermediate operation:\n "
+ << *(*it) << *num_to_tensor << *user << *potential_cast);
+
// Replace the Tensor Constant with a scalar constant
LOG_GRAPH("Deleting 0-dim Tensor: " << **it);
torch::jit::WithInsertPoint gaurd(*it);
@@ -126,19 +123,16 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
// has a different schema than the original
case c10::aten::add:
new_node = g->create(
- user->kind(),
- torch::jit::ArrayRef<torch::jit::Value*>({user->inputs()[0], user->inputs()[1]}),
- 1);
+ user->kind(),
+ torch::jit::ArrayRef<torch::jit::Value*>({user->inputs()[0], user->inputs()[1]}),
+ 1);
new_node->insertAfter(user);
new_node->outputs()[0]->setType(c10::IntType::get());
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
user->destroy();
break;
default:
- new_node = g->create(
- user->kind(),
- user->inputs(),
- 1);
+ new_node = g->create(user->kind(), user->inputs(), 1);
new_node->insertAfter(user);
new_node->outputs()[0]->setType(c10::IntType::get());
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
@@ -148,7 +142,7 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
LOG_GRAPH("New intermediate operation: " << *new_node);
LOG_GRAPH(new_node->schema());
-
+
// Delete aten::Int
LOG_GRAPH("Deleting aten::[Int/Float]: " << *potential_cast);
potential_cast->output()->replaceAllUsesWith(potential_cast->inputs()[0]);
@@ -163,12 +157,11 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
}
}
}
- }
+ }
}
LOG_ERROR("Post removing single use 0-dim Tensor operations: " << *g);
}
-
} // namespace passes
} // namespace lowering
} // namespace core
diff --git a/workspace/tests/core/lowering/test_remove_unnecessary_casts.cpp b/tmp/changes.txt
index ef370a8..62f913e 100644
--- a/workspace/tests/core/lowering/test_remove_unnecessary_casts.cpp
+++ b/tmp/changes.txt
@@ -102,8 +102,7 @@ TEST(LoweringPasses, RemoveSingleUse0DTensorsIntCorrectly) {
auto first_op = *(sg->block()->nodes().begin());
torch::jit::WithInsertPoint guard(first_op);
- torch::jit::Value* r = sg->insertConstant(
- c10::scalar_to_tensor(8), c10::nullopt, first_op->scope());
+ torch::jit::Value* r = sg->insertConstant(c10::scalar_to_tensor(8), c10::nullopt, first_op->scope());
r->copyMetadata(first_op->output());
r->setType(c10::TensorType::get());
first_op->output()->replaceAllUsesWith(r);
@@ -141,8 +140,7 @@ TEST(LoweringPasses, RemoveSingleUse0DTensorsFloatCorrectly) {
auto first_op = *(sg->block()->nodes().begin());
torch::jit::WithInsertPoint guard(first_op);
- torch::jit::Value* r = sg->insertConstant(
- c10::scalar_to_tensor(8.0), c10::nullopt, first_op->scope());
+ torch::jit::Value* r = sg->insertConstant(c10::scalar_to_tensor(8.0), c10::nullopt, first_op->scope());
r->copyMetadata(first_op->output());
r->setType(c10::TensorType::get());
first_op->output()->replaceAllUsesWith(r);
ERROR: Some files do not conform to style guidelines|
Likely fixes #732 as well. |
This commit adds a pass to lower out aten::[Int/Float/Bool], aten::NumToTensor pairs w.o. exception. We are assumming this is safe as there are similar passes in PyTorch for ONNX lowering however the scope of this rule is intentionally limited to avoid possible cases where it is not safe. Therefore it should not be expected that all aten::Int issues will be solved with this change and the operator itself remains a limitation of TorchTRT Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
0D Tensors
Now we remove select more complex aten::Int cases found in
models such as BERT, like the following:
```
graph(%0: int):
%1: Tensor = prim::Constant[value={8}]()
%2: int = prim::Constant[value=1]()
%3: Tensor = prim::NumToTensor(%0)
%4: Tensor = aten::add(%1, %3, %2)
%5: int = aten::Int(%4)
%6: int = aten::add(%5, %5)
return (%6)";
graph(%0: int):
%1: int = prim::Constant[value=8]()
%4: int = aten::add(%1, %0)
%6: int = aten::add(%4, %4)
return (%6)";
```
Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
30ee238 to
8139da9
Compare
Lower logging level on debug info Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /workspace/tests/modules/hub.py (original)
+++ /workspace/tests/modules/hub.py (reformatted)
@@ -191,7 +191,6 @@
conditional_script_model = torch.jit.script(conditional_model)
torch.jit.save(conditional_script_model, "conditional_scripted.jit.pt")
-
enc = BertTokenizer.from_pretrained("bert-base-uncased")
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text)
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_trt_intercompatibility.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
Reformatting /workspace/tests/py/test_api.py
Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/py/test_qat_trt_accuracy.py
Reformatting /workspace/tests/py/test_to_backend_api.py
ERROR: Some files do not conform to style guidelinesThere was a problem hiding this comment.
There are some changes that do not conform to C++ style guidelines:
diff --git a/workspace/tests/cpp/test_default_input_types.cpp b/tmp/changes.txt
index 752f51e..a79ddaf 100644
--- a/workspace/tests/cpp/test_default_input_types.cpp
+++ b/tmp/changes.txt
@@ -116,4 +116,5 @@ TEST_P(CppAPITests, InputsRespectUserSettingFP32WeightsFP16In) {
INSTANTIATE_TEST_SUITE_P(
CompiledModuleForwardIsCloseSuite,
CppAPITests,
- testing::Values(PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat} /*unused*/, 2e-5})));
+ testing::Values(
+ PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat} /*unused*/, 2e-5})));
ERROR: Some files do not conform to style guidelinesThere was a problem hiding this comment.
There are some changes that do not conform to C++ style guidelines:
diff --git a/workspace/tests/cpp/test_default_input_types.cpp b/tmp/changes.txt
index 752f51e..a79ddaf 100644
--- a/workspace/tests/cpp/test_default_input_types.cpp
+++ b/tmp/changes.txt
@@ -116,4 +116,5 @@ TEST_P(CppAPITests, InputsRespectUserSettingFP32WeightsFP16In) {
INSTANTIATE_TEST_SUITE_P(
CompiledModuleForwardIsCloseSuite,
CppAPITests,
- testing::Values(PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat} /*unused*/, 2e-5})));
+ testing::Values(
+ PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat} /*unused*/, 2e-5})));
ERROR: Some files do not conform to style guidelinesThere was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /workspace/tests/modules/hub.py (original)
+++ /workspace/tests/modules/hub.py (reformatted)
@@ -191,7 +191,6 @@
conditional_script_model = torch.jit.script(conditional_model)
torch.jit.save(conditional_script_model, "conditional_scripted.jit.pt")
-
enc = BertTokenizer.from_pretrained("bert-base-uncased")
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text)
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_trt_intercompatibility.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
Reformatting /workspace/tests/py/test_api.py
Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/py/test_qat_trt_accuracy.py
Reformatting /workspace/tests/py/test_to_backend_api.py
ERROR: Some files do not conform to style guidelinesSigned-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
af8d22d to
83ae991
Compare
|
|
||
| std::string set_attr_pattern = R"IR( | ||
| graph(%self, %0): | ||
| None = prim::SetAttr[name="_has_warned"](%self, %0) |
There was a problem hiding this comment.
Can you mention in a comment about this specific attribute _has_warned ? Where is this used ? etc ..
| namespace lowering { | ||
| namespace passes { | ||
|
|
||
| void RemoveSetAttrs(const torch::jit::Module& mod, std::string method_name) { |
There was a problem hiding this comment.
Where is this lowering pass used ?
There was a problem hiding this comment.
There was a version of transformers that had this and it was breaking the conversion process since setattr does not have a schema. But later versions dont use this so I removed it from the set of active passes
Description
This PR adds support for aten::Int / prim::NumToTensor in a few limited cases.
prim::NumToTensor -> aten::Intprim::NumToTensor -> X -> aten::Intin cases where the tensors used are single use and can safely be fusedFixes #513, Fixes #707
Partially: #867, #829, #785, #711, #660
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: