Skip to content

Commit 7c0a4e7

Browse files
ajyufacebook-github-bot
authored andcommitted
[static runtime] convert to->to_copy (#53524)
Summary: Pull Request resolved: #53524 Add to->to_copy in the ReplaceWithCopy pass for playing well with AliasDb Test Plan: Run bench with CastedBatchOneHot fusion off (https://www.internalfb.com/intern/diff/view-version/123230476/), on adindexer and adfinder models Reviewed By: hlu1 Differential Revision: D26887050 fbshipit-source-id: 3f2fb9e27783bcdeb91c8b4181575f059317aff1
1 parent 1e99281 commit 7c0a4e7

2 files changed

Lines changed: 43 additions & 30 deletions

File tree

torch/csrc/jit/runtime/static/ops.cpp

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -784,36 +784,39 @@ REGISTER_OPERATOR_FUNCTOR(aten::pow, aten_pow, [](Node* n) -> SROperator {
784784
};
785785
});
786786
// out variant takes precedence over native
787-
REGISTER_OPERATOR_FUNCTOR(aten::to, aten_to, [](Node* n) -> SROperator {
788-
return [](ProcessedNode* p_node) {
789-
// support 4- or 5-arg for adindexer/adfinder models
790-
DCHECK(p_node->inputs().size() >= 4);
791-
const auto& in0_t = p_node->Input(0).toTensor();
792-
auto in2_i = p_node->Input(2).toBool(); // non_blocking
793-
// ignore input 3 (copy)
794-
if (p_node->Output(0).isNone()) {
795-
auto in1_i = p_node->Input(1).toScalarType();
796-
c10::optional<c10::MemoryFormat> in4_o = c10::nullopt;
797-
if (p_node->inputs().size() > 4 && p_node->Input(4).isInt()) {
798-
in4_o = p_node->Input(4).toOptional<c10::MemoryFormat>();
799-
}
800-
if (in4_o.value_or(c10::MemoryFormat::Preserve) ==
801-
c10::MemoryFormat::Preserve) {
802-
if (in0_t.is_non_overlapping_and_dense()) {
803-
in4_o = c10::nullopt;
804-
} else {
805-
in4_o = in0_t.suggest_memory_format();
787+
REGISTER_OPERATOR_FUNCTOR(
788+
static_runtime::to_copy,
789+
aten_to_copy,
790+
[](Node* n) -> SROperator {
791+
return [](ProcessedNode* p_node) {
792+
// support 4- or 5-arg for adindexer/adfinder models
793+
DCHECK(p_node->inputs().size() >= 4);
794+
const auto& in0_t = p_node->Input(0).toTensor();
795+
auto in2_i = p_node->Input(2).toBool(); // non_blocking
796+
// ignore input 3 (copy)
797+
if (p_node->Output(0).isNone()) {
798+
auto in1_i = p_node->Input(1).toScalarType();
799+
c10::optional<c10::MemoryFormat> in4_o = c10::nullopt;
800+
if (p_node->inputs().size() > 4 && p_node->Input(4).isInt()) {
801+
in4_o = p_node->Input(4).toOptional<c10::MemoryFormat>();
802+
}
803+
if (in4_o.value_or(c10::MemoryFormat::Preserve) ==
804+
c10::MemoryFormat::Preserve) {
805+
if (in0_t.is_non_overlapping_and_dense()) {
806+
in4_o = c10::nullopt;
807+
} else {
808+
in4_o = in0_t.suggest_memory_format();
809+
}
810+
}
811+
// See Note [Explicit nullopt MemoryFormat argument]
812+
p_node->Output(0) = at::detail::empty_cpu(
813+
{0}, in1_i, in0_t.layout(), in0_t.device(), c10::nullopt, in4_o);
806814
}
807-
}
808-
// See Note [Explicit nullopt MemoryFormat argument]
809-
p_node->Output(0) = at::detail::empty_cpu(
810-
{0}, in1_i, in0_t.layout(), in0_t.device(), c10::nullopt, in4_o);
811-
}
812-
auto& out_t = p_node->Output(0).toTensor();
813-
fastResizeToZero(out_t);
814-
at::native::to_copy_out(out_t, in0_t, in2_i);
815-
};
816-
});
815+
auto& out_t = p_node->Output(0).toTensor();
816+
fastResizeToZero(out_t);
817+
at::native::to_copy_out(out_t, in0_t, in2_i);
818+
};
819+
});
817820

818821
// Out variants for view ops are registered to a separate registry because
819822
// their outputs (views) can't participate in memory reuse.

torch/csrc/jit/runtime/static/passes.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,14 @@ TORCH_LIBRARY_FRAGMENT(static_runtime, m) {
332332
at::native::copy_(out, self);
333333
return out.permute(dims);
334334
});
335+
m.def(
336+
"static_runtime::to_copy(Tensor self, ScalarType dtype, bool non_blocking, bool copy) -> Tensor",
337+
[](at::Tensor self, at::ScalarType dtype, bool non_blocking, bool copy)
338+
-> at::Tensor {
339+
at::Tensor out = at::empty_like(self);
340+
at::native::copy_(out, self);
341+
return out.to(dtype, non_blocking, copy);
342+
});
335343
}
336344

337345
void ReplaceWithCopy(std::shared_ptr<torch::jit::Graph>& graph) {
@@ -357,7 +365,9 @@ void ReplaceWithCopy(std::shared_ptr<torch::jit::Graph>& graph) {
357365
{c10::Symbol::fromQualString("aten::permute"),
358366
c10::Symbol::fromQualString("static_runtime::permute_copy")},
359367
{c10::Symbol::fromQualString("aten::narrow"),
360-
c10::Symbol::fromQualString("aten::narrow_copy")}};
368+
c10::Symbol::fromQualString("aten::narrow_copy")},
369+
{c10::Symbol::fromQualString("aten::to"),
370+
c10::Symbol::fromQualString("static_runtime::to_copy")}};
361371
std::vector<std::pair<Node*, Node*>> replacement;
362372
for (auto* n : graph->nodes()) {
363373
if (!supported.count(n->kind())) {

0 commit comments

Comments
 (0)