Skip to content

Commit 491359f

Browse files
committed
Update on "Update backward formulas (Re #44444)"
Re #44444 Fixes #46144 Differential Revision: [D24285785](https://our.internmc.facebook.com/intern/diff/D24285785) [ghstack-poisoned]
2 parents 0377256 + edbc84a commit 491359f

64 files changed

Lines changed: 1138 additions & 870 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.circleci/config.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,7 @@ jobs:
645645
export CIRCLE_PR_NUMBER="${CIRCLE_PR_NUMBER:-}"
646646
export CIRCLE_BRANCH="$CIRCLE_BRANCH"
647647
export CIRCLE_JOB="$CIRCLE_JOB"
648+
export CIRCLE_WORKFLOW_ID="$CIRCLE_WORKFLOW_ID"
648649
cd workspace
649650
python test/print_test_stats.py test
650651
EOL

.circleci/scripts/binary_populate_env.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ export CIRCLE_TAG="${CIRCLE_TAG:-}"
167167
export CIRCLE_SHA1="$CIRCLE_SHA1"
168168
export CIRCLE_PR_NUMBER="${CIRCLE_PR_NUMBER:-}"
169169
export CIRCLE_BRANCH="$CIRCLE_BRANCH"
170+
export CIRCLE_WORKFLOW_ID="$CIRCLE_WORKFLOW_ID"
170171
# =================== The above code will be executed inside Docker container ===================
171172
EOL
172173

.circleci/scripts/upload_binary_size_to_scuba.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def build_message(size):
4141
"build_num": os.environ.get("CIRCLE_BUILD_NUM"),
4242
"sha1": os.environ.get("CIRCLE_SHA1"),
4343
"branch": os.environ.get("CIRCLE_BRANCH"),
44+
"workflow_id": os.environ.get("CIRCLE_WORKFLOW_ID"),
4445
},
4546
"int": {
4647
"time": int(time.time()),
@@ -115,6 +116,7 @@ def gen_messages():
115116
"build_num": os.environ.get("CIRCLE_BUILD_NUM"),
116117
"sha1": os.environ.get("CIRCLE_SHA1"),
117118
"branch": os.environ.get("CIRCLE_BRANCH"),
119+
"workflow_id": os.environ.get("CIRCLE_WORKFLOW_ID"),
118120
},
119121
"int": {
120122
"time": int(time.time()),

.circleci/verbatim-sources/job-specs/pytorch-job-specs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ jobs:
207207
export CIRCLE_PR_NUMBER="${CIRCLE_PR_NUMBER:-}"
208208
export CIRCLE_BRANCH="$CIRCLE_BRANCH"
209209
export CIRCLE_JOB="$CIRCLE_JOB"
210+
export CIRCLE_WORKFLOW_ID="$CIRCLE_WORKFLOW_ID"
210211
cd workspace
211212
python test/print_test_stats.py test
212213
EOL

aten/src/ATen/BatchedFallback.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,32 @@ void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, torch::j
195195
torch::jit::push(stack, self);
196196
}
197197

198+
static Tensor safeStack(TensorList tensors) {
199+
auto is_defined = [](const Tensor& t) { return t.defined(); };
200+
if (std::all_of(tensors.begin(), tensors.end(), is_defined)) {
201+
return at::stack(tensors);
202+
}
203+
// NOTE [vmap through backward and undefined grad]
204+
// While vmapping through backward functions (to compute batched grad), it
205+
// is possible for the backward function to return an undefined grad for some
206+
// grad_input for each example. In that case, we return an undefined grad.
207+
//
208+
// It is theoretically posssible for *some* of the examples to produce an
209+
// undefined grad (a kernel could peek at the gradient values and return an
210+
// undefined tensor if it determines the gradient is full of zeros). We
211+
// could handle this by treating the undefined grad as a zero-filled tensor
212+
// of the correct shape while stacking the tensors together. However I expect
213+
// this to happen very rarely (I have not been able to find an example in our
214+
// codebase) so we just error out in this case.
215+
if (std::none_of(tensors.begin(), tensors.end(), is_defined)) {
216+
return Tensor();
217+
}
218+
TORCH_CHECK(false,
219+
"vmap: slow fallback received a mix of undefined and defined tensors ",
220+
"as the result of an operation. This is not supported, please file us ",
221+
"an issue on github.");
222+
}
223+
198224
// The general flow of the algorithm is as follows.
199225
// - First, we figure out which arguments are BatchedTensors and save them
200226
// to a vector. We also store a vector of which index of the arguments list
@@ -318,7 +344,12 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta
318344
auto output_shards_chunks = MatrixRef<Tensor>(output_shards, num_batches);
319345
for (int64_t return_idx = 0; return_idx < num_returns; ++return_idx) {
320346
auto shards = output_shards_chunks[return_idx];
321-
auto flat_output = at::stack(shards);
347+
auto flat_output = safeStack(shards);
348+
// See NOTE [vmap through backward and undefined grad]
349+
if (!flat_output.defined()) {
350+
torch::jit::push(stack, flat_output);
351+
continue;
352+
}
322353
VmapDimVector output_sizes(batch_sizes);
323354
output_sizes.insert(
324355
output_sizes.end(),

aten/src/ATen/Context.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,10 +314,12 @@ static inline void manual_seed(uint64_t seed) {
314314
}
315315
// NB: Sometimes we build with CUDA, but we don't have any GPUs
316316
// available. In that case, we must not seed CUDA; it will fail!
317-
int num_gpus = detail::getCUDAHooks().getNumGPUs();
317+
const auto num_gpus = detail::getCUDAHooks().getNumGPUs();
318318
if (hasCUDA() && num_gpus > 0) {
319319
for (int i = 0; i < num_gpus; i++) {
320-
auto cuda_gen = globalContext().defaultGenerator(Device(at::kCUDA, i));
320+
auto cuda_gen = globalContext().defaultGenerator(
321+
Device(at::kCUDA, static_cast<c10::DeviceIndex>(i))
322+
);
321323
{
322324
// See Note [Acquire lock when using random generators]
323325
std::lock_guard<std::mutex> lock(cuda_gen.mutex());

aten/src/ATen/core/Dict_inl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ namespace detail {
3838

3939
inline size_t DictKeyHash::operator()(const IValue& ivalue) const {
4040
if (ivalue.isInt()) {
41-
return std::hash<int>()(ivalue.toInt());
41+
return std::hash<int64_t>()(ivalue.toInt());
4242
} else if (ivalue.isString()) {
4343
return std::hash<std::string>()(ivalue.toStringRef());
4444
} else if (ivalue.isDouble()) {

aten/src/ATen/core/boxing/impl/boxing.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,13 @@ using can_unbox =
9797
//
9898
template <class FuncType, class Enable = void>
9999
struct BoxedKernelWrapper {
100-
static_assert(sizeof(FuncType) == -1,
100+
// The reason we're not just doing straight up static_assert(false, ...) here:
101+
// Basically, the way to make sure a static_assert only fires if a template
102+
// is actually instantiated (rather than every time the file is parsed) is to use
103+
// template parameters in the expression, e.g. FuncType here. However, since
104+
// `sizeof(FuncType) != sizeof(FuncType)` is always false, this has the same
105+
// effect.
106+
static_assert(sizeof(FuncType) != sizeof(FuncType),
101107
"Function signature contains one or more unsupported parameter and/or return types. "
102108
"Look for a nearby error like "
103109
"\"'call' is not a member of 'c10::impl::BoxedKernelWrapper<(your function type), void>'\" "

aten/src/ATen/core/builtin_function.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,19 @@ struct BuiltinOpFunction : public Function {
1010
BuiltinOpFunction(
1111
c10::QualifiedName qualname,
1212
c10::FunctionSchema schema,
13-
std::function<void(Stack&)> callable)
13+
std::function<void(Stack&)> callable,
14+
std::string doc_string = "")
1415
: name_(std::move(qualname)),
1516
callable_(std::move(callable)),
16-
schema_(std::move(schema)) {
17+
schema_(std::move(schema)),
18+
doc_string_(std::move(doc_string)) {
1719
TORCH_INTERNAL_ASSERT(schema_.returns().size() == 1);
1820
}
1921

22+
const std::string& doc_string() const override {
23+
return doc_string_;
24+
}
25+
2026
bool isGraphFunction() const override {
2127
return false;
2228
}
@@ -110,6 +116,8 @@ struct BuiltinOpFunction : public Function {
110116
std::function<void(Stack&)> callable_;
111117

112118
c10::FunctionSchema schema_;
119+
120+
std::string doc_string_;
113121
};
114122

115123
} // namespace jit

aten/src/ATen/core/function.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ TORCH_API void preoptimizeGraph(std::shared_ptr<Graph>& graph);
2525
// execution of the function. Method is a wrapper around an
2626
// underlying Function that also provides a `self` object.
2727
struct TORCH_API Function {
28+
virtual const std::string& doc_string() const {
29+
static const std::string no_doc_string = "";
30+
return no_doc_string;
31+
}
32+
2833
virtual bool isGraphFunction() const = 0;
2934

3035
virtual void run(Stack& stack) = 0;

0 commit comments

Comments
 (0)