Skip to content

Commit 2bdc99d

Browse files
committed
Update on "split out functionalization codegen to use view_copy operators"
This PR splits the functionalization codegen into 2 pieces: (1) Vanilla functionalization will now always turn view ops into "view_copy" ops. (2) For functorch to "reapply views underneath the pass", I added a new dispatch key, "FunctionalizeAddBackViews". I codegen a kernel to that key for every view_copy operator that just calls back into the view op. All other ops get a fallthrough kernel. Also - the codegen will now unconditionally register CompositeImplicitAutograd kernels directly to the functionalization keys, so we "always decompose" before hitting functionalization. Otherwise, we might break and accidentally send "view" calls to the backend, if we decompose an op into a view underneath the functionalization pass. The important changes are in `gen.py` and `gen_functionalization_type.py` - most of the other changes are just plumbing `{view}_copy` everywhere. I also updated `test_functionalization.py`, and added expecttests for the "add back views" case. One thing about the `AddBackViews` key - right now, I add it into the TLS include set. The other option would be to try to add it directly to the tensors, but that's kind of hard: putting it on the `FunctionalTensorWrapper` doesn't help, because the functionalization pass will unwrap when it calls back into the dispatcher, and run on the "inner tensor" (maybe we could modify the inner tensor's keyset and add the `AddBackViews` key when functionalization happens, instead?) I also have an accompanying functorch change here: pytorch/functorch#678 Differential Revision: [D35419652](https://our.internmc.facebook.com/intern/diff/D35419652) [ghstack-poisoned]
2 parents 999223d + 3f820ee commit 2bdc99d

39 files changed

Lines changed: 1280 additions & 278 deletions

.github/actions/setup-ssh/action.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ runs:
1414
uses: seemethere/add-github-ssh-key@v1
1515
with:
1616
GITHUB_TOKEN: ${{ inputs.github-secret }}
17+
activate-with-label: false

.lintrunner.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
[init_config]
2+
last_hash = "7d8b366223b22aaab788b0a7efec064dfef38b24"
3+
14
[[linter]]
25
code = 'FLAKE8'
36
include_patterns = ['**/*.py']
@@ -269,7 +272,10 @@ exclude_patterns=[
269272
'third_party/**',
270273
'**/*.expect',
271274
'**/*.ipynb',
275+
'**/*.ptl',
272276
'tools/clang_format_hash/**',
277+
'test/cpp/jit/upgrader_models/*.ptl',
278+
'test/cpp/jit/upgrader_models/*.ptl.ff',
273279
]
274280
command = [
275281
'python3',
@@ -285,6 +291,8 @@ exclude_patterns = [
285291
'**/contrib/**',
286292
'**/*.diff',
287293
'third_party/**',
294+
'test/cpp/jit/upgrader_models/*.ptl',
295+
'test/cpp/jit/upgrader_models/*.ptl.ff',
288296
]
289297
command = [
290298
'python3',
@@ -310,6 +318,8 @@ exclude_patterns = [
310318
'third_party/**',
311319
'**/.gitattributes',
312320
'**/.gitmodules',
321+
'test/cpp/jit/upgrader_models/*.ptl',
322+
'test/cpp/jit/upgrader_models/*.ptl.ff',
313323
]
314324
command = [
315325
'python3',

aten/src/ATen/native/Activation.cpp

Lines changed: 82 additions & 214 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,8 @@ DEFINE_DISPATCH(silu_stub);
202202
DEFINE_DISPATCH(silu_backward_stub);
203203
DEFINE_DISPATCH(mish_stub);
204204
DEFINE_DISPATCH(mish_backward_stub);
205+
DEFINE_DISPATCH(prelu_cpu_stub);
206+
DEFINE_DISPATCH(prelu_backward_cpu_stub);
205207

206208
TORCH_IMPL_FUNC(elu_out) (
207209
const Tensor& self, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale, const Tensor& result
@@ -595,253 +597,119 @@ TORCH_IMPL_FUNC(threshold_backward_out)(const Tensor& grad, const Tensor& self,
595597
threshold_stub(device_type(), *this, threshold, 0);
596598
}
597599

598-
// -----------------------------------
599-
// prelu forward
600-
// -----------------------------------
601-
template <typename scalar_t>
602-
void inline prelu_cpu_kernel_share_weights(
603-
Tensor& result,
604-
const Tensor& input,
605-
const Tensor& weight) {
606-
607-
int64_t input_numel = input.numel();
608-
auto result_data = result.data_ptr<scalar_t>();
609-
auto input_data = input.data_ptr<scalar_t>();
610-
auto weight_val = weight.data_ptr<scalar_t>()[0];
611-
612-
at::parallel_for(0, input_numel, 1000, [&](int64_t start, int64_t end) {
613-
for (const auto i : c10::irange(start, end)) {
614-
scalar_t input_data_val = input_data[i];
615-
// to allow for compiler optimization, here splitting into two lines:
616-
scalar_t r = (input_data_val > 0) ? scalar_t(1) : weight_val;
617-
result_data[i] = r * input_data_val;
618-
}
619-
});
620-
}
621-
622-
template <typename scalar_t>
623-
void inline prelu_cpu_kernel_multi_weights(
624-
Tensor& result,
625-
const Tensor& input,
626-
const Tensor& weight,
627-
int64_t input_dim0_size,
628-
int64_t channel_size,
629-
int64_t input_stride0,
630-
int64_t input_stride1) {
631-
632-
scalar_t* result_data = result.data_ptr<scalar_t>();
633-
scalar_t* input_data = input.data_ptr<scalar_t>();
634-
scalar_t* weight_data = weight.data_ptr<scalar_t>();
635-
636-
auto loop = [&](int64_t start, int64_t end) {
637-
for (const auto i : c10::irange(start, end)) {
638-
int64_t offset = i * channel_size * input_stride1;
639-
scalar_t* n_input_data = input_data + offset;
640-
scalar_t* n_result_data = result_data + offset;
641-
for (const auto j : c10::irange(channel_size)) {
642-
for (const auto k : c10::irange(input_stride1)) {
643-
// to allow for compiler optimization, here splitting into two lines:
644-
scalar_t w = (n_input_data[k] > 0) ? scalar_t(1) : weight_data[j];
645-
n_result_data[k] = w * n_input_data[k];
646-
}
647-
n_input_data += input_stride1;
648-
n_result_data += input_stride1;
649-
}
650-
}
651-
};
652-
if (input.numel() > 1000) {
653-
at::parallel_for(0, input_dim0_size, 0, loop);
654-
} else {
655-
loop(0, input_dim0_size);
656-
}
657-
}
658-
659600
Tensor prelu_cpu(const Tensor& self, const Tensor& weight_) {
660-
auto input = self.contiguous();
661-
auto weight = weight_.contiguous();
662-
663-
TORCH_CHECK(input.is_contiguous());
664-
TORCH_CHECK(weight.is_contiguous());
601+
int64_t weight_num = weight_.numel();
602+
Tensor result = at::empty_like(self, self.suggest_memory_format());
665603

666-
int64_t weight_num = weight.numel();
667-
Tensor result = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
668-
auto strides = input.strides();
669-
670-
// case1: shared weight for all channels
671-
if (weight_num == 1) {
672-
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "prelu_cpu", [&] {
673-
prelu_cpu_kernel_share_weights<scalar_t>(result, input, weight);
674-
});
675-
}
676-
else { // case2: multiple weights, one for each channel
677-
int64_t input_ndim = input.dim();
604+
if (weight_num != 1) {
605+
int64_t input_ndim = self.dim();
678606
TORCH_CHECK(input_ndim > 0, "Not allow zero-dim input tensor.");
679607

680608
int64_t channel_size = 1; // channel_size default to 1
681-
int64_t input_dim0_size = 1, input_stride0 = 1, input_stride1 = 1;
682-
683609
if (input_ndim > 1) {
684-
channel_size = input.size(1); // channel is the 2nd dim of input
685-
input_dim0_size = input.size(0);
686-
input_stride0 = strides[0];
687-
input_stride1 = strides[1];
610+
channel_size = self.size(1); // channel is the 2nd dim of input
688611
}
689612
TORCH_CHECK(channel_size == weight_num,
690613
"Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num,
691614
" and channel size = ", channel_size, ".");
692-
693-
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "prelu_cpu", [&] {
694-
prelu_cpu_kernel_multi_weights<scalar_t>(
695-
result,
696-
input,
697-
weight,
698-
input_dim0_size,
699-
channel_size,
700-
input_stride0,
701-
input_stride1);
702-
});
703615
}
704-
return result;
705-
}
706616

707-
// -----------------------------------
708-
// prelu backward
709-
// -----------------------------------
710-
template <typename scalar_t>
711-
void inline prelu_cpu_backward_kernel_share_weights(
712-
const Tensor& input,
713-
const Tensor& weight,
714-
const Tensor& grad_out,
715-
Tensor& input_grad,
716-
Tensor& weight_grad) {
717-
718-
int64_t input_numel = input.numel();
719-
auto input_data = input.data_ptr<scalar_t>();
720-
auto weight_val = weight.data_ptr<scalar_t>()[0];
721-
auto grad_out_data = grad_out.data_ptr<scalar_t>();
722-
auto input_grad_data = input_grad.data_ptr<scalar_t>();
723-
auto weight_grad_data = weight_grad.data_ptr<scalar_t>();
724-
725-
scalar_t sum = at::parallel_reduce(0, input_numel, 1000, scalar_t(0),
726-
[&](int64_t start, int64_t end, scalar_t ident) -> scalar_t {
727-
scalar_t partial_sum = ident;
728-
for (const auto i : c10::irange(start, end)) {
729-
scalar_t input_data_val = input_data[i];
730-
scalar_t grad_out_data_val = grad_out_data[i];
731-
// to allow for compiler optimization, here splitting into two lines:
732-
scalar_t w = (input_data_val > 0) ? scalar_t(1) : weight_val;
733-
input_grad_data[i] = w * grad_out_data_val;
734-
// to allow for compiler optimization, here splitting into two lines:
735-
scalar_t mask = (input_data_val > 0) ? scalar_t(0) : scalar_t(1);
736-
partial_sum += mask * input_data_val * grad_out_data_val;
737-
}
738-
return partial_sum;
739-
}, std::plus<scalar_t>());
740-
weight_grad_data[0] = sum;
741-
}
742-
743-
template <typename scalar_t>
744-
void inline prelu_cpu_backward_kernel_multi_weights(
745-
const Tensor& input,
746-
const Tensor& weight,
747-
const Tensor& grad_out,
748-
Tensor& input_grad,
749-
Tensor& weight_grad_collector,
750-
int64_t input_dim0_size,
751-
int64_t channel_size,
752-
int64_t input_stride0,
753-
int64_t input_stride1) {
754-
755-
auto input_data = input.data_ptr<scalar_t>();
756-
auto weight_data = weight.data_ptr<scalar_t>();
757-
auto grad_out_data = grad_out.data_ptr<scalar_t>();
758-
auto input_grad_data = input_grad.data_ptr<scalar_t>();
759-
auto weight_grad_collector_data = weight_grad_collector.data_ptr<scalar_t>();
760-
761-
auto loop = [&](int64_t start, int64_t end) {
762-
for (const auto i : c10::irange(start, end)) {
763-
for (const auto j : c10::irange(channel_size)) {
764-
for (const auto k : c10::irange(input_stride1)) {
765-
int64_t pos = i * input_stride0 + j * input_stride1 + k;
766-
scalar_t weight_data_val = weight_data[j];
767-
scalar_t input_data_val = input_data[pos];
768-
scalar_t grad_out_data_val = grad_out_data[pos];
769-
// to allow for compiler optimization, here splitting into two lines:
770-
scalar_t w = (input_data_val > 0) ? scalar_t(1) : weight_data_val;
771-
input_grad_data[pos] = w * grad_out_data_val;
772-
// to allow for compiler optimization, here splitting into two lines:
773-
scalar_t mask = (input_data_val > 0) ? scalar_t(0) : scalar_t(1);
774-
weight_grad_collector_data[pos] = mask * input_data_val * grad_out_data_val;
775-
}
776-
}
617+
const int64_t ndim = self.dim();
618+
// Helper to convert 1d tensors or scalar tensor to an nd tensor that broadcasts with input
619+
// All elements go into the channel dimension
620+
DimVector sizes(ndim, 1), strides(ndim, 0);
621+
auto as_nd = [&](const Tensor& t) {
622+
TORCH_INTERNAL_ASSERT(t.defined() && (t.dim() == 1 || t.dim() == 0));
623+
if (ndim >= 2) {
624+
sizes[1] = t.dim() == 1 ? t.sizes()[0] : 1;
625+
strides[1] = t.dim() == 1 ? t.strides()[0] : 0;
626+
return t.as_strided(sizes, strides);
777627
}
628+
return t.as_strided(sizes, strides);
778629
};
779-
if (input.numel() > 1000) {
780-
at::parallel_for(0, input_dim0_size, 0, loop);
630+
Tensor w;
631+
if (self.scalar_type() == ScalarType::BFloat16) {
632+
auto w_bf16 = at::empty(weight_.sizes(), weight_.options().dtype(ScalarType::BFloat16));
633+
w_bf16.copy_(weight_);
634+
w = weight_.defined() ? as_nd(w_bf16) :
635+
at::detail::scalar_tensor_static(1, self.scalar_type(), kCPU);
781636
} else {
782-
loop(0, input_dim0_size);
637+
w = weight_.defined() ? as_nd(weight_) :
638+
at::detail::scalar_tensor_static(1, self.scalar_type(), kCPU);
783639
}
640+
641+
auto iter = TensorIteratorConfig()
642+
.add_output(result)
643+
.add_input(self)
644+
.add_input(w)
645+
.build();
646+
prelu_cpu_stub(iter.device_type(), iter);
647+
return result;
784648
}
785649

786650
std::tuple<Tensor, Tensor> prelu_backward_cpu(const Tensor& grad_out_, const Tensor& self, const Tensor& weight_) {
787-
auto input = self.contiguous();
788-
auto grad_out = grad_out_.contiguous();
789-
auto weight = weight_.contiguous();
790-
791-
TORCH_CHECK(input.is_contiguous());
792-
TORCH_CHECK(grad_out.is_contiguous());
793-
TORCH_CHECK(weight.is_contiguous());
651+
int64_t weight_num = weight_.numel();
794652

795-
int64_t weight_num = weight.numel();
796-
auto strides = input.strides();
797-
auto dims = input.dim();
653+
Tensor input_grad = at::empty_like(self, self.suggest_memory_format());
654+
Tensor weight_grad = at::empty_like(weight_, at::MemoryFormat::Contiguous);
655+
Tensor weight_grad_collector = at::empty_like(self, at::MemoryFormat::Contiguous);
798656

799-
Tensor input_grad = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
800-
Tensor weight_grad = at::empty_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
801-
Tensor weight_grad_collector = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
802-
803-
// case1: shared parameter for all channels
804-
if (weight_num == 1) {
805-
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "prelu_backward_cpu", [&] {
806-
prelu_cpu_backward_kernel_share_weights<scalar_t>(input, weight, grad_out, input_grad, weight_grad);
807-
});
808-
}
809-
else { // case2: multiple parameters, one for each channel
810-
int64_t input_ndim = input.dim();
657+
if (weight_num != 1) {
658+
int64_t input_ndim = self.dim();
811659
TORCH_CHECK(input_ndim > 0, "Not allow zero-dim input tensor.");
812660

813661
int64_t channel_size = 1; // channel_size default to 1
814-
int64_t input_dim0_size = 1, input_stride0 = 1, input_stride1 = 1;
815-
816662
if (input_ndim > 1) {
817-
channel_size = input.size(1); // channel is the 2nd dim of input
818-
input_dim0_size = input.size(0);
819-
input_stride0 = strides[0];
820-
input_stride1 = strides[1];
663+
channel_size = self.size(1); // channel is the 2nd dim of input
821664
}
822665
TORCH_CHECK(channel_size == weight_num,
823666
"Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num,
824667
" and channel size = ", channel_size, ".");
668+
}
825669

826-
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "prelu_backward_cpu", [&] {
827-
prelu_cpu_backward_kernel_multi_weights<scalar_t>(
828-
input,
829-
weight,
830-
grad_out,
831-
input_grad,
832-
weight_grad_collector,
833-
input_dim0_size,
834-
channel_size,
835-
input_stride0,
836-
input_stride1);
837-
});
670+
const int64_t ndim = self.dim();
671+
// Helper to convert 1d tensor or scalar tensor to an nd tensor that broadcasts with input
672+
// All elements go into the channel dimension
673+
DimVector sizes(ndim, 1), strides(ndim, 0);
674+
auto as_nd = [&](const Tensor& t) {
675+
TORCH_INTERNAL_ASSERT(t.defined() && (t.dim() == 1 || t.dim() == 0));
676+
if (ndim >= 2) {
677+
sizes[1] = t.dim() == 1 ? t.sizes()[0] : 1;
678+
strides[1] = t.dim() == 1 ? t.strides()[0] : 0;
679+
return t.as_strided(sizes, strides);
680+
}
681+
return t.as_strided(sizes, strides);
682+
};
683+
Tensor w;
684+
if (self.scalar_type() == ScalarType::BFloat16) {
685+
auto w_bf16 = at::empty(weight_.sizes(), weight_.options().dtype(ScalarType::BFloat16));
686+
w_bf16.copy_(weight_);
687+
w = weight_.defined() ? as_nd(w_bf16) :
688+
at::detail::scalar_tensor_static(1, self.scalar_type(), kCPU);
689+
} else {
690+
w = weight_.defined() ? as_nd(weight_) :
691+
at::detail::scalar_tensor_static(1, self.scalar_type(), kCPU);
692+
}
693+
694+
auto iter = TensorIteratorConfig()
695+
.add_output(input_grad)
696+
.add_output(weight_grad_collector)
697+
.add_input(self)
698+
.add_input(grad_out_)
699+
.add_input(w)
700+
.build();
701+
702+
prelu_backward_cpu_stub(iter.device_type(), iter);
703+
704+
if (weight_num == 1) {
705+
weight_grad.fill_(weight_grad_collector.sum());
706+
} else {
838707
// update weight_grad
839708
std::vector<int64_t> reduce_dims;
709+
int64_t input_ndim = self.dim();
840710
reduce_dims.push_back(0);
841-
if (dims > 2) {
842-
for (const auto i : c10::irange(2, dims)) {
843-
reduce_dims.push_back(i);
844-
}
711+
if (input_ndim > 2) {
712+
for(int64_t i = 2; i < input_ndim; i++) reduce_dims.push_back(i);
845713
}
846714
weight_grad = weight_grad_collector.sum(reduce_dims);
847715
}

aten/src/ATen/native/Activation.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ DECLARE_DISPATCH(structured_activation_fn, silu_stub);
8080
DECLARE_DISPATCH(structured_activation_backward_fn, silu_backward_stub);
8181
DECLARE_DISPATCH(structured_activation_fn, mish_stub);
8282
DECLARE_DISPATCH(activation_backward_fn, mish_backward_stub);
83+
DECLARE_DISPATCH(activation_fn, prelu_cpu_stub);
84+
DECLARE_DISPATCH(activation_backward_fn, prelu_backward_cpu_stub);
8385

8486
} // namespace native
8587

0 commit comments

Comments
 (0)