Skip to content

[JIT] Add tracing to custom op and simplify tracer overall #10212

Closed
goldsborough wants to merge 3 commits intopytorch:masterfrom
goldsborough:custom-op-tracing
Closed

[JIT] Add tracing to custom op and simplify tracer overall #10212
goldsborough wants to merge 3 commits intopytorch:masterfrom
goldsborough:custom-op-tracing

Conversation

@goldsborough
Copy link
Contributor

@goldsborough goldsborough commented Aug 3, 2018

This PR adds tracing infrastructure for custom operators. It also simplifies the tracer overall, and changes the codegen to do more metaprogramming there instead of via C++ (which was necessary for the custom op tracing).

To give an example of the tracer/metaprogramming change, what used to look like this in VariableType.cpp:

jit::tracer::PreTraceInfo trace_info;
  if (jit::tracer::isTracing()) {
    trace_info = jit::tracer::preRecordTrace(jit::aten::index_select, "self", self, "dim", dim, "index", index);
  }

is now simply the inlined version of preRecordTrace, minus C++ metaprogramming:

torch::jit::Node* node = nullptr;
  if (jit::tracer::isTracing()) {
    auto& graph = jit::tracer::getTracingState()->graph;
    node = graph->create(jit::aten::index_select_out, /*outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "result", result);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "dim", dim);
    jit::tracer::addInputs(node, "index", index);
    graph->appendNode(node);
  }

@zdevito @apaszke

@goldsborough goldsborough added the oncall: jit Add this issue/PR to JIT oncall triage queue label Aug 3, 2018
@goldsborough goldsborough changed the title [JIT] Add tracing to custom op and simplify tracer overall [WIP][JIT] Add tracing to custom op and simplify tracer overall Aug 3, 2018
@goldsborough goldsborough force-pushed the custom-op-tracing branch 4 times, most recently from dce382e to 52fd54d Compare August 6, 2018 20:58
@goldsborough goldsborough changed the title [WIP][JIT] Add tracing to custom op and simplify tracer overall [JIT] Add tracing to custom op and simplify tracer overall Aug 6, 2018
@goldsborough
Copy link
Contributor Author

goldsborough commented Aug 6, 2018

The base diff has landed, so this PR is now ready for review.

@apaszke take a look at the tracer changes if you can, you're likely most familiar with that code

@goldsborough goldsborough force-pushed the custom-op-tracing branch 2 times, most recently from 560aecd to e09859d Compare August 6, 2018 23:40
Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good!

@goldsborough
Copy link
Contributor Author

@pytorchbot retest this please

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

goldsborough has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

PenghuiCheng pushed a commit to PenghuiCheng/pytorch that referenced this pull request Aug 10, 2018
Summary:
This PR adds tracing infrastructure for custom operators. It also simplifies the tracer overall, and changes the codegen to do more metaprogramming there instead of via C++ (which was necessary for the custom op tracing).

To give an example of the tracer/metaprogramming change, what used to look like this in `VariableType.cpp`:

```
jit::tracer::PreTraceInfo trace_info;
  if (jit::tracer::isTracing()) {
    trace_info = jit::tracer::preRecordTrace(jit::aten::index_select, "self", self, "dim", dim, "index", index);
  }
```

is now simply the inlined version of `preRecordTrace`, minus C++ metaprogramming:

```
torch::jit::Node* node = nullptr;
  if (jit::tracer::isTracing()) {
    auto& graph = jit::tracer::getTracingState()->graph;
    node = graph->create(jit::aten::index_select_out, /*outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "result", result);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "dim", dim);
    jit::tracer::addInputs(node, "index", index);
    graph->appendNode(node);
  }
```

zdevito apaszke
Pull Request resolved: pytorch#10212

Differential Revision: D9199615

Pulled By: goldsborough

fbshipit-source-id: cd4b603c1dc01340ead407228e109c99bdba2cfc
goodlux pushed a commit to goodlux/pytorch that referenced this pull request Aug 15, 2018
Summary:
This PR adds tracing infrastructure for custom operators. It also simplifies the tracer overall, and changes the codegen to do more metaprogramming there instead of via C++ (which was necessary for the custom op tracing).

To give an example of the tracer/metaprogramming change, what used to look like this in `VariableType.cpp`:

```
jit::tracer::PreTraceInfo trace_info;
  if (jit::tracer::isTracing()) {
    trace_info = jit::tracer::preRecordTrace(jit::aten::index_select, "self", self, "dim", dim, "index", index);
  }
```

is now simply the inlined version of `preRecordTrace`, minus C++ metaprogramming:

```
torch::jit::Node* node = nullptr;
  if (jit::tracer::isTracing()) {
    auto& graph = jit::tracer::getTracingState()->graph;
    node = graph->create(jit::aten::index_select_out, /*outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "result", result);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "dim", dim);
    jit::tracer::addInputs(node, "index", index);
    graph->appendNode(node);
  }
```

zdevito apaszke
Pull Request resolved: pytorch#10212

Differential Revision: D9199615

Pulled By: goldsborough

fbshipit-source-id: cd4b603c1dc01340ead407228e109c99bdba2cfc
@ezyang ezyang added the merged label Jun 26, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants