Skip to content

[pytorch] move tracing logic to a separate dispatch backend#38467

Closed
ljk53 wants to merge 15 commits intogh/ljk53/139/basefrom
gh/ljk53/139/head
Closed

[pytorch] move tracing logic to a separate dispatch backend#38467
ljk53 wants to merge 15 commits intogh/ljk53/139/basefrom
gh/ljk53/139/head

Conversation

@ljk53
Copy link
Copy Markdown
Contributor

@ljk53 ljk53 commented May 14, 2020

Stack from ghstack:

This PR moves tracing logic out of the generated VariableType kernels, to associate it with a new dedicated dispatch key Tracer.
It also toggles the dispatch key set at various places to keep the semantics unchanged - see the inline [Tracing Mode Switches] note.

Sample generated code:

Tensor & __ilshift___Tensor(Tensor & self, const Tensor & other) {
  #if !defined(PYTORCH_DISABLE_TRACING)
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::__ilshift__");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  #endif
  static auto op = c10::Dispatcher::singleton().findSchemaOrThrow("aten::__ilshift__", "Tensor");
  c10::Dispatcher::singleton().redispatch<Tensor &, Tensor &, const Tensor &>(op, c10::DispatchKey::Tracer, self, other);
  #if !defined(PYTORCH_DISABLE_TRACING)
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, self);
  }
  #endif
  return self;
}

Differential Revision: D21570684

NOTE FOR REVIEWERS: This PR has internal Facebook specific changes or comments, please review them on Phabricator!

Differential Revision: D21570684

Differential Revision: [D21570684](https://our.internmc.facebook.com/intern/diff/D21570684/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D21570684/)!

[ghstack-poisoned]
ljk53 added a commit that referenced this pull request May 14, 2020
Differential Revision: [D21570684](https://our.internmc.facebook.com/intern/diff/D21570684/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D21570684/)!

ghstack-source-id: 104090205
Pull Request resolved: #38467
@dr-ci
Copy link
Copy Markdown

dr-ci Bot commented May 14, 2020

💊 CI failures summary and remediations

As of commit d85a22c (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 71 times.

…ceType"

Differential Revision: [D21570684](https://our.internmc.facebook.com/intern/diff/D21570684/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D21570684/)!

[ghstack-poisoned]
ljk53 added a commit that referenced this pull request May 14, 2020
Pull Request resolved: #38467


ghstack-source-id: 104108948

Differential Revision: [D21570684](https://our.internmc.facebook.com/intern/diff/D21570684/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D21570684/)!
@ljk53 ljk53 changed the title [WIP][pytorch] move tracing logic from VariableType to TraceType [pytorch] move tracing logic from VariableType to TraceType May 14, 2020
@ljk53 ljk53 requested review from bhosmer, ezyang and smessmer May 14, 2020 17:48
Differential Revision: [D21570684](https://our.internmc.facebook.com/intern/diff/D21570684/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D21570684/)!

[ghstack-poisoned]
@ljk53 ljk53 marked this pull request as draft May 14, 2020 19:02
Differential Revision: [D21570684](https://our.internmc.facebook.com/intern/diff/D21570684/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D21570684/)!

[ghstack-poisoned]
…d (part 1: codegen only)"

Differential Revision: [D21570684](https://our.internmc.facebook.com/intern/diff/D21570684/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D21570684/)!

Differential Revision: [D21570684](https://our.internmc.facebook.com/intern/diff/D21570684)

[ghstack-poisoned]
@ljk53 ljk53 changed the title [pytorch] move tracing logic from VariableType to TraceType [pytorch] move tracing logic to a separate dispatch backend (part 1: codegen only) May 14, 2020
@ljk53 ljk53 marked this pull request as ready for review May 14, 2020 19:37
@ljk53
Copy link
Copy Markdown
Contributor Author

ljk53 commented May 14, 2020

This PR can be reviewed - it only generates the new TraceType_*.cpp but doesn't build them yet.
The integration part is in the next PR.

…d (part 1: codegen only)"

Differential Revision: [D21570684](https://our.internmc.facebook.com/intern/diff/D21570684/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D21570684/)!

Differential Revision: [D21570684](https://our.internmc.facebook.com/intern/diff/D21570684)

[ghstack-poisoned]
@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented May 15, 2020

This looks reasonable, though I only skimmed the codegen bits. You'll probably want to land this all in one go with the second part right?

Copy link
Copy Markdown

@bhosmer bhosmer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Couple minor comments inline.

Comment thread tools/autograd/gen_variable_type.py Outdated

# TraceType templates
TRACE_DISPATCH_FIND_OP = CodeTemplate("""\
static auto op = c10::Dispatcher::singleton().findSchema({"aten::${operator_name}", "${overload_name}"});
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we factor things to eliminate the duplication between these TRACE_DISPATCH_FIND_OP and TRACE_DISPATCH_CALL_OP and PROFILE_DISPATCH_UNBOXED above? I know that takes you outside the new functionality but it seems like it's worth doing.

Oh also! I think you can replace the findSchema()/TORCH_INTERNAL_ASSERT() with findSchemaOrThrow() - see e.g. here

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched to findSchemaOrThrow() and merged TRACE_DISPATCH_FIND_OP & TRACE_DISPATCH_CALL_OP. I didn't touch PROFILE_DISPATCH_UNBOXED because I was told @ilia-cher is going to remove it (calling profiler from dispatcher instead).

TRACE_DISPATCH_FIND_OP / TRACE_DISPATCH_CALL_OP were initially separated because auto result = might be inserted before TRACE_DISPATCH_CALL_OP when it doesn't return void / mutating self. Now I pass "auto result = " into the CodeTemplate - however, I figured it unconditionally strips the trailing whitespace in the value here, so it will look like "auto result =c10::Dispatcher::singleton()..." - correct but ugly :(

base_type_call=base_type_call)

call += wrap_output(tie_return_values(), 'tmp')
call += wrap_output(tie_return_values, 'tmp')
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if this is too fussy, but I think the code would be easier to follow if tie_return_values was now passed into emit_call() explicitly. (Actually FWIW I think having format_return_variables() gather those three bits of codegen together might be making the logistics a bit higher-friction than if they were still independent, though ofc factoring to make it/them callable from both tracing and variable is totally the right thing to do, and I see why you brought them together.)

Totally up to you if you want to change any of this - basically just nits.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passed in tie_return_values explicitly.

Yeah - I grouped them together to share these util methods...

    arguments = declaration['arguments']
    returns = declaration['returns']
    func = declaration['derivative']
    name = declaration['name']
    inplace = declaration['inplace']
    is_out_fn = name.endswith('_out')
    modifies_arguments = inplace or is_out_fn
    returns_void = len(returns) == 0


#include "torch/csrc/autograd/function.h"

// ${generated_comment}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth adding a (non-generated) comment that briefly describes what these kernels do?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a link to torch/csrc/jit/OVERVIEW.md.

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented May 19, 2020

Can you please post some examples of the generated code?

ljk53 added 2 commits May 20, 2020 00:45
…d (part 1: codegen only)"

This PR is the first step to move tracing logic into a separate dispatch backend.

Sample generated code:
```
Tensor & __ilshift___Tensor(Tensor & self, const Tensor & other) {
  #if !defined(PYTORCH_DISABLE_TRACING)
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::__ilshift__");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  #endif
  static auto op = c10::Dispatcher::singleton().findSchema({"aten::__ilshift__", "Tensor"});
  TORCH_INTERNAL_ASSERT(op);
  c10::Dispatcher::singleton().redispatch<Tensor &, Tensor &, const Tensor &>(*op, c10::DispatchKey::Tracer, self, other);
  #if !defined(PYTORCH_DISABLE_TRACING)
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, self);
  }
  #endif
  return self;
}
```



Differential Revision: [D21570684](https://our.internmc.facebook.com/intern/diff/D21570684/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D21570684/)!

Differential Revision: [D21570684](https://our.internmc.facebook.com/intern/diff/D21570684)

[ghstack-poisoned]
…d (part 1: codegen only)"

This PR is the first step to move tracing logic into a separate dispatch backend.

Sample generated code:
```
Tensor & __ilshift___Tensor(Tensor & self, const Tensor & other) {
  #if !defined(PYTORCH_DISABLE_TRACING)
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::__ilshift__");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  #endif
  static auto op = c10::Dispatcher::singleton().findSchemaOrThrow("aten::__ilshift__", "Tensor");
  c10::Dispatcher::singleton().redispatch<Tensor &, Tensor &, const Tensor &>(op, c10::DispatchKey::Tracer, self, other);
  #if !defined(PYTORCH_DISABLE_TRACING)
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, self);
  }
  #endif
  return self;
}
```



Differential Revision: [D21570684](https://our.internmc.facebook.com/intern/diff/D21570684/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D21570684/)!

Differential Revision: [D21570684](https://our.internmc.facebook.com/intern/diff/D21570684)

[ghstack-poisoned]
@ljk53
Copy link
Copy Markdown
Contributor Author

ljk53 commented May 20, 2020

Can you please post some examples of the generated code?

Added to the PR description.

ljk53 added 2 commits May 26, 2020 11:10
…d (part 1: codegen only)"

This PR is the first step to move tracing logic into a separate dispatch backend.

Sample generated code:
```
Tensor & __ilshift___Tensor(Tensor & self, const Tensor & other) {
  #if !defined(PYTORCH_DISABLE_TRACING)
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::__ilshift__");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  #endif
  static auto op = c10::Dispatcher::singleton().findSchemaOrThrow("aten::__ilshift__", "Tensor");
  c10::Dispatcher::singleton().redispatch<Tensor &, Tensor &, const Tensor &>(op, c10::DispatchKey::Tracer, self, other);
  #if !defined(PYTORCH_DISABLE_TRACING)
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, self);
  }
  #endif
  return self;
}
```



Differential Revision: [D21570684](https://our.internmc.facebook.com/intern/diff/D21570684/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D21570684/)!

Differential Revision: [D21570684](https://our.internmc.facebook.com/intern/diff/D21570684)

[ghstack-poisoned]
…d (part 1: codegen only)"

This PR is the first step to move tracing logic into a separate dispatch backend.

Sample generated code:
```
Tensor & __ilshift___Tensor(Tensor & self, const Tensor & other) {
  #if !defined(PYTORCH_DISABLE_TRACING)
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::__ilshift__");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  #endif
  static auto op = c10::Dispatcher::singleton().findSchemaOrThrow("aten::__ilshift__", "Tensor");
  c10::Dispatcher::singleton().redispatch<Tensor &, Tensor &, const Tensor &>(op, c10::DispatchKey::Tracer, self, other);
  #if !defined(PYTORCH_DISABLE_TRACING)
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, self);
  }
  #endif
  return self;
}
```



Differential Revision: [D21570684](https://our.internmc.facebook.com/intern/diff/D21570684/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D21570684/)!

Differential Revision: [D21570684](https://our.internmc.facebook.com/intern/diff/D21570684)

[ghstack-poisoned]
@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented May 27, 2020

A few notes from looking at the generated code:

  • I think the jit::tracer::isTracing() is now unnecessary. Because we should just bypass the tracing key entirely if tracing is disabled, so you don't need to put these in a conditional. If getting this PR to pass all CI was challenging you can save this refactor for later, but we should definitely get rid of it at some point, and it may be easier to do so now while you have the patient open. EDIT: Actually, this PR is probably not testing this codepath in any substantive manner, so....
  • Similarly, PYTORCH_DISABLE_TRACING also seems like it is in the wrong place now. @iseeyuan, correct me if I'm wrong, but this toggle is to get rid of the tracing code for federated learning on mobile; to me, this suggests that we should just avoid registering tracing wrappers entirely when we have this flag turned on.


# TraceType templates
TRACE_DISPATCH_UNBOXED = CodeTemplate("""\
static auto op = c10::Dispatcher::singleton().findSchemaOrThrow("aten::${operator_name}", "${overload_name}");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No action needed for this PR, but we really need to come up with a better idiom for redispatching in these situations.

This PR moves tracing logic out of the generated VariableType kernels, to associate it with a new dedicated dispatch key Tracer.
It also toggles the dispatch key set at various places to keep the semantics unchanged - see the inline [Tracing Mode Switches] note.

Sample generated code:
```
Tensor & __ilshift___Tensor(Tensor & self, const Tensor & other) {
  #if !defined(PYTORCH_DISABLE_TRACING)
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::__ilshift__");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  #endif
  static auto op = c10::Dispatcher::singleton().findSchemaOrThrow("aten::__ilshift__", "Tensor");
  {
    at::tracer::impl::NoTracerDispatchMode tracer_guard;
    c10::Dispatcher::singleton().call<Tensor &, Tensor &, const Tensor &>(op, self, other);
  }
  #if !defined(PYTORCH_DISABLE_TRACING)
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, self);
  }
  #endif
  return self;
}
```



Differential Revision: [D21570684](https://our.internmc.facebook.com/intern/diff/D21570684/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D21570684/)!

Differential Revision: [D21570684](https://our.internmc.facebook.com/intern/diff/D21570684)

[ghstack-poisoned]
ljk53 added a commit that referenced this pull request Jun 3, 2020
This PR moves tracing logic out of the generated VariableType kernels, to associate it with a new dedicated dispatch key Tracer.
It also toggles the dispatch key set at various places to keep the semantics unchanged - see the inline [Tracing Mode Switches] note.

Sample generated code:
```
Tensor & __ilshift___Tensor(Tensor & self, const Tensor & other) {
  #if !defined(PYTORCH_DISABLE_TRACING)
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::__ilshift__");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  #endif
  static auto op = c10::Dispatcher::singleton().findSchemaOrThrow("aten::__ilshift__", "Tensor");
  {
    at::tracer::impl::NoTracerDispatchMode tracer_guard;
    c10::Dispatcher::singleton().call<Tensor &, Tensor &, const Tensor &>(op, self, other);
  }
  #if !defined(PYTORCH_DISABLE_TRACING)
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, self);
  }
  #endif
  return self;
}
```

Pull Request resolved: #38467


ghstack-source-id: 105161832

Differential Revision: [D21570684](https://our.internmc.facebook.com/intern/diff/D21570684/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D21570684/)!
@ezyang ezyang closed this Jun 3, 2020
@ljk53
Copy link
Copy Markdown
Contributor Author

ljk53 commented Jun 3, 2020

@ezyang this PR has not been landed and I expected you & @bhosmer to do another review as I merged the 2nd PR into this PR according as per you request :)

@ljk53 ljk53 reopened this Jun 3, 2020
This PR moves tracing logic out of the generated VariableType kernels, to associate it with a new dedicated dispatch key Tracer.
It also toggles the dispatch key set at various places to keep the semantics unchanged - see the inline [Tracing Mode Switches] note.

Sample generated code:
```
Tensor & __ilshift___Tensor(Tensor & self, const Tensor & other) {
  #if !defined(PYTORCH_DISABLE_TRACING)
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::__ilshift__");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  #endif
  static auto op = c10::Dispatcher::singleton().findSchemaOrThrow("aten::__ilshift__", "Tensor");
  c10::Dispatcher::singleton().redispatch<Tensor &, Tensor &, const Tensor &>(op, c10::DispatchKey::Tracer, self, other);
  #if !defined(PYTORCH_DISABLE_TRACING)
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, self);
  }
  #endif
  return self;
}
```



Differential Revision: [D21570684](https://our.internmc.facebook.com/intern/diff/D21570684/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D21570684/)!

Differential Revision: [D21570684](https://our.internmc.facebook.com/intern/diff/D21570684)

[ghstack-poisoned]
This PR moves tracing logic out of the generated VariableType kernels, to associate it with a new dedicated dispatch key Tracer.
It also toggles the dispatch key set at various places to keep the semantics unchanged - see the inline [Tracing Mode Switches] note.

Sample generated code:
```
Tensor & __ilshift___Tensor(Tensor & self, const Tensor & other) {
  #if !defined(PYTORCH_DISABLE_TRACING)
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::__ilshift__");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  #endif
  static auto op = c10::Dispatcher::singleton().findSchemaOrThrow("aten::__ilshift__", "Tensor");
  c10::Dispatcher::singleton().redispatch<Tensor &, Tensor &, const Tensor &>(op, c10::DispatchKey::Tracer, self, other);
  #if !defined(PYTORCH_DISABLE_TRACING)
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, self);
  }
  #endif
  return self;
}
```



Differential Revision: [D21570684](https://our.internmc.facebook.com/intern/diff/D21570684/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D21570684/)!

Differential Revision: [D21570684](https://our.internmc.facebook.com/intern/diff/D21570684)

[ghstack-poisoned]
ljk53 added a commit that referenced this pull request Jun 3, 2020
This PR moves tracing logic out of the generated VariableType kernels, to associate it with a new dedicated dispatch key Tracer.
It also toggles the dispatch key set at various places to keep the semantics unchanged - see the inline [Tracing Mode Switches] note.

Sample generated code:
```
Tensor & __ilshift___Tensor(Tensor & self, const Tensor & other) {
  #if !defined(PYTORCH_DISABLE_TRACING)
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::__ilshift__");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  #endif
  static auto op = c10::Dispatcher::singleton().findSchemaOrThrow("aten::__ilshift__", "Tensor");
  c10::Dispatcher::singleton().redispatch<Tensor &, Tensor &, const Tensor &>(op, c10::DispatchKey::Tracer, self, other);
  #if !defined(PYTORCH_DISABLE_TRACING)
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, self);
  }
  #endif
  return self;
}
```

Pull Request resolved: #38467


ghstack-source-id: 105215150

Differential Revision: [D21570684](https://our.internmc.facebook.com/intern/diff/D21570684/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D21570684/)!
zasdfgbnm added a commit that referenced this pull request Jun 5, 2020
* [ONNX] Bump up ONNX submodule to a82c6a7010e2e332d8f74ad5b0c726fd47c85376 (#39372)

Summary:
Pull Request resolved: #39372

we only bump the submodule in oss to unblock some works

Test Plan: ci

Reviewed By: hl475

Differential Revision: D21830800

fbshipit-source-id: fb4a716992efcd71926f7bba24a7c24422c17e38

* Add rpc.functions.async_execution decorator for rpc_sync/rpc_async (#39216)

Summary:
Pull Request resolved: #39216

The `rpc.functions.async_execution` decorator specifies that the
wrapped function is guaranteed to return a `torch.futures.Future`.
The decorator adds a `_wrapped_async_rpc_function` attribute to
the wrapper function. The caller retrieves this information and
then sets `isAsyncFunction` argument accordingly which is later
added to PythonCall RPC message as a field. On the callee side,
if the PythonCall carries an asynchronous function, it will cast
the function's return value to a jit::PythonFutureWrapper object,
and then install response creation and communication as a callback
on the that jit::PythonFutureWrapper.

For applications, this feature is useful when a function needs to
wait for IO or additional singaling. In those cases, marking the
user function as `rpc.functions.async_execution` will prevent it
from blocking one thread on callee for too long.

Test Plan: Imported from OSS

Reviewed By: rohan-varma

Differential Revision: D21779962

fbshipit-source-id: 6b6aa698bf6f91dad6ed2a7ee433df429b59e941

* Fix index overflow in ConvTranspose3d [attempt 2] (#39198)

Summary:
Fixes #32866, resubmit of #38970

The memory error in the issue is caused by int overflowing in col2vol. This version using mixed 32-bit and 64-bit indexing calculation lifts the maximum indexing possible without compromising the performance of ConvTranspose3d. vs 20-30% regression with pure 64-bit indexing.

This requires that input.numel() <= UINT_MAX, and channels * kernel.numel() <= UINT_MAX otherwise it raises an error. Previously, the code would crash or give incorrect results unless input.numel() * kernel.numel() <= INT_MAX.

Note that the test is a minimised reproducer for the issue.
Pull Request resolved: #39198

Differential Revision: D21817836

Pulled By: ezyang

fbshipit-source-id: b9adfe9f9dd00f04435be132966b33ac6b9efbef

* [TensorPipe] Re-enable RPC tests (#39406)

Summary:
Pull Request resolved: #39406

For now, just the RPC test (no dist autograd or dist optimizer).

I removed the skipping decorator from all the tests except those that explicitly use the ProcessGroup options.

Includes #39027.
ghstack-source-id: 105159974

Test Plan: Ran the tests several hundred times, in various build modes. Saw some flakes, but at a rate of about 0.1%

Differential Revision: D21716069

fbshipit-source-id: 9d2a99e112049a63745772c18e7a58266ed8e74e

* Added FPGA DispatchKey, DeviceType, Backend (#38938)

Summary:
ezyang,

I have added the changes to DispatchKey, DeviceType, Backend to support the out-of-tree FPGA.

cc. tataetae
Pull Request resolved: #38938

Differential Revision: D21748955

Pulled By: ezyang

fbshipit-source-id: fe76d9730818205961430d2a0e00727b5c547b32

* Leak safety in RReLU (#39347)

Summary:
Fixes gh-38966

If `THCTensor_(resizeAs)` fails to allocate, then these `free`s will never be reached. So, instead I use a wrapped tensor to do cleanup automatically.
Pull Request resolved: #39347

Differential Revision: D21838933

Pulled By: ezyang

fbshipit-source-id: 8c74ecdd720d6712a33ddef6126ea545761a269b

* [TensorPipe] Re-enable dist autograd tests (#39440)

Summary:
Pull Request resolved: #39440

After the RPC tests, re-enable the second test suite: dist autograd.
ghstack-source-id: 105165393

Test Plan: Ran the tests, several times each, in different build configs.

Differential Revision: D21858974

fbshipit-source-id: 409377d564c36fecae51b9e4c776d94187b434a2

* [TensorPipe] Re-enable dist optimizer tests (#39441)

Summary:
Pull Request resolved: #39441

This is the last test suite to be enabled for TensorPipe.
ghstack-source-id: 105166757

Test Plan: Ran the tests, hundreds of times each, in different build modes.

Differential Revision: D21858975

fbshipit-source-id: ee0a7e64b77b4b1974f031207031cc14afb3a8c2

* `as_strided` : add size and stride length check (#39301)

Summary:
Fixes #39281
Pull Request resolved: #39301

Differential Revision: D21849082

Pulled By: gchanan

fbshipit-source-id: 5d30ef10767c4d35c6cb59c5e6a9acbfe0270a40

* add_observer: respect device affinity for ReLU (#39337)

Summary:
Pull Request resolved: #39337

In #39031 we made fake quantize respect device affinity of the
original module. However, that PR only handled modules with parameters
or buffers, and did not work properly for `ReLU`.

Fixing the logic to also work for `ReLU` by passing the parent's
device when adding observers.

Test Plan:
```
python test/test_quantization.py TestDistributed.test_device_affinity
```

Imported from OSS

Differential Revision: D21821243

fbshipit-source-id: cc6abda3694b80ce8ba0440dc6c1b5b58f3c0066

* Fix `torch.backends.cudnn` mypy error (#38947)

Summary:
Fix #38410

![image](https://user-images.githubusercontent.com/6421097/82724121-74b26880-9c99-11ea-9b63-e92de2dccdf2.png)
Pull Request resolved: #38947

Differential Revision: D21765290

Pulled By: ezyang

fbshipit-source-id: 5d2b25f039a653c609d60cdaac4a7ac5812ae291

* .github: Add initial target specifier config (#39378)

Summary:
Pull Request resolved: #39378

Will initially only contain a label to trigger builds for binary tests

Signed-off-by: Eli Uriegas <eliuriegas@fb.com>

Test Plan: Imported from OSS

Differential Revision: D21864091

Pulled By: seemethere

fbshipit-source-id: f69467ccc797b6b320dc8b7f2d50a8601c172a1f

* [torch] remove integer conversion resulted in a change of sign warning (#38968)

Summary:
Pull Request resolved: #38968

As title

Reviewed By: glaringlee

Differential Revision: D21711684

fbshipit-source-id: c340360b29849fe9ab0e7be376918c92ba3629be

* Replace torch.allClose with self.assertEqual (#39424)

Summary: Pull Request resolved: #39424

Reviewed By: Krovatkin

Differential Revision: D21854870

Pulled By: ailzhang

fbshipit-source-id: eb68f1775596e4c963169033444d6d6f4f818d4f

* Do not call optimizations within freezing API (#38499)

Summary:
This patch removes call to run optimizations within freezing API.
Only dead code elimination is invoked to clean up the frozen module.
Pull Request resolved: #38499

Reviewed By: eellison

Differential Revision: D21579607

Pulled By: bzinodev

fbshipit-source-id: a6231754fea89296a3dcf07b5e37a1c43cb8d5dd

* Update key_padding_mask arg docs in MHA module (#39321)

Summary: Pull Request resolved: #39321

Reviewed By: zhangguanheng66

Differential Revision: D21825488

Pulled By: Nayef211

fbshipit-source-id: 41ee09e683c4ae838cfd488a342088d591e806e4

* [TensorExpr] Fix two bugs in Rfactor (#39268)

Summary:
The two bugs were:
* Non-reduction axes were not added when inserting the new ReduceOp, meaning if a reduction with non-reduce axes was rfactored we'd produce bad outputs. There were no tests of Rfactor with non-reduce axis so I modified a test to do this.
* The new statements were always prepended to the block, meaning writes to a buffer could be reordered after the usage of that buffer. This mostly happened in the case where we rfactor a previously rfactored reduction. There was a test of this, but since it only tested rfactoring the outer reduction axis there was never any other statements at the insertion point (the tests of the insertion point argument also do this). I added a new test which covers various rfactor-axis cases.

Also cleaned up tests, removed some helper code we don't need etc.
Pull Request resolved: #39268

Differential Revision: D21864489

Pulled By: nickgg

fbshipit-source-id: d314d20997a8472ec96b72f7a9068d6da6d2399c

* Revert D21579607: [pytorch][PR] Do not call optimizations within freezing API

Test Plan: revert-hammer

Differential Revision:
D21579607

Original commit changeset: a6231754fea8

fbshipit-source-id: 277011605eedee1c3b44fbaf877233b239adf56b

* [TensorExpr] some cleanups / fixes for LoopOptions (#39408)

Summary:
Mainly, fix a bug in the HashProvider where it would not include LoopOptions in the hash, meaning two loops would be seen as identical even if they were bound to different thread/block axes. Also added symbolic names for the different axis options.
Pull Request resolved: #39408

Differential Revision: D21864494

Pulled By: nickgg

fbshipit-source-id: 9c28729984e7a3375e026c78294c9f75b9015123

* Adds TestCase.compare_with_numpy (#39179)

Summary:
Cut from #38994.

This is a helper function for comparing torch and NumPy behavior. It updates the existing and increasingly popular _np_compare function and moves it to be a method on TestCase.
Pull Request resolved: #39179

Differential Revision: D21855082

Pulled By: mruberry

fbshipit-source-id: edca3b78ae392d32243b02bf61960898b6ba590f

* Enable some test cases in `test_memory_format_operators` (#38648)

Summary:
Re-enable some test cases in `test_memory_format_operators` since their corresponding issue has been fixed.
Pull Request resolved: #38648

Differential Revision: D21689085

Pulled By: VitalyFedyunin

fbshipit-source-id: 0aa09e0bf31ba98c8ad0191ac3afd31dda0f1d42

* Retry/skip test on URLError rather than on HTTPError (#39477)

Summary:
`HTTPError` are raised when server is overloaded, while `URLError` is
raised when network is not available
And since `HTTPError` is an extension of `URLError`, `URLError` should catch both exceptions
Pull Request resolved: #39477

Differential Revision: D21873560

Pulled By: malfet

fbshipit-source-id: 11806671b768705465f562087521ad4887fd20f7

* LayerNorm Fake FP16 Op debug (#39476)

Summary:
LayerNorm Fake FP16 Op debug.
still seeing output mismatches.
Pull Request resolved: #39476

Differential Revision: D21871748

Pulled By: hyuen

fbshipit-source-id: ab308e3acff9ce21de41b0f006cbee767983f8e4

* Selective build on Training, query based. (#39452)

Summary:
Pull Request resolved: #39452

Selective build works on training.
* VariableType_?.cpp are now selectively generated based on the operator list.
* Add a flag in pt_operator_library, "train". If it's True, an extra flag of "pt_train_operator_library" will be added to the labels. A query for "pt_train_operator_library" will be done to aggregate the training operators. With this flag we limit the generated VariableType to used training operators only, to conserve the code size. The models for inference only have train = False by default.
* For testing purpose, caffe2/fb/pytorch_trainer is created. It's based on full jit but the operators are selectively built.
* smartkeyboard_debug_model is used for test. Since the static code analysis is not applied for VariableType yet, the operators are manually added based on debugging error messages.
* At build stage, make selective build optional for training code-gen library.
The reason is that to make fb4a built, the generated VariableType.cpp needs to depend on torch_mobile_train. Torch_mobile_train is not needed for apps with inference only. In those cases training can be turned off to remove the dependency on torch_mobile_train to save size. It can also be used as a switch to check size regression introduced by training.
ghstack-source-id: 105190037

(Note: this ignores all push blocking failures!)

Test Plan:
Training:
```
buck run -c pt.build_from_deps_query=1 -c pt.selective_build=0 -c pt.static_dispatch=0 xplat/caffe2/fb/pytorch_trainer:trainer ~/models/papaya/keyboard/smartkeyboard_debug_model.pt
```

Inference, with and without the new query-based feature:
```
buck run -c pt.build_from_deps_query=1 -c pt.selective_build=0 -c pt.static_dispatch=0 xplat/caffe2/fb/lite_predictor:lite_predictor_bi -- --model=/home/myuan/models/pytext/BI/bi_pytext_0512.bc --input_dims "1,4" --input_type int64 --pytext_len=4
```
```
buck run xplat/caffe2/fb/lite_predictor:lite_predictor_bi -- --model=/home/myuan/models/pytext/BI/bi_pytext_0512.bc --input_dims "1,4" --input_type int64 --pytext_len=4
```

Reviewed By: ljk53

Differential Revision: D21459302

fbshipit-source-id: df71a46d74f8c7448cbf51990804104f1384594f

* Remove copy_imag and copy_real methods (#39065)

Summary: Pull Request resolved: #39065

Test Plan: Imported from OSS

Differential Revision: D21803939

Pulled By: anjali411

fbshipit-source-id: c7313c527eb6b54d49ef46aa0a839a3418fa8d7e

* Enabling lite interpreter in torch python API (#39181)

Summary:
Pull Request resolved: #39181

Create a python binding classes torch._C. LiteScriptModule for mobile::module, a python class called LiteScriptModule is created which wrap torch._C. LiteScriptModule.
Python class LiteScriptModule contains preliminary functions including forward, run_method and __call__.

Create a python api "load_for_lite_interpreter" under torch.jit.mobile where takes pre-saved mobile module in a file-like object as input and returns python class LiteScriptModule.

Add a python binding method "_save_to_buffer_for_mobile" under ScriptModule, and python method "_save_to_buffer_for_lite_interpreter" under RecursiveScriptModule which saves mobile module into buffer instead of file.
ghstack-source-id: 105215736

Test Plan: buck test caffe2/test:mobile

Differential Revision: D21757474

fbshipit-source-id: 758b87497d65c4686459a567d41887c7a577aa4c

* [caffe2] compute r_correction only for radam to avoid sqrt(negative) (#39393)

Summary:
Pull Request resolved: #39393

Computing r_correction should be done only for radam . Otherwise can generate floating-point exceptions.

Test Plan:
buck test caffe2/caffe2/python/operator_test:adam_test -- test_sparse_adam
with --caffe2_operator_throw_if_fp_exceptions=1 gflags option

Differential Revision: D21834296

fbshipit-source-id: a9e6a93451423e76a99f6591d21cb65d4374b008

* Device and torch._C function cleanup (#38173)

Summary:
Pull Request resolved: #38173

- Introduce torch.types.Device representing all "device-like" types
- Stubbed torch.device.__reduce__
- Stubbed all torch._C functions comprehensively
- Deleted _safe_call which is unused throughout the codebase

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Differential Revision: D21497399

Pulled By: ezyang

fbshipit-source-id: 1f534442b0ec9a70d556545d072f2c06a08b9d15

* move to.prim_dtype to lite interpreter (#39456)

Summary:
Pull Request resolved: #39456

Move aten::to.prim_dtype from full jit to lite interpreter

Test Plan: verify TTS model can be used

Reviewed By: iseeyuan

Differential Revision: D21856104

fbshipit-source-id: 774981a5c04798e3a87cf7d6e6682f35e604944e

* [ONNX]Enable tests in test_operators.py (#39431)

Summary:
Enable Dropout and SoftmaxCrossEntropy tests in test_operators.py
Pull Request resolved: #39431

Reviewed By: hl475

Differential Revision: D21877501

Pulled By: houseroad

fbshipit-source-id: 1e9b1e5cf80dc1843bdbde2662f3339e357c6654

* fix internal targets for layernorm (#39501)

Summary:
Pull Request resolved: #39501

fix internal targets, and disable the test until it is fixed

Test Plan:
built and ran the test, but venkat has to get access to nnpi before
fine tuning the last few pieces. Currently getting around 1e-5 relative error

Reviewed By: yinghai

Differential Revision: D21875657

fbshipit-source-id: 3ae762093084fa65b9aeedaef1b2ca1b1e13b587

* [ONNX] Update pytoch/onnx doc (#39480)

Summary:
Updated dos for operator_export_types and recently added op symbolics.
Pull Request resolved: #39480

Reviewed By: hl475

Differential Revision: D21877364

Pulled By: houseroad

fbshipit-source-id: 9831fe5776629da897db6d7943f830528cb916d2

* Implement rad2deg, deg2rad (#38852)

Summary:
Resolves #38372.

cc mruberry
Pull Request resolved: #38852

Differential Revision: D21868935

Pulled By: mruberry

fbshipit-source-id: ae6ded11b743c9d1cdc032984b4abe0a115290d6

* Add rpc.async_function decorator for TorchScript functions (#39267)

Summary:
Pull Request resolved: #39267

When combined with `torch.jit.script`, the order of decorators matter.
`rpc.functions.async_execution` must be the outmost one. The
`async_execution` decorator will store the TorchScript function in
attribute `_wrapped_async_rpc_function` on the wrapper function, and
pass this wrapped TorchScript function (i.e., an instance of
`torch.jit.ScriptFunction`) to RPC. The caller will mark the ScriptCall
with `isAsyncExecution=true`, and the callee will extract the returned
`Future` in C++ and install subsequent processing as a callback to
that `Future`.

Test Plan: Imported from OSS

Differential Revision: D21792688

fbshipit-source-id: de095eb148d21e9114a478e9e6047c707d34fd07

* Revert D21652452: [pytorch][PR] Fix for num_threads==1 in OpenMP "parallel for"

Test Plan: revert-hammer

Differential Revision:
D21652452

Original commit changeset: 2cda7777c0ea

fbshipit-source-id: fdd9a0346ce32a962766f57e13357dd2bc60d8b8

* Optimize GroupNorm on CPU (#28203)

Summary:
Pull Request resolved: #28203

Optimize GroupNorm on CPU
ghstack-source-id: 105149765

Test Plan: buck test mode/dev-nosan caffe2/test:nn -- "GroupNorm"

Reviewed By: houseroad

Differential Revision: D17901506

fbshipit-source-id: 5eb22ad0e8a9ab2533282b967b2818f690e48865

* [pytorch] move tracing logic to a separate dispatch backend (#38467)

Summary:
This PR moves tracing logic out of the generated VariableType kernels, to associate it with a new dedicated dispatch key Tracer.
It also toggles the dispatch key set at various places to keep the semantics unchanged - see the inline [Tracing Mode Switches] note.

Sample generated code:
```
Tensor & __ilshift___Tensor(Tensor & self, const Tensor & other) {
  #if !defined(PYTORCH_DISABLE_TRACING)
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::__ilshift__");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  #endif
  static auto op = c10::Dispatcher::singleton().findSchemaOrThrow("aten::__ilshift__", "Tensor");
  c10::Dispatcher::singleton().redispatch<Tensor &, Tensor &, const Tensor &>(op, c10::DispatchKey::Tracer, self, other);
  #if !defined(PYTORCH_DISABLE_TRACING)
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, self);
  }
  #endif
  return self;
}
```

Pull Request resolved: #38467

ghstack-source-id: 105215150

Test Plan: CI

Differential Revision: D21570684

fbshipit-source-id: 1a96761830307f9a934f38bfb9fe8b5b1763e0e0

* Implement timeout support for RRefs (#38590)

Summary:
Pull Request resolved: #38590

This PR implements timeout semantics for RRef for parity with rpc_sync and rpc_async. How it works:

- Timeout parameter is added to rpc.remote. If the rpc.remote call times out, note that the error won't be raised to the user in that call, as it is not blocking (similar to rpc_async). Instead, the timeout error will be raised the next time the RRef is used (either by pickling or to_here call).
- Error handling semantics are added to RRef to deal with the timeout errors. Previously, if there was an error creating the OwnerRRef, the callback on the local user would throw an error in a callback, resulting in an `std::terminate`. Instead of this, the error is now caught and surfaced to the user the next time the RRef is used. As part of this, we have added an `RPCErrorType` enum and defined RRef error handlers to handle the `RPCErrorrTypes` (currently just timeout and unknown)
- A timeout parameter is added to `to_here()` which gives the user control over the max amount of time it can block for.
- `ctx.prepareChildForFork()` which is called when the RRef is pickled (i.e. used as an arg over RPC) checks if the `rpc.remote()` call had timed out, and if so, raises that error to the user.
- Tests are added, primarily via delay injection.
ghstack-source-id: 105232837

Test Plan: CI

Differential Revision: D21588165

fbshipit-source-id: c9f9e8aa3521012ea1de3e0f152a41afdf8b23f3

* Fix lint (#39517)

Summary:
Fixes lint.
Pull Request resolved: #39517

Reviewed By: lw

Differential Revision: D21881495

Pulled By: mruberry

fbshipit-source-id: 43b06466d9311d16b0d78d58ed124c1f01807443

* Clean up thrust::complex from rsqrt (#39294)

Summary: Pull Request resolved: #39294

Differential Revision: D21818288

Pulled By: anjali411

fbshipit-source-id: ee7758872700a93713ab66565e2a7a9e8a088a94

* Add SymbolicShape and replace all uses of VaryingShape<ShapeSymbol> with it (#38544)

Summary:
Adding a SymbolicShape class to represent a generic tensor shape with ShapeSymbols.

Its core data structure is c10::optional<std::vector<ShapeSymbol>>. If has_value() == false, it represents an unranked tensor shape. At any dimension ShapeSymbol can contain dynamic size, checkable with ShapeSymbol::IsStatic method.

SymbolicShape now replaces all uses of VaryingShape<ShapeSymbol>, ie c10::optional<std::vector<c10::optional<ShapeSymbol>>>. The inner c10::optional wrapper around ShapeSymbol used to indicate dynamic shape, which overlaps with part of ShapeSymbol's representation.
Pull Request resolved: #38544

Reviewed By: ZolotukhinM

Differential Revision: D21693984

Pulled By: gmagogsfm

fbshipit-source-id: 6e633e4f36cf570d6fb34ac15d00ec1fb2054a09

* fix build table for ppc64le (#39475)

Summary:
This corrects the build info for ppc64le in the main README.

I am opening this PR before renaming the build job.  (So, the "live" master README has the correct "live" link and the PR does not.)
Immediately after submitting the PR, I will correct the name of the build job.  This will make the new PR link correct, and the current "master" link will briefly appear broken until this PR gets merged.
Pull Request resolved: #39475

Differential Revision: D21883184

Pulled By: malfet

fbshipit-source-id: 148353b632448c98e5aff560d31642328afe7963

* Remove cuda init patch (#39222)

Summary:
The below lines have been removed from `torch/cuda/__init__.py` anyway:
```
        _cudart = _load_cudart()
        _cudart.cudaGetErrorName.restype = ctypes.c_char_p
        _cudart.cudaGetErrorString.restype = ctypes.c_char_p
```
Pull Request resolved: #39222

Differential Revision: D21864397

Pulled By: yns88

fbshipit-source-id: 941b13f92192f930e1dfa4b385e1aec2e321e75f

* Fix lint (#39527)

Summary:
Pull Request resolved: #39527

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Differential Revision: D21884798

Pulled By: ezyang

fbshipit-source-id: a130bfd4cc122ea1d45e7db7303bf44e04f08703

* Restore docs coverage test via sphinx (#39331)

Summary:
Pull Request resolved: #39331

Fixes gh-37590

Adds an extra `make coverage` to document building, which uses the built-in facility in sphinx to check docstring coverage. Also fixes a failure to import `torch/jit/supported_ops.py` which broke the [Torchscript Builtins](https://pytorch.org/docs/stable/jit_builtin_functions.html) page.

This also adds the required `SPHINXOPTS` to turn warnings into error, but this is commented out. Note that since documentation of `torchvision` is merged in here, failures there would cause failures here if this is made active. Some thought might be needed about pinning the torchvision version merged into documentation.

The first commit should fail, since the "ScriptModule" class is commented out. I did that in order to check that a CI failure is properly reported.
Pull Request resolved: #38244

Differential Revision: D21640589

Pulled By: ezyang

fbshipit-source-id: 1e240d81669b5f21404d596de4a27d192dc9fd8a

* Add arcosh, arcsinh and arctanh to unary ops (#38388)

Summary:
This PR aims to add `arcosh`, `arcsinh` and `arctanh` support. Please see issue #38349 for more details.

**TODOs:**

* [x] Add test cases for `arcosh`, `arcsinh` and `arctanh`. (need help)
* [x] Overload ops if `std::op` does not work with `thrust::complex` types (like for `sinh`, `cosh`).

Note: `std::acosh, std::asinh, std::atanh` do not support `thrust::complex` types. Added support for complex types for these 3 ops (`arccosh, arcsinh, arctanh`)

cc: mruberry
Pull Request resolved: #38388

Differential Revision: D21882055

Pulled By: mruberry

fbshipit-source-id: d334590b47c5a89e491a002c3e41e6ffa89000e3

* Move some of the definitions in LegacyNNDefinitions.cpp closer to sites (#37531)

Summary:
Pull Request resolved: #37531

All of these definitions are no longer "legacy" as their CPU
implementations have been ported to ATen.  There are probably some
layers of indirection that could be reduced here, but for now just do a
minor but unlikely to break things cleanup.

The last thing in LegacyNNDefinitions truly is still in THCUNN and can't
be removed.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Differential Revision: D21310913

Pulled By: ezyang

fbshipit-source-id: 1ff4ff16abddf13f8d583df990242ac4b0461915

* Selective enabling of xnnpack based max_pool2d in ceil_mode. (#39447)

Summary:
max_pool2d with ceil_mode calculates output size a little differently
than what we get with xnnpack max_pool2d. Thus when ceil_mode=True, we
disable this path. However if we get the same output size with ceil_mode
and without ceil_mode, we should use xnnpack based max_pool2d.
Pull Request resolved: #39447

Test Plan: CI

Differential Revision: D21873334

Pulled By: kimishpatel

fbshipit-source-id: b84abed1505e36e492cc87e7d40664ac63964909

* .circleci: Move binary builds into their own workflow (#39379)

Summary:
Pull Request resolved: #39379

Moves binary builds into their own workflow and adds the ability to
target specification on them. This allows you to run the binary build
workflow on a pull request without the need to modify any configuration
at all.

Some notes about this implementation:
* Upload jobs are still restricted to only the nightly branches and RC tags
* Parameters for circleci are currently defined in
  .circleci/verbatim-sources/header-section.yml
* Target specification configuration is currently located at
  .github/pytorch-circleci-labels.yml

Signed-off-by: Eli Uriegas <eliuriegas@fb.com>

Test Plan: Imported from OSS

Differential Revision: D21886341

Pulled By: seemethere

fbshipit-source-id: 146ef5df2fea208d33e97862d52c170bf001bc98

* [JIT] fix broadcasting lists of ints (#39481)

Summary:
Previously, on conversion from python -> c++ it was casted to double list through bad copy pasta. It's pretty unusual for someone to script a broadcasting list function directly since it's an internal api, so it was unlikely to affect anyone.

Fix for #39450
Pull Request resolved: #39481

Reviewed By: jamesr66a

Differential Revision: D21870557

Pulled By: eellison

fbshipit-source-id: e704e5e87d2702a270b7d65c4df444246a134480

* Fix overflow issues when unpacking large numbers (#39140)

Summary:
Resolve #33111

relax the overflow and precision lost checks when unpacking doubles.

Signed-off-by: Xiong Wei <xiongw.fnst@cn.fujitsu.com>
Pull Request resolved: #39140

Differential Revision: D21885217

Pulled By: ezyang

fbshipit-source-id: e2bbe90d719443ea2e1c6b7b2c637f9a943fa5c0

* Upgrade lint. (#39483)

Summary:
Pull Request resolved: #39483

I fixed all of the new errors that occurred because of the upgrade.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Differential Revision: D21884575

Pulled By: ezyang

fbshipit-source-id: 45c8e1f1ecb410c8d7c46dd3922ad70e982a0685

* [ONNX] Fix type casting for reduce ops (#38829)

Summary:
Fix type casting for reduce ops in ONNX exporter. PyTorch promotes dtypes bool and all integer types to long for these ops.

This fix only covers traced modules where dtype is present
Pull Request resolved: #38829

Reviewed By: hl475

Differential Revision: D21833533

Pulled By: houseroad

fbshipit-source-id: 00d9ff692cc0b09d6ca169f6c63913f04b56f182

* deepcopy() of Objects should call __g/setstate__ (#39500)

Summary: Pull Request resolved: #39500

Test Plan: Imported from OSS

Reviewed By: suo

Differential Revision: D21875091

Pulled By: jamesr66a

fbshipit-source-id: 105875dd220a91bc4fcb8fcfb77fab8b626eb6cb

* Add/fix typing annotations to some functions (#39075)

Summary:
Add missing typing imports to some jit tests
Add typing annotations to `torch.testing._compare_scalars_internal` and `torch.testing._internal.assertTrue`
Pull Request resolved: #39075

Differential Revision: D21882468

Pulled By: malfet

fbshipit-source-id: dd9858eb8e11a38411544cc64daf36fced807d76

* Reduce time spent per guard by comparing TensorType with Tensor (#39098)

Summary:
It mainly reduces the time spent on allocating new TensorType object for Tensor, but comparing them directly.
benchmark result before and after this PR: https://gist.github.com/ailzhang/db44d0a1911cae62e0bb794bff33f40a
Pull Request resolved: #39098

Differential Revision: D21786678

Pulled By: ailzhang

fbshipit-source-id: 2f61f0ac1dc8c529c45bef4e149be431ff1608b0

* Do not raise decorator (#39532)

Summary:
s/raise unittest.skip/raise unittest.SkipTest/
As `unittest.skip` is a decorator while `unittest.SkipTest` is an exception
Pull Request resolved: #39532

Differential Revision: D21889152

Pulled By: malfet

fbshipit-source-id: 27a03dbf065a1e2712a63c6c27e156bd13edbbdf

* misc updates to fake fp16 tests (#39405)

Summary:
Misc updates to the fake FP16 tests.
1. seeding numpy with a random seed
2. test base class changed from unittest.TestCase=>serial.SerializedTestCase
3. Removed the hypothesis_test_util import
Reviewer: Hector Yuen
Pull Request resolved: #39405

Test Plan: Fake FP16 test

Differential Revision: D21890212

Pulled By: hyuen

fbshipit-source-id: 25e7e17f118655f32cdd06ea9db3cdac5277e649

* Remove useless copy on zip file load (#36362)

Summary:
Instead of copying to a buffer, then setting a tensor's storage with that buffer, create a storage directly from the file

Pull Request resolved: #36362

Pulled By: driazati

Differential Revision: D21889537

fbshipit-source-id: edbd430073c2bbf52332fe7b3b2590e7d936dedf

* [quant][graphmode] Test for another type of ops in insert_observer for if (#39380)

Summary:
Pull Request resolved: #39380

Test for inserting observers for if statement for ops that propagate quantization parameters

Test Plan: Imported from OSS

Differential Revision: D21832477

fbshipit-source-id: 6e0b4ce4a89f847af161bb22338525802adb8b41

* try always inline

* fix

Co-authored-by: Lu Fang <lufang@fb.com>
Co-authored-by: Shen Li <shenli@fb.com>
Co-authored-by: Peter Bell <peterbell10@live.co.uk>
Co-authored-by: Luca Wehrstedt <lcw@fb.com>
Co-authored-by: Dylan Bespalko <dylanbespalko@me.com>
Co-authored-by: kshitij12345 <kshitijkalambarkar@gmail.com>
Co-authored-by: Vasiliy Kuznetsov <vasiliy@fb.com>
Co-authored-by: Shawn Zhong <github@shawnzhong.com>
Co-authored-by: Eli Uriegas <eliuriegas@fb.com>
Co-authored-by: Jongsoo Park <jongsoo@fb.com>
Co-authored-by: JackCaoG <jackcao@google.com>
Co-authored-by: Zino Benaissa <zinob@fb.com>
Co-authored-by: Christian Puhrsch <cpuhrsch@fb.com>
Co-authored-by: Nick Gibson <nickg@fb.com>
Co-authored-by: Sebastian Messmer <messmer@fb.com>
Co-authored-by: Mike Ruberry <mruberry@devfair044.maas>
Co-authored-by: Nikita Shulga <nikita.shulga@oculus.com>
Co-authored-by: Venkata Chintapalli <venkatc77@gmail.com>
Co-authored-by: Martin Yuan <myuan@fb.com>
Co-authored-by: anjali411 <chourdiaanjali123@gmail.com>
Co-authored-by: Xingying Cheng <xcheng16@fb.com>
Co-authored-by: Edward Yang <ezyang@fb.com>
Co-authored-by: Linbin Yu <linbin@fb.com>
Co-authored-by: Ksenija Stanojevic <ksenija.stanojevic@gmail.com>
Co-authored-by: Hector Yuen <hyz@fb.com>
Co-authored-by: neginraoof <neginmr@utexas.edu>
Co-authored-by: Aayush Naik <aayushnaik17@gmail.com>
Co-authored-by: Oguz Ulgen <oulgen@fb.com>
Co-authored-by: Xiaomeng Yang <yangxm@fb.com>
Co-authored-by: Jiakai Liu <liujiakai@fb.com>
Co-authored-by: Rohan Varma <rvarm1@fb.com>
Co-authored-by: Yanan Cao <gmagogsfm@gmail.com>
Co-authored-by: David Clissold <cliss@us.ibm.com>
Co-authored-by: Jithun Nair <jithun.nair@amd.com>
Co-authored-by: mattip <matti.picus@gmail.com>
Co-authored-by: krshrimali <kushashwaravishrimali@gmail.com>
Co-authored-by: Kimish Patel <kimishpatel@fb.com>
Co-authored-by: Elias Ellison <eellison@fb.com>
Co-authored-by: Xiong Wei <xiongw.fnst@cn.fujitsu.com>
Co-authored-by: James Reed <jamesreed@fb.com>
Co-authored-by: Nikita Shulga <nshulga@fb.com>
Co-authored-by: Ailing Zhang <ailzhang@fb.com>
Co-authored-by: davidriazati <davidriazati@fb.com>
Co-authored-by: Jerry Zhang <jerryzh@fb.com>
ljk53 added a commit that referenced this pull request Jul 2, 2020
This PR continues the work of #38467 - decoupling Autograd and Trace for manually registered ops.

Differential Revision: [D22354804](https://our.internmc.facebook.com/intern/diff/D22354804/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22354804/)!

[ghstack-poisoned]
ljk53 added a commit that referenced this pull request Jul 2, 2020
This PR continues the work of #38467 - decoupling Autograd and Trace for manually registered ops.

Differential Revision: [D22354804](https://our.internmc.facebook.com/intern/diff/D22354804/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22354804/)!

[ghstack-poisoned]
ljk53 added a commit that referenced this pull request Jul 2, 2020
Pull Request resolved: #40903

This PR continues the work of #38467 - decoupling Autograd and Trace for manually registered ops.
ghstack-source-id: 107042483

Differential Revision: [D22354804](https://our.internmc.facebook.com/intern/diff/D22354804/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22354804/)!
ljk53 added a commit that referenced this pull request Jul 2, 2020
This PR continues the work of #38467 - decoupling Autograd and Trace for manually registered ops.

Differential Revision: [D22354804](https://our.internmc.facebook.com/intern/diff/D22354804/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22354804/)!

[ghstack-poisoned]
ljk53 added a commit that referenced this pull request Jul 2, 2020
This PR continues the work of #38467 - decoupling Autograd and Trace for manually registered ops.

Differential Revision: [D22354804](https://our.internmc.facebook.com/intern/diff/D22354804/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22354804/)!

[ghstack-poisoned]
ljk53 added a commit that referenced this pull request Jul 2, 2020
Pull Request resolved: #40903

This PR continues the work of #38467 - decoupling Autograd and Trace for manually registered ops.
ghstack-source-id: 107087630

Differential Revision: [D22354804](https://our.internmc.facebook.com/intern/diff/D22354804/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22354804/)!
ljk53 added a commit that referenced this pull request Jul 3, 2020
This PR continues the work of #38467 - decoupling Autograd and Trace for manually registered ops.

Differential Revision: [D22354804](https://our.internmc.facebook.com/intern/diff/D22354804/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22354804/)!

[ghstack-poisoned]
ljk53 added a commit that referenced this pull request Jul 3, 2020
Pull Request resolved: #40903

This PR continues the work of #38467 - decoupling Autograd and Trace for manually registered ops.
ghstack-source-id: 107106964

Differential Revision: [D22354804](https://our.internmc.facebook.com/intern/diff/D22354804/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22354804/)!
@facebook-github-bot facebook-github-bot deleted the gh/ljk53/139/head branch July 4, 2020 14:17
ljk53 added a commit that referenced this pull request Jul 6, 2020
This PR continues the work of #38467 - decoupling Autograd and Trace for manually registered ops.

Differential Revision: [D22354804](https://our.internmc.facebook.com/intern/diff/D22354804/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22354804/)!

[ghstack-poisoned]
facebook-github-bot pushed a commit that referenced this pull request Jul 6, 2020
Summary:
Pull Request resolved: #40903

This PR continues the work of #38467 - decoupling Autograd and Trace for manually registered ops.
ghstack-source-id: 107158638

Test Plan: CI

Differential Revision: D22354804

fbshipit-source-id: f5ea45ade2850296c62707a2a4449d7d67a9f5b5
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
…38467)

Summary:
This PR moves tracing logic out of the generated VariableType kernels, to associate it with a new dedicated dispatch key Tracer.
It also toggles the dispatch key set at various places to keep the semantics unchanged - see the inline [Tracing Mode Switches] note.

Sample generated code:
```
Tensor & __ilshift___Tensor(Tensor & self, const Tensor & other) {
  #if !defined(PYTORCH_DISABLE_TRACING)
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::__ilshift__");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "other", other);
    tracer_state->graph->insertNode(node);

    jit::tracer::setTracingState(nullptr);
  }
  #endif
  static auto op = c10::Dispatcher::singleton().findSchemaOrThrow("aten::__ilshift__", "Tensor");
  c10::Dispatcher::singleton().redispatch<Tensor &, Tensor &, const Tensor &>(op, c10::DispatchKey::Tracer, self, other);
  #if !defined(PYTORCH_DISABLE_TRACING)
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, self);
  }
  #endif
  return self;
}
```

Pull Request resolved: pytorch#38467

ghstack-source-id: 105215150

Test Plan: CI

Differential Revision: D21570684

fbshipit-source-id: 1a96761830307f9a934f38bfb9fe8b5b1763e0e0
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
Pull Request resolved: pytorch#40903

This PR continues the work of pytorch#38467 - decoupling Autograd and Trace for manually registered ops.
ghstack-source-id: 107158638

Test Plan: CI

Differential Revision: D22354804

fbshipit-source-id: f5ea45ade2850296c62707a2a4449d7d67a9f5b5
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants