Skip to content

Commit 621a37b

Browse files
Rebase to master on "Replacing assertEqual with assertEqualIgnoreType wherever types missmatch"
Differential Revision: [D21477060](https://our.internmc.facebook.com/intern/diff/D21477060) [ghstack-poisoned]
2 parents 1ee5bec + dfeb600 commit 621a37b

38 files changed

Lines changed: 883 additions & 657 deletions

aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h

Lines changed: 33 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,18 @@ using supported_primitive_arg_types = guts::typelist::typelist<
4242
at::Tensor,
4343
at::Scalar,
4444
c10::QScheme,
45-
c10::ScalarType>;
46-
47-
template <class T, bool AllowDeprecatedTypes, class Enable = void>
48-
struct assert_is_valid_input_type {
49-
assert_is_valid_input_type() {
50-
auto tmap = c10::getCustomClassTypeMap();
51-
TORCH_CHECK(
52-
c10::isCustomClassRegistered<T>(),
53-
"Tried to use undefined class as input argument");
54-
}
55-
};
56-
57-
template<class T, bool AllowDeprecatedTypes>
58-
struct assert_is_valid_input_type<T, AllowDeprecatedTypes, std::enable_if_t<guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
59-
// everything is ok, this is a primitive type
45+
c10::ScalarType
46+
>;
47+
48+
template<class T, bool AllowDeprecatedTypes, class Enable = void> struct assert_is_valid_input_type {
49+
assert_is_valid_input_type() {
50+
guts::if_constexpr<guts::typelist::contains<supported_primitive_arg_types, T>::value>([] {
51+
/* everything is ok, this is a primitive type */
52+
}, /* else */ [] {
53+
auto tmap = c10::getCustomClassTypeMap();
54+
TORCH_CHECK(c10::isCustomClassRegistered<T>(), "Tried to use undefined class as input argument");
55+
});
56+
}
6057
};
6158

6259
template<class T, bool AllowDeprecatedTypes>
@@ -139,16 +136,15 @@ struct assert_is_valid_input_type {
139136

140137
template<class T, bool AllowDeprecatedTypes, class Enable = void> struct assert_is_valid_output_type {
141138
assert_is_valid_output_type() {
142-
auto tmap = getCustomClassTypeMap();
143-
TORCH_CHECK(c10::isCustomClassRegistered<T>(), "Tried to use undefined class as output");
139+
guts::if_constexpr<guts::typelist::contains<supported_primitive_arg_types, T>::value>([] {
140+
/* everything is ok, this is a primitive type */
141+
}, /* else */ [] {
142+
auto tmap = getCustomClassTypeMap();
143+
TORCH_CHECK(c10::isCustomClassRegistered<T>(), "Tried to use undefined class as output");
144+
});
144145
}
145146
};
146147

147-
template<class T, bool AllowDeprecatedTypes>
148-
struct assert_is_valid_output_type<T, AllowDeprecatedTypes, std::enable_if_t<guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
149-
// everything is ok, this is a primitive type
150-
};
151-
152148
template<class T, bool AllowDeprecatedTypes>
153149
struct assert_is_valid_output_type<c10::optional<T>, AllowDeprecatedTypes>
154150
: assert_is_valid_output_type<T, AllowDeprecatedTypes> {};
@@ -268,33 +264,30 @@ struct assert_is_valid_input_type {
268264
torch::jit::push(*stack, return_to_ivalue<OutputTypes, AllowDeprecatedTypes>(std::move(std::get<indices>(output)))...);
269265
}
270266
};
271-
272-
template<class KernelFunctor, bool AllowDeprecatedTypes, class Enable = void> struct make_boxed_from_unboxed_functor final {};
273-
274-
// SFINAE version for kernels that return an output
275-
template<class KernelFunctor, bool AllowDeprecatedTypes>
276-
struct make_boxed_from_unboxed_functor<KernelFunctor, AllowDeprecatedTypes, std::enable_if_t<!std::is_same<void, typename guts::infer_function_traits_t<KernelFunctor>::return_type>::value>> final {
277-
static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
278-
279-
static void call(OperatorKernel* functor, const OperatorHandle&, Stack* stack) {
280-
constexpr size_t num_inputs = guts::infer_function_traits_t<KernelFunctor>::number_of_parameters;
281-
KernelFunctor* functor_ = static_cast<KernelFunctor*>(functor);
282-
auto output = call_functor_with_args_from_stack<KernelFunctor, AllowDeprecatedTypes>(functor_, stack);
283-
torch::jit::drop(*stack, num_inputs);
284-
push_outputs<typename guts::infer_function_traits_t<KernelFunctor>::return_type, AllowDeprecatedTypes>::call(std::move(output), stack);
267+
template<bool AllowDeprecatedTypes>
268+
struct push_outputs<void, AllowDeprecatedTypes> final {
269+
static void call(int /*dummy*/, Stack* /*stack*/) {
285270
}
286271
};
287272

288-
// SFINAE version for kernels that don't return an output
289273
template<class KernelFunctor, bool AllowDeprecatedTypes>
290-
struct make_boxed_from_unboxed_functor<KernelFunctor, AllowDeprecatedTypes, std::enable_if_t<std::is_same<void, typename guts::infer_function_traits_t<KernelFunctor>::return_type>::value>> final {
274+
struct make_boxed_from_unboxed_functor final {
291275
static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
292276

293277
static void call(OperatorKernel* functor, const OperatorHandle&, Stack* stack) {
294278
constexpr size_t num_inputs = guts::infer_function_traits_t<KernelFunctor>::number_of_parameters;
295279
KernelFunctor* functor_ = static_cast<KernelFunctor*>(functor);
296-
call_functor_with_args_from_stack<KernelFunctor, AllowDeprecatedTypes>(functor_, stack);
297-
torch::jit::pop(*stack, num_inputs);
280+
281+
using ReturnType = typename guts::infer_function_traits_t<KernelFunctor>::return_type;
282+
constexpr bool has_outputs = !std::is_same<void, ReturnType>::value;
283+
guts::if_constexpr<has_outputs>([&] (auto _) {
284+
auto output = call_functor_with_args_from_stack<KernelFunctor, AllowDeprecatedTypes>(functor_, _(stack));
285+
torch::jit::drop(*stack, num_inputs);
286+
push_outputs<ReturnType, AllowDeprecatedTypes>::call(std::move(output), stack);
287+
}, /* else */ [&] {
288+
call_functor_with_args_from_stack<KernelFunctor, AllowDeprecatedTypes>(functor_, stack);
289+
torch::jit::drop(*stack, num_inputs);
290+
});
298291
}
299292
};
300293

aten/src/ATen/core/ivalue_inl.h

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -283,17 +283,20 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
283283

284284
void setError(FutureError&& error) {
285285
std::unique_lock<std::mutex> lock(mutex_);
286-
AT_ASSERT(!completed());
287-
completed_ = true;
288-
error_ = std::move(error);
289-
290-
std::vector<std::function<void(void)>> cbs;
291-
cbs.swap(callbacks_);
292-
lock.unlock();
286+
setErrorInternal(std::move(error), lock);
287+
}
293288

294-
finished_cv_.notify_all();
295-
for (auto& callback : cbs) {
296-
callback();
289+
void setErrorIfNeeded(std::string errorMsg) {
290+
std::unique_lock<std::mutex> lock(mutex_);
291+
if (completed_) {
292+
// This should be rare and shouldn't cause log spew. Its important to
293+
// log errors and thats why we have this log here.
294+
LOG(INFO) << "Skipping setting following error on the Future since " <<
295+
"it is already marked completed (this is not neccessarily an error): "
296+
<< errorMsg;
297+
return;
298+
} else {
299+
setErrorInternal(FutureError(std::move(errorMsg)), lock);
297300
}
298301
}
299302

@@ -307,6 +310,15 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
307310
return value_;
308311
}
309312

313+
const IValue& constValue() {
314+
std::unique_lock<std::mutex> lock(mutex_);
315+
AT_ASSERT(completed());
316+
if (error_) {
317+
throw *error_;
318+
}
319+
return value_;
320+
}
321+
310322
/**
311323
* Add a callback to the future.
312324
* The callbacks will be executed once the future completes.
@@ -347,6 +359,23 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
347359
}
348360

349361
private:
362+
void setErrorInternal(
363+
FutureError error,
364+
std::unique_lock<std::mutex>& lock) {
365+
AT_ASSERT(!completed());
366+
completed_ = true;
367+
error_ = std::move(error);
368+
369+
std::vector<std::function<void(void)>> cbs;
370+
cbs.swap(callbacks_);
371+
lock.unlock();
372+
373+
finished_cv_.notify_all();
374+
for (auto& callback : cbs) {
375+
callback();
376+
}
377+
}
378+
350379
mutable std::mutex mutex_;
351380
std::atomic_bool completed_ = {false}; // is this future complete
352381
std::condition_variable finished_cv_;

caffe2/python/onnx/tests/onnx_backend_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,16 @@
9696
'|test_.*negative_ind.*' # negative axis is not supported yet
9797
'|test_argmax_.*select_last_index.*' # unsupported case
9898
'|test_argmin_.*select_last_index_.*' # unsupported case
99+
'|test_celu.*' # unsupported case
100+
'|test_gathernd.*' # unsupported case
101+
'|test_greater_equal.*' # unsupported case
102+
'|test_inverse.*' # unsupported case
103+
'|test_less_equal.*' # unsupported case
104+
'|test_max_.*' # unsupported case
105+
'|test_min_.*' # unsupported case
106+
'|test_mean_square_distance_.*' # unsupported case
107+
'|test_softmax_cross_entropy.*' # unsupported case
108+
'|test_unfoldtodepth.*' # unsupported case
99109
'|test_.*gradient.*' # no support for gradient op in c2-onnx
100110
')')
101111

mypy.ini

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,6 @@ python_version = 3.6
2929
# Extension modules without stubs.
3030
#
3131

32-
[mypy-torch._C]
33-
ignore_missing_imports = True
34-
3532
[mypy-torch._C._jit_tree_views]
3633
ignore_missing_imports = True
3734

@@ -193,6 +190,9 @@ ignore_errors = True
193190
[mypy-torch.utils.bundled_inputs]
194191
ignore_errors = True
195192

193+
[mypy-torch.utils.mkldnn]
194+
ignore_errors = True
195+
196196
[mypy-torch.utils.tensorboard.*]
197197
ignore_errors = True
198198

@@ -253,6 +253,9 @@ ignore_errors = True
253253
[mypy-torch.utils.hipify.hipify_python]
254254
ignore_errors = True
255255

256+
[mypy-torch.autograd]
257+
ignore_errors = True
258+
256259
[mypy-torch.autograd._functions.tensor]
257260
ignore_errors = True
258261

0 commit comments

Comments
 (0)