Skip to content

Commit de9d986

Browse files
committed
Update on "Add base forward grad logic"
RFC: pytorch/rfcs#11 This PR add the basic logic to handle forward grad as dual Tensors. It contains the following: - Mechanism to save dual state on a Tensor and clear it up when the dual level ends - C++ and python user facing API - Updated view system that is able to track both forward and backward views The current PR has the following limitations: - Extensive tests are in the next PR in the stack as formulas are needed to write full tests. - Only the manual formulas have been audited and no other formula is actually implemented here (they are in the next PR in the stack) - Only level 0 is allowed for now. This was discussed and agreed that it is not needed for the first version of this PR. - We can save one ViewInfo creation when both the forward and backward views have the same base. This can be done by adding a boolean flag to the DifferentiableViewMeta and extra logic in the `as_view` method. This is left out to keep this PR concise. - We can skip tracking forward views if the base has a forward grad. This can be done by adding extra logic in the `as_view` method. This is left out to keep this PR concise. Reading guide: - Updated view handling in [gen_variable_type.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-f6553cec68caeaea36f6c8b14ff76a6d39dfd774e0ea9ef2f76e8d81fd9af5df), [VariableTypeUtils.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-ec71cfa45954dece1236c661d170e6341879c5be637f4abf52e826d61b40695a), [variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-60e3bfe444e89efc7149f25b38e472710525984789934ab83f1bd5671b8ff285) (skip code below "[Forward Grad View]" for now), [variable.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-1604bcd0e4350ed99ec45e437cee7ac9ebe337392c9ea16a236247aeeb35b02bR266-R542) and [custom_function.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-dd85f452082b5bb6612bbc12adb496f8827defa228509f7b493de1d517522d5d). This introduces the new ViewInfo to hold view informations shared for forward and backward. It also updates the differentiable view meta to use this. And it updates the as_view function to handle both forward and backward view. - New forward grad class that handle storing gradients and tracking at each level [forward_grad.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-c6c5b9ab2d7e5dde4102495faa1b6bbbfc23aa3e47deb7359c0bfe1eb004c0cb), [forward_grad.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-de2ab54ade7312701850d71a119a4f4ee4b9fc5a9c42a467cdd4e73c033531dd) and [build_variables.bzl](https://github.com/pytorch/pytorch/pull/49097/files#diff-dfdfa2efb17beddfd9094524f95351fd197db6c8857e96b436fb599870359325). EDIT: These files also contain the new flag to globally disable forward AD that allows us to reduce performance issues while this is in development. - Lowest level API and binding between Tensor and AutogradMeta in [TensorBody.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-7554853205392fa743357bf845ecc350a974ec049383248c12daaf2f4de04911), [TensorImpl.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-052bd9150ef8e09289ddf644b5a6830ede49207201cd41728f6d7cc6d9cead94), [TensorImpl.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-a15aae4cf23da44970db7cece62ff981265575c798c62f7b52d87c8809dfe2e1) and the rest of [variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-60e3bfe444e89efc7149f25b38e472710525984789934ab83f1bd5671b8ff285R557-R677) - API to access the forward primal that needs to be a differentiable function (and so in native_functions.yaml) [native_functions.yaml](https://github.com/pytorch/pytorch/pull/49097/files#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991) [NamedRegistrations.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-69bd3bea510c9b64e1633fa18c3ea63d4b8348dbad3a78ad9de844ab3e43dc1d), [VariableMethodsStub.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-23f5fcb737a2b289811fe0f4b65aef775e7c824b2e629ecd343df51405cd434f), [derivatives.yaml](https://github.com/pytorch/pytorch/pull/49097/files#diff-e4c2f99a2404e98c3586e07425da73008f36b1bada790648a7297af141d37f8c), [gen_python_functions.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-e4c2f99a2404e98c3586e07425da73008f36b1bada790648a7297af141d37f8c), [gen_trace_type.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-54e0b976027bf8debefb959ff360b89ae93466970c843365b1b3a03806d868ce), [TraceTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-f34636741ad4a23d018e0c289bc750c3bad887b45660e1d6eaf440d234a78fbf) and [part of VariableTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-6e19a1bce8cbdba8714b6e2c794a76bc0864b64a49cfa757cb0b5afdc937d1a4R198-R243) - c++ API [autograd.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-349028fbe8291a965a7a263c323b208fe071c35c66179ee997ef84fa81aa4b1e), [autograd.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-a3fe908d67dfec16a1fcde300de68b0701bf68b88db7451f29f2bee255cf30c9) - python binding [init.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-c58a67c85191c22c9b3bb439117d8053edfd9dea839fa010cf967d404c3c630d) - python API [forward_ad.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-a4efad4ba18fffdfb264c21e5475997a24a743089a899f8ec1a5ff962c6738d9), [autograd/__init__.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-743abcafd32ad0e69f39ac5a91df4197b7e1921c135cacee7ef6dc829a8a7af8) - c++ and python printing [Formatting.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-881dba501e71662e2e4818b4b016f739b344c8aed2f5edc6b871eda47a2aced0), [_tensor_str.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-a7911f8d5e73adbff914d99fd7818ace2a7030b6a3748abe06ec6fc6e3df9cc3) - Utility for formulas and updated manual functions to respect new view system as well as forward grad [FunctionsManual.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-6378bb6dc81a64dab676d61731341fa5d1088418f32a1473a33a0ccfc2357dc1), [FunctionsManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-4adbd88239afcd60e8198aab65d4f5e43b62314e34b80551e997a1ea503adea5) [rest of VariableTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-6e19a1bce8cbdba8714b6e2c794a76bc0864b64a49cfa757cb0b5afdc937d1a4R264-R433) - Ensure SavedVariable save forward grad properly [saved_variable.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-c1b8039d776241abe177d5aa99b79dd9489a9b3e529da8ab24c2e386c1238ae2), [saved_variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-cc9fba479b5beae06b2eea2e390d17796e0341c5b037a20b5bcaccbb0c341030) Differential Revision: [D25607503](https://our.internmc.facebook.com/intern/diff/D25607503) [ghstack-poisoned]
2 parents a55e021 + ccd6466 commit de9d986

160 files changed

Lines changed: 3931 additions & 1141 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.clang-tidy

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ cppcoreguidelines-*,
2222
hicpp-exception-baseclass,
2323
hicpp-avoid-goto,
2424
modernize-*,
25+
-modernize-concat-nested-namespaces,
2526
-modernize-return-braced-init-list,
2627
-modernize-use-auto,
2728
-modernize-use-default-member-init,

.github/workflows/lint.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,16 @@ jobs:
1717
architecture: x64
1818
- name: Checkout PyTorch
1919
uses: actions/checkout@v1
20+
- name: Checkout PR tip
21+
run: |
22+
set -eux
23+
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
24+
# We are on a PR, so actions/checkout leaves us on a merge commit.
25+
# Check out the actual tip of the branch.
26+
git checkout ${{ github.event.pull_request.head.sha }}
27+
fi
28+
echo ::set-output name=commit_sha::$(git rev-parse HEAD)
29+
id: get_pr_tip
2030
- name: Ensure consistent CircleCI YAML config
2131
run: |
2232
pip install -r requirements.txt
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
name: Update S3 HTML indices for download.pytorch.org
2+
on:
3+
schedule:
4+
# Update the indices every 30 minutes
5+
- cron: "*/30 * * * *"
6+
# Have the ability to trigger this job manually using the API as well
7+
workflow_dispatch:
8+
9+
jobs:
10+
update-html:
11+
runs-on: ubuntu-latest
12+
strategy:
13+
matrix:
14+
prefix: ["whl", "whl/test", "whl/nightly"]
15+
steps:
16+
- name: Run updater image
17+
env:
18+
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_S3_UPDATE_ACCESS_KEY_ID }}
19+
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_UPDATE_SECRET_ACCESS_KEY }}
20+
uses: docker://pytorch/manage_s3_html
21+
with:
22+
args: ${{ matrix.prefix }}

.jenkins/pytorch/win-build.sh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,21 @@ fi
3838

3939
export SCRIPT_HELPERS_DIR=$SCRIPT_PARENT_DIR/win-test-helpers
4040

41+
set +ex
42+
grep -E -R 'PyLong_(From|As)(Unsigned|)Long\(' --exclude=python_numbers.h torch/
43+
PYLONG_API_CHECK=$?
44+
if [[ $PYLONG_API_CHECK == 0 ]]; then
45+
echo "Usage of PyLong_{From,As}{Unsigned}Long API may lead to overflow errors on Windows"
46+
echo "because \`sizeof(long) == 4\` and \`sizeof(unsigned long) == 4\`."
47+
echo "Please include \"torch/csrc/python_numbers.h\" and use the correspoding APIs instead."
48+
echo "PyLong_FromLong -> THPUtils_packInt32 / THPUtils_packInt64"
49+
echo "PyLong_AsLong -> THPUtils_unpackInt (32-bit) / THPUtils_unpackLong (64-bit)"
50+
echo "PyLong_FromUnsignedLong -> THPUtils_packUInt32 / THPUtils_packUInt64"
51+
echo "PyLong_AsUnsignedLong -> THPUtils_unpackUInt32 / THPUtils_unpackUInt64"
52+
exit 1
53+
fi
54+
set -ex
55+
4156
$SCRIPT_HELPERS_DIR/build_pytorch.bat
4257

4358
assert_git_not_dirty

aten/src/ATen/BatchedFallback.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta
361361
flat_output.sizes().end());
362362
torch::jit::push(
363363
stack,
364-
input_physical_views.front().newLogicalFromPhysical(flat_output.view(output_sizes)));
364+
input_physical_views.front().getPhysicalToLogicalMap().apply(flat_output.view(output_sizes)));
365365
}
366366
}
367367

aten/src/ATen/BatchingRegistrations.cpp

Lines changed: 50 additions & 50 deletions
Large diffs are not rendered by default.

aten/src/ATen/OpaqueTensorImpl.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@ struct CAFFE2_API OpaqueTensorImpl : public TensorImpl {
2424
const caffe2::TypeMeta data_type,
2525
c10::Device device,
2626
OpaqueHandle opaque_handle,
27-
c10::IntArrayRef sizes)
27+
c10::IntArrayRef sizes,
28+
bool is_non_overlapping_and_dense = true)
2829
: TensorImpl(key_set, data_type, device),
2930
opaque_handle_(std::move(opaque_handle)) {
3031
sizes_ = sizes.vec();
3132
refresh_numel();
33+
is_non_overlapping_and_dense_ = is_non_overlapping_and_dense;
3234
}
3335

3436
void release_resources() override {

aten/src/ATen/ParallelOpenMP.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include <mkl.h>
99
#endif
1010

11+
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
12+
1113
namespace at {
1214

1315
namespace {
@@ -49,6 +51,12 @@ void set_num_threads(int nthreads) {
4951
// See https://github.com/pytorch/pytorch/issues/13757
5052
mkl_set_dynamic(false);
5153
#endif
54+
#ifdef USE_PTHREADPOOL
55+
// because PyTorch uses caffe2::pthreadpool() in QNNPACK
56+
caffe2::PThreadPool* const pool = caffe2::pthreadpool();
57+
TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
58+
pool->set_thread_count(nthreads);
59+
#endif
5260
}
5361

5462
// Explicitly calling omp_get_max_threads() as the size of the parallel

aten/src/ATen/VmapTransforms.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,6 @@ static BatchDims computeFrontBatchDimsFromLevels(std::bitset<kVmapNumLevels> lev
9191
return bdims;
9292
}
9393

94-
Tensor VmapPhysicalView::newLogicalFromPhysical(const Tensor& physical) const {
95-
return makeBatched(physical, computeFrontBatchDimsFromLevels(levels_));
96-
}
97-
98-
void VmapPhysicalView::makeLogicalFromPhysicalListInplace(std::vector<Tensor>& physical_tensors) const {
99-
for (int64_t idx = 0; idx < physical_tensors.size(); ++idx) {
100-
physical_tensors[idx] = newLogicalFromPhysical(physical_tensors[idx]);
101-
}
102-
}
103-
10494
// Given a Tensor or a BatchedTensor, returns the underlying physical tensor
10595
// with all vmapped dimensions permuted to the front, if they exist, and a
10696
// bitset of vmap levels that were present in the tensor.
@@ -281,4 +271,18 @@ VmapPhysicalViewVec BroadcastingVmapTransform::logicalToPhysical(TensorList logi
281271
return result;
282272
}
283273

274+
VmapPhysicalToLogicalMap VmapPhysicalView::getPhysicalToLogicalMap() const {
275+
return VmapPhysicalToLogicalMap(levels_);
276+
}
277+
278+
Tensor VmapPhysicalToLogicalMap::apply(const Tensor& physical_tensor) const {
279+
return makeBatched(physical_tensor, computeFrontBatchDimsFromLevels(levels_));
280+
}
281+
282+
void VmapPhysicalToLogicalMap::applyInplace(std::vector<Tensor>& physical_tensors) const {
283+
for (int64_t idx = 0; idx < physical_tensors.size(); ++idx) {
284+
physical_tensors[idx] = apply(physical_tensors[idx]);
285+
}
286+
}
287+
284288
} // namespace at

aten/src/ATen/VmapTransforms.h

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ struct TORCH_API BroadcastingVmapTransform {
7979
static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors);
8080
};
8181

82+
// Forward declared, if you're reading this file head to toe, don't worry about
83+
// it yet.
84+
struct VmapPhysicalToLogicalMap;
85+
8286
// NOTE: [What is a VmapPhysicalView?]
8387
// VmapPhysicalView represents a physical view on a Tensor.
8488
//
@@ -115,24 +119,14 @@ struct TORCH_API VmapPhysicalView {
115119
VmapDimVector getPhysicalDims(IntArrayRef logical_dims) const;
116120
int64_t getPhysicalDim(int64_t logical_dim) const;
117121

122+
// Returns a VmapPhysicalToLogicalMap object. This can be used for
123+
// mapping a physical tensor to a new logical tensor (BatchedTensor)
124+
VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const;
125+
118126
// Maps a logical shape to a physical shape by pre-pending the batch
119127
// sizes to the logical shape.
120128
VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const;
121129

122-
// Maps a physical tensor to a new logical tensor (BatchedTensor),
123-
// using the mapping info stored in this VmapPhysicalView.
124-
// Assumes that all of the "batch dimensions" are at the front
125-
// of the physical tensor.
126-
Tensor newLogicalFromPhysical(const Tensor& physical) const;
127-
128-
// Given a vector of physical tensors,
129-
// 1. maps each tensor to a new logical tensor using the mapping info stored
130-
// in this VmapPhysicalView. Assumes that all of the "batch dimensions"
131-
// are at the front of the physical tensors.
132-
// 2. stores the new logical tensors back into the passed-in vector. This is
133-
// to avoid additional dynamic allocations.
134-
void makeLogicalFromPhysicalListInplace(std::vector<Tensor>& physical_tensors) const;
135-
136130
int64_t numBatchDims() const;
137131

138132
private:
@@ -142,5 +136,31 @@ struct TORCH_API VmapPhysicalView {
142136
Tensor tensor_;
143137
};
144138

139+
// Convenience struct used for mapping a physical tensor (a non-BatchedTensor)
140+
// to a logical one (BatchedTensor). It holds some levels that are used to do the
141+
// mapping and assumes that the batch dimensions in the physical tensor all
142+
// occur at the front of the tensor.
143+
struct TORCH_API VmapPhysicalToLogicalMap {
144+
VmapPhysicalToLogicalMap(std::bitset<kVmapNumLevels> levels): levels_(levels) {}
145+
146+
// Maps a physical tensor to a new logical tensor (BatchedTensor).
147+
// Assumes that all of the "batch dimensions" are at the front
148+
// of the physical tensor. For example, given:
149+
// - x = rank-4 Tensor with size 2, 3, 5, 7
150+
// - levels = (2, 4)
151+
// Returns:
152+
// - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)])
153+
Tensor apply(const Tensor& physical_tensor) const;
154+
155+
// Given a vector of physical tensors,
156+
// 1. maps each tensor to a new logical tensor. Assumes that all of the
157+
// "batch dimensions" are at the front of the physical tensors.
158+
// 2. stores the new logical tensors back into the passed-in vector. This is
159+
// to avoid additional dynamic allocations.
160+
void applyInplace(std::vector<Tensor>& physical_tensors) const;
161+
162+
std::bitset<kVmapNumLevels> levels_;
163+
};
164+
145165

146166
} // namespace at

0 commit comments

Comments
 (0)