Skip to content

Commit 7082381

Browse files
committed
Update on "quantized hardsigmoid, take 2"
Summary: Adds quantized implementation of hardsigmoid. Original PR was #34607 and had to be reverted for a test breakage, trying again. Test Plan: tests benchmarks Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D20514212](https://our.internmc.facebook.com/intern/diff/D20514212) [ghstack-poisoned]
2 parents 55f379b + 5d92a6c commit 7082381

102 files changed

Lines changed: 1722 additions & 910 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/lint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ jobs:
105105
run: |
106106
set -eux
107107
pip install flake8
108-
rm -rf .circleci
108+
rm -rf .circleci tools/clang_format_new.py
109109
flake8 --exit-zero > ${GITHUB_WORKSPACE}/flake8-output.txt
110110
cat ${GITHUB_WORKSPACE}/flake8-output.txt
111111
- name: Add annotations

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,3 +252,6 @@ TAGS
252252

253253
# clang-format storage location used by apply_clang_format.py
254254
.clang-format-bin
255+
256+
# clangd background index
257+
.clangd/

.jenkins/pytorch/test.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ test_aten() {
182182
}
183183

184184
test_torchvision() {
185-
pip_install --user git+https://github.com/pytorch/vision.git@44a5bae933655ed7ff798669a43452b833f9ce01
185+
pip_install --user git+https://github.com/pytorch/vision.git@43e94b39bcdda519c093ca11d99dfa2568aa7258
186186
}
187187

188188
test_libtorch() {
@@ -270,9 +270,9 @@ elif [[ "${BUILD_ENVIRONMENT}" == *libtorch* ]]; then
270270
# TODO: run some C++ tests
271271
echo "no-op at the moment"
272272
elif [[ "${BUILD_ENVIRONMENT}" == *-test1 || "${JOB_BASE_NAME}" == *-test1 ]]; then
273-
test_torchvision
274273
test_python_nn
275274
elif [[ "${BUILD_ENVIRONMENT}" == *-test2 || "${JOB_BASE_NAME}" == *-test2 ]]; then
275+
test_torchvision
276276
test_python_all_except_nn
277277
test_aten
278278
test_libtorch

android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ namespace {
2626
struct JITCallGuard {
2727
// AutoGrad is disabled for mobile by default.
2828
torch::autograd::AutoGradMode no_autograd_guard{false};
29+
// VariableType dispatch is not included in default mobile build. We need set
30+
// this guard globally to avoid dispatch error (only for dynamic dispatch).
31+
// Thanks to the unification of Variable class and Tensor class it's no longer
32+
// required to toggle the NonVariableTypeMode per op - so it doesn't hurt to
33+
// always set NonVariableTypeMode for inference only use case.
34+
torch::AutoNonVariableTypeMode non_var_guard{true};
2935
// Disable graph optimizer to ensure list of unused ops are not changed for
3036
// custom mobile build.
3137
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
@@ -111,11 +117,11 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
111117
/* need_inputs */ false,
112118
/* sampled */ false);
113119
#endif
114-
JITCallGuard guard;
115120
}
116121

117122
PytorchJni(facebook::jni::alias_ref<jstring> modelPath) {
118123
preModuleLoadSetup();
124+
JITCallGuard guard;
119125
module_ = torch::jit::load(std::move(modelPath->toStdString()));
120126
module_.eval();
121127
}
@@ -147,6 +153,7 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
147153
"Could not get buffer for asset '%s'",
148154
assetName->toStdString().c_str());
149155
}
156+
JITCallGuard guard;
150157
module_ = torch::jit::load(torch::make_unique<MemoryReadAdapter>(
151158
assetBuffer, AAsset_getLength(asset)));
152159
AAsset_close(asset);

android/pytorch_android/src/main/cpp/pytorch_jni_lite.cpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,23 @@
1212

1313
#include "pytorch_jni_common.h"
1414

15-
using namespace pytorch_jni;
16-
1715
namespace pytorch_jni {
1816

17+
namespace {
18+
19+
struct LiteJITCallGuard {
20+
// VariableType dispatch is not included in default mobile build. We need set
21+
// this guard globally to avoid dispatch error (only for dynamic dispatch).
22+
// Thanks to the unification of Variable class and Tensor class it's no longer
23+
// required to toggle the NonVariableTypeMode per op - so it doesn't hurt to
24+
// always set NonVariableTypeMode for inference only use case.
25+
// TODO: avoid having to set this guard for custom mobile build with mobile
26+
// interpreter.
27+
torch::AutoNonVariableTypeMode non_var_guard{true};
28+
};
29+
30+
} // namespace
31+
1932
class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
2033
private:
2134
friend HybridBase;
@@ -31,6 +44,7 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
3144
}
3245

3346
PytorchJni(facebook::jni::alias_ref<jstring> modelPath) {
47+
LiteJITCallGuard guard;
3448
module_ = torch::jit::_load_for_mobile(std::move(modelPath->toStdString()));
3549
}
3650

@@ -55,8 +69,7 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
5569
}
5670

5771
auto output = [&]() {
58-
torch::autograd::AutoGradMode guard(false);
59-
at::AutoNonVariableTypeMode non_var_type_mode(true);
72+
LiteJITCallGuard guard;
6073
return module_.forward(inputs);
6174
}();
6275
return JIValue::newJIValueFromAtIValue(output);
@@ -78,7 +91,7 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
7891
}
7992
if (auto method = module_.find_method(methodName)) {
8093
auto output = [&]() {
81-
at::AutoNonVariableTypeMode non_var_type_mode(true);
94+
LiteJITCallGuard guard;
8295
return module_.run_method(methodName, inputs);
8396
}();
8497
return JIValue::newJIValueFromAtIValue(output);

aten/src/ATen/core/List.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,10 @@ class ListIterator final : public std::iterator<std::random_access_iterator_tag,
139139
ListElementReference<T, Iterator> operator*() const {
140140
return {iterator_};
141141
}
142+
143+
ListElementReference<T, Iterator> operator[](typename List<T>::size_type offset) const {
144+
return {iterator_ + offset};
145+
}
142146

143147
private:
144148
explicit ListIterator(Iterator iterator): iterator_(std::move(iterator)) {}

aten/src/ATen/native/BinaryOps.cpp

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,27 @@ Tensor& add_(Tensor& self, const Tensor& other, Scalar alpha) {
5959
}
6060

6161
Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) {
62+
if (isIntegralType(result.scalar_type(), /*includeBool=*/ true)) {
63+
TORCH_WARN_ONCE(
64+
"Integer division of tensors using div or / is deprecated, ",
65+
"and in a future release div will perform true division as in Python 3. ",
66+
"Use true_divide or floor_divide (// in Python) instead.");
67+
}
68+
6269
auto iter = TensorIterator::binary_op(result, self, other,
6370
/*check_mem_overlap=*/true);
6471
div_stub(iter.device_type(), iter);
6572
return result;
6673
}
6774

6875
Tensor div(const Tensor& self, const Tensor& other) {
76+
if (isIntegralType(self.scalar_type(), /*includeBool=*/ true)
77+
&& isIntegralType(other.scalar_type(), /*includeBool=*/ true)) {
78+
TORCH_WARN_ONCE(
79+
"Integer division of tensors using div or / is deprecated, ",
80+
"and in a future release div will perform true division as in Python 3. ",
81+
"Use true_divide or floor_divide (// in Python) instead.");
82+
}
6983
Tensor result;
7084
auto iter = TensorIterator::binary_op(result, self, other);
7185
div_stub(iter.device_type(), iter);
@@ -94,13 +108,6 @@ Tensor& remainder_(Tensor& self, const Tensor& other) {
94108
return native::remainder_out(self, self, other);
95109
}
96110

97-
Tensor truncate(const Tensor& tensor) {
98-
if (tensor.is_floating_point()) {
99-
return tensor.trunc();
100-
}
101-
return tensor;
102-
}
103-
104111
Tensor& true_divide_out(Tensor& result, const Tensor& self, const Tensor& divisor) {
105112
TORCH_CHECK(!isIntegralType(result.scalar_type(), /*includeBool=*/ true),
106113
"True division requires a floating output type, but got ",
@@ -131,14 +138,34 @@ Tensor true_divide(const Tensor& self, const Tensor& divisor) {
131138
return iter.output();
132139
}
133140

134-
Tensor floor_divide(const Tensor& input, const Tensor& other) {
135-
Tensor out = input / other;
136-
return truncate(out);
141+
Tensor& floor_divide_out(Tensor& result, const Tensor& self, const Tensor& other) {
142+
auto iter = TensorIterator::binary_op(result, self, other,
143+
/*check_mem_overlap=*/true);
144+
div_stub(iter.device_type(), iter);
145+
146+
if (result.is_floating_point()) {
147+
result.trunc_();
148+
}
149+
150+
return result;
151+
}
152+
153+
Tensor floor_divide(const Tensor& self, const Tensor& other) {
154+
Tensor result;
155+
auto iter = TensorIterator::binary_op(result, self, other);
156+
157+
div_stub(iter.device_type(), iter);
158+
159+
auto out = iter.output();
160+
if (out.is_floating_point()) {
161+
out.trunc_();
162+
}
163+
164+
return out;
137165
}
138166

139-
Tensor floor_divide(const Tensor& input, Scalar other) {
140-
Tensor out = input / other;
141-
return truncate(out);
167+
Tensor& floor_divide_(Tensor& self, const Tensor& other) {
168+
return native::floor_divide_out(self, self, other);
142169
}
143170

144171
Tensor& mul_out(Tensor& result, const Tensor& self, const Tensor& other) {
@@ -661,6 +688,14 @@ Tensor min(const Tensor& self, const Tensor& other) {
661688

662689
Tensor& min_(Tensor& self, const Tensor& other) { return at::min_out(self, self, other); }
663690

691+
Tensor floor_divide(const Tensor& self, Scalar other) {
692+
return at::floor_divide(self, wrapped_scalar_tensor(other));
693+
}
694+
695+
Tensor& floor_divide_(Tensor& self, Scalar other) {
696+
return at::floor_divide_out(self, self, wrapped_scalar_tensor(other));
697+
}
698+
664699
Tensor& fmod_out(Tensor & result, const Tensor& self, const Tensor& other) {
665700
auto iter = TensorIterator::binary_op(result, self, other,
666701
/*check_mem_overlap=*/true);

aten/src/ATen/native/PointwiseOps.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,17 @@ Tensor& addcdiv_out(
6969
const Tensor& tensor1,
7070
const Tensor& tensor2,
7171
Scalar value) {
72+
if (isIntegralType(tensor1.scalar_type(), /*includeBool=*/ true)
73+
&& isIntegralType(tensor2.scalar_type(), /*includeBool=*/ true)) {
74+
TORCH_WARN_ONCE(
75+
"Integer division with addcdiv is deprecated, and in a future ",
76+
"release addcdiv will perform a true division of tensor1 and tensor2. ",
77+
"The current addcdiv behavior can be replicated using floor_divide ",
78+
"for integral inputs (self + value * tensor1 // tensor2) and ",
79+
"division for float inputs (self + value * tensor1 / tensor2). ",
80+
"The new addcdiv behavior can be implemented with true_divide ",
81+
"(self + value * torch.true_divide(tensor1, tensor2).");
82+
}
7283
checkBackend("addcdiv_cpu", result, self.options().backend());
7384
auto iter = at::TensorIterator();
7485
iter.set_check_mem_overlap(true);

aten/src/ATen/native/TensorFactories.cpp

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ Tensor empty_cpu(IntArrayRef size, const TensorOptions& options_, c10::optional<
134134

135135
auto memory_format = options.memory_format_opt().value_or(MemoryFormat::Contiguous);
136136
tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format);
137+
137138
return tensor;
138139
}
139140

@@ -342,18 +343,47 @@ Tensor& eye_out_cpu(Tensor& result, int64_t n, int64_t m) {
342343

343344
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ full ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
344345

345-
Tensor full(IntArrayRef size, Scalar fill_value, const TensorOptions& options) {
346-
if (options.layout() == kSparse) {
347-
AT_ERROR("full(...) is not implemented for sparse layout");
346+
namespace {
347+
348+
// Performs dtype inference for full
349+
TensorOptions infer_full_options(
350+
Scalar fill_value,
351+
const TensorOptions& options) {
352+
353+
if (!options.has_dtype()) {
354+
if (fill_value.isIntegral(true)) {
355+
TORCH_WARN_ONCE(
356+
"Deprecation warning: In a future PyTorch release torch.full ",
357+
"will no longer return tensors of floating dtype by default. ",
358+
"Instead, a bool fill_value will return a tensor of torch.bool dtype, ",
359+
"and an integral fill_value will return a tensor of torch.long dtype. ",
360+
"Set the optional `dtype` or `out` arguments to suppress this warning."
361+
);
362+
} else if (fill_value.isComplex()) {
363+
auto scalar_type = (get_default_dtype() == ScalarType::Double) ?
364+
ScalarType::ComplexDouble :
365+
ScalarType::ComplexFloat;
366+
return options.dtype(scalar_type);
367+
}
348368
}
349-
auto result = at::empty(size, options);
369+
370+
return options;
371+
}
372+
373+
} // anonymous namespace
374+
375+
Tensor full(IntArrayRef size, Scalar fill_value, const TensorOptions& options) {
376+
TORCH_CHECK(options.layout() != kSparse,
377+
"full(...) is not implemented for sparse layout");
378+
379+
auto result = at::empty(size, infer_full_options(fill_value, options));
350380
return result.fill_(fill_value);
351381
}
352382

353383
Tensor& full_out(Tensor& result, IntArrayRef size, Scalar fill_value) {
354-
if (result.is_sparse()) {
355-
AT_ERROR("full(...) is not implemented for sparse layout");
356-
}
384+
TORCH_CHECK(!result.is_sparse(),
385+
"full(...) is not implemented for sparse layout");
386+
357387
result.resize_(size);
358388
return result.fill_(fill_value);
359389
}
@@ -404,19 +434,19 @@ Tensor logspace(
404434
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ones ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
405435

406436
Tensor ones(IntArrayRef size, const TensorOptions& options) {
407-
return native::full(size, /*fill_value=*/1, options);
437+
return native::full(size, /*fill_value=*/1., options);
408438
}
409439

410440
Tensor& ones_out(Tensor& result, IntArrayRef size) {
411-
return native::full_out(result, size, /*fill_value=*/1);
441+
return native::full_out(result, size, /*fill_value=*/1.);
412442
}
413443

414444
Tensor ones_like(
415445
const Tensor& self,
416446
const TensorOptions& options,
417447
c10::optional<c10::MemoryFormat> optional_memory_format) {
418448
auto result = at::empty_like(self, options, optional_memory_format);
419-
return result.fill_(1);
449+
return result.fill_(1.);
420450
}
421451

422452
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ scalar_tensor ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -746,7 +776,7 @@ Tensor zeros(IntArrayRef size, const TensorOptions& options) {
746776

747777
Tensor& zeros_out(Tensor& result, IntArrayRef size) {
748778
if (result.is_sparse()) {
749-
result.sparse_resize_and_clear_(size, size.size(), 0);
779+
result.sparse_resize_and_clear_(size, size.size(), 0.);
750780
return result;
751781
} else {
752782
result.resize_(size);
@@ -960,22 +990,26 @@ Tensor full(
960990
Scalar fill_value,
961991
optional<DimnameList> names,
962992
const TensorOptions& options) {
963-
auto result = at::empty(size, names, options);
993+
994+
TORCH_CHECK(options.layout() != kSparse,
995+
"full(...) is not implemented for sparse layout");
996+
997+
auto result = at::empty(size, names, infer_full_options(fill_value, options));
964998
return result.fill_(fill_value);
965999
}
9661000

9671001
Tensor ones(
9681002
IntArrayRef size,
9691003
optional<DimnameList> names,
9701004
const TensorOptions& options) {
971-
return native::full(size, /*fill_value=*/1, names, options);
1005+
return native::full(size, /*fill_value=*/1., names, options);
9721006
}
9731007

9741008
Tensor zeros(
9751009
IntArrayRef size,
9761010
optional<DimnameList> names,
9771011
const TensorOptions& options) {
978-
return native::full(size, /*fill_value=*/0, names, options);
1012+
return native::full(size, /*fill_value=*/0., names, options);
9791013
}
9801014

9811015
Tensor randn(

0 commit comments

Comments
 (0)