Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 69 additions & 29 deletions torch_xla/csrc/aten_autograd_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,6 @@ namespace aten_autograd_ops {
torch::Tensor EinsumAutogradFunction::forward(
torch::autograd::AutogradContext* ctx, const c10::string_view equation,
at::TensorList tensors) {
// We might get here after the autograd graph was created, but before
// python gets a chance to run. Give python a chance to run.
if (tensors[0].key_set().has(c10::DispatchKey::Python)) {
return at::redispatch::einsum(c10::DispatchKeySet(c10::DispatchKey::Python),
equation, tensors);
}
std::string eq_str = std::string(equation);
ctx->saved_data["equation"] = eq_str;

Expand Down Expand Up @@ -74,13 +68,6 @@ torch::Tensor MaxPool2dAutogradFunction::forward(
torch::autograd::AutogradContext* ctx, torch::Tensor self,
torch::IntArrayRef kernel_size, torch::IntArrayRef stride,
torch::IntArrayRef padding, torch::IntArrayRef dilation, bool ceil_mode) {
// We might get here after the autograd graph was created, but before
// python gets a chance to run. Give python a chance to run.
if (self.key_set().has(c10::DispatchKey::Python)) {
return at::redispatch::max_pool2d(
c10::DispatchKeySet(c10::DispatchKey::Python), self, kernel_size,
stride, padding, dilation, ceil_mode);
}
ctx->saved_data["kernel_size"] = kernel_size;
ctx->saved_data["stride"] = stride;
ctx->saved_data["padding"] = padding;
Expand All @@ -99,11 +86,25 @@ torch::Tensor MaxPool2dAutogradFunction::forward(
return std::get<0>(results);
}
ctx->save_for_backward({self});
auto outputs = tensor_methods::max_pool_nd(
bridge::GetXlaTensor(self), /*spatial_dim_count=*/2,
XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride),
XlaHelpers::I64List(padding), ceil_mode);
return bridge::AtenFromXlaTensor(std::get<0>(outputs));
auto self_keyset = self.key_set();
// This is a bit fragile: Ideally, we would figure out a way to plumb
// the DispatchKeySet from the autograd kernel directly here.
// Instead, I enumerated the list of dispatch keys below autograd
// that XLA could reasonably run into,
// and mask them with the current tensor's keyset.
auto mask = c10::DispatchKeySet({
c10::DispatchKey::XLA,
c10::DispatchKey::Python,
c10::DispatchKey::Functionalize,
});
auto ks = self_keyset & mask;
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("xla::max_pool2d_forward", "")
.typed<at::Tensor(at::Tensor, at::IntArrayRef, at::IntArrayRef,
at::IntArrayRef, at::IntArrayRef, bool)>();
return op.redispatch(ks, self, kernel_size, stride, padding, dilation,
ceil_mode);
}

torch::autograd::variable_list MaxPool2dAutogradFunction::backward(
Expand All @@ -127,10 +128,21 @@ torch::autograd::variable_list MaxPool2dAutogradFunction::backward(
padding, dilation,
ceil_mode, indices);
}
grad = bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward(
bridge::GetXlaTensor(grad_output[0]), bridge::GetXlaTensor(self),
/*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size),
XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode));

static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("xla::max_pool2d_backward", "")
.typed<at::Tensor(at::Tensor, at::Tensor, at::IntArrayRef,
at::IntArrayRef, at::IntArrayRef, bool)>();
auto self_keyset = self.key_set();
auto mask = c10::DispatchKeySet({
c10::DispatchKey::XLA,
c10::DispatchKey::Python,
c10::DispatchKey::Functionalize,
});
auto ks = self_keyset & mask;
grad = op.redispatch(ks, grad_output[0], self, kernel_size, stride, padding,
ceil_mode);

torch::Tensor undef;
torch::autograd::variable_list grad_inputs = {grad, undef, undef,
Expand All @@ -142,13 +154,6 @@ torch::Tensor MaxPool3dAutogradFunction::forward(
torch::autograd::AutogradContext* ctx, torch::Tensor self,
torch::IntArrayRef kernel_size, torch::IntArrayRef stride,
torch::IntArrayRef padding, torch::IntArrayRef dilation, bool ceil_mode) {
// We might get here after the autograd graph was created, but before
// python gets a chance to run. Give python a chance to run.
if (self.key_set().has(c10::DispatchKey::Python)) {
return at::redispatch::max_pool3d(
c10::DispatchKeySet(c10::DispatchKey::Python), self, kernel_size,
stride, padding, dilation, ceil_mode);
}
ctx->saved_data["kernel_size"] = kernel_size;
ctx->saved_data["stride"] = stride;
ctx->saved_data["padding"] = padding;
Expand Down Expand Up @@ -206,5 +211,40 @@ torch::autograd::variable_list MaxPool3dAutogradFunction::backward(
return grad_inputs;
}

torch::Tensor max_pool2d_forward(torch::Tensor self,
torch::IntArrayRef kernel_size,
torch::IntArrayRef stride,
torch::IntArrayRef padding,
torch::IntArrayRef dilation, bool ceil_mode) {
auto outputs = tensor_methods::max_pool_nd(
bridge::GetXlaTensor(self), /*spatial_dim_count=*/2,
XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride),
XlaHelpers::I64List(padding), ceil_mode);
return bridge::AtenFromXlaTensor(std::get<0>(outputs));
}

torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self,
torch::IntArrayRef kernel_size,
torch::IntArrayRef stride,
torch::IntArrayRef padding, bool ceil_mode) {
auto grad = bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward(
bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self),
/*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size),
XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode));
return grad;
}

TORCH_LIBRARY(xla, m) {
m.def(
"max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], "
"int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_forward)));

m.def(
"max_pool2d_backward(Tensor grad_output, Tensor self, int[2] "
"kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) "
"-> Tensor",
torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_backward)));
}
} // namespace aten_autograd_ops
} // namespace torch_xla