Skip to content

Commit 62ed7bd

Browse files
committed
Update on "Remove backward and requires_grad from Autograd backend key"
Just following a TODO in the code base... Differential Revision: [D25644597](https://our.internmc.facebook.com/intern/diff/D25644597/) [ghstack-poisoned]
2 parents 28ae970 + 56b6322 commit 62ed7bd

23 files changed

Lines changed: 362 additions & 180 deletions

aten/src/ATen/VmapTransforms.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,17 @@ struct VmapPhysicalToLogicalMap;
9696
// The levels bitset specifies which vmap levels correspond to the batch
9797
// dimensions at the front of the tensor. In particular, the number of set bits
9898
// corresponds to the number of batch dimensions on `tensor` and the rightmost
99-
// bit of `levels` specifies the minimum number of nested vmaps we are in at
99+
// bit of `levels` specifies the maximum number of nested vmaps we are in at
100100
// this point in time.
101+
// For example, given:
102+
// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3})
103+
//
104+
// Rightmost bit of `levels` is 3 indicating the number of nested vmaps less
105+
// than or equal to 3.
106+
// bitset: 010100
107+
// ^
108+
// |
109+
// levels: 012345
101110
struct TORCH_API VmapPhysicalView {
102111
VmapPhysicalView(Tensor&& tensor, std::bitset<kVmapNumLevels> levels)
103112
: levels_(levels), tensor_(tensor) {

aten/src/ATen/native/Distributions.cpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ DEFINE_DISPATCH(bernoulli_tensor_stub);
118118
DEFINE_DISPATCH(bernoulli_scalar_stub);
119119
DEFINE_DISPATCH(cauchy_stub);
120120
DEFINE_DISPATCH(exponential_stub);
121-
DEFINE_DISPATCH(multinomial_stub);
121+
DEFINE_DISPATCH(multinomial_with_replacement_stub);
122122
DEFINE_DISPATCH(geometric_stub);
123123
DEFINE_DISPATCH(log_normal_stub);
124124
DEFINE_DISPATCH(uniform_stub);
@@ -497,8 +497,10 @@ Tensor& multinomial_out(
497497
// Reference:
498498
// https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503
499499
// Half is not supported on CPU.
500-
if (!with_replacement &&
501-
!(self.device().is_cpu() && self.scalar_type() == ScalarType::Half)) {
500+
TORCH_CHECK(
501+
!(self.device().is_cpu() && self.scalar_type() == ScalarType::Half),
502+
"multinomial is not implemented for half on CPU");
503+
if (!with_replacement) {
502504
// Sanity checks on `self`.
503505
auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0)).item();
504506
TORCH_CHECK(
@@ -537,13 +539,8 @@ Tensor& multinomial_out(
537539
return result;
538540
}
539541

540-
multinomial_stub(
541-
result.device().type(),
542-
result,
543-
self,
544-
n_sample,
545-
with_replacement,
546-
gen);
542+
multinomial_with_replacement_stub(
543+
result.device().type(), result, self, n_sample, gen);
547544
return result;
548545
}
549546

aten/src/ATen/native/TensorShape.cpp

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -98,23 +98,25 @@ static inline void check_cat_shape_except_dim(const Tensor & first, const Tensor
9898
if (dim == dimension) {
9999
continue;
100100
}
101-
int64_t first_dim_size = first.size(dim);
102-
int64_t second_dim_size = second.size(dim);
101+
int64_t first_dim_size = first.sizes()[dim];
102+
int64_t second_dim_size = second.sizes()[dim];
103103
TORCH_CHECK(first_dim_size == second_dim_size, "Sizes of tensors must match except in dimension ",
104104
dimension, ". Got ", first_dim_size, " and ", second_dim_size, " in dimension ", dim,
105105
" (The offending index is ", index, ")");
106106
}
107107
}
108108

109+
static bool should_skip(const Tensor& t) {
110+
return t.numel() == 0 && t.dim() == 1;
111+
}
112+
109113
Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
110114
// previously, size [0] tensors were the only possible empty tensors; thus, it wasn't possible
111115
// to cat empty tensors unless all the other tensors were 1-dimensional, so we allowed these tensors
112116
// to be "skipped". We maintain this behavior for backwards compatibility, but only for this specific
113117
// size (i.e. other empty sizes are not skipped).
114-
// FIXME: warn if this is the case
115-
bool allSkipped = true;
118+
116119
bool allContiguous = true;
117-
Tensor notSkippedTensor;
118120

119121
// Inputs cannot alias the output tensor
120122
for (int64_t i = 0; i < tensors.size(); i++) {
@@ -126,19 +128,23 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
126128
}
127129
at::assert_no_internal_overlap(result);
128130

129-
auto should_skip = [](const Tensor& t) { return t.numel() == 0 && t.dim() == 1; };
130-
for (auto const &tensor : tensors) {
131-
if (should_skip(tensor)) {
132-
continue;
131+
const Tensor* pnotSkippedTensor = [](TensorList tensors) -> const Tensor* {
132+
for (auto const &tensor : tensors) {
133+
if (should_skip(tensor)) {
134+
continue;
135+
}
136+
// we've found a non-empty tensor
137+
return &tensor;
133138
}
134-
// we've found a non-empty tensor
135-
allSkipped = false;
136-
notSkippedTensor = tensor;
137-
break;
138-
}
139-
if (allSkipped) {
139+
return nullptr;
140+
}(tensors);
141+
142+
if (!pnotSkippedTensor) {
143+
// FIXME: warn if this is the case -- see comment about skipped
144+
// tensors at top of function.
140145
return result;
141146
}
147+
const Tensor& notSkippedTensor = *pnotSkippedTensor;
142148

143149
TORCH_CHECK(tensors.size() > 0, "expected a non-empty list of Tensors");
144150
TORCH_CHECK(dim <= notSkippedTensor.dim(), "dimension ", dim, "out of range");
@@ -161,7 +167,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
161167
continue;
162168
}
163169
check_cat_shape_except_dim(notSkippedTensor, tensor, dim, i);
164-
cat_dim_size += tensor.size(dim);
170+
cat_dim_size += tensor.sizes()[dim];
165171

166172
if (!tensor.is_contiguous(first_tensor_mem_format)) {
167173
allContiguous = false;
@@ -196,8 +202,8 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
196202
if (reuse_iterator &&
197203
result.is_contiguous(first_tensor_mem_format) &&
198204
no_type_promotion) {
199-
auto source_slice = notSkippedTensor;
200-
auto slice_dim_size = source_slice.size(dim);
205+
const auto& source_slice = notSkippedTensor;
206+
auto slice_dim_size = source_slice.sizes()[dim];
201207
auto result_slice = result.narrow(dim, 0, slice_dim_size);
202208
auto result_slice_data = result_slice.data_ptr();
203209
auto result_stride_bytes = result.stride(dim) * elementSize(result.scalar_type());
@@ -226,7 +232,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
226232
if (should_skip(tensor)) {
227233
continue;
228234
}
229-
auto slice_dim_size = tensor.size(dim);
235+
auto slice_dim_size = tensor.sizes()[dim];
230236
auto result_slice = result.narrow(dim, offset, slice_dim_size);
231237

232238
auto iter = TensorIteratorConfig()

aten/src/ATen/native/UnaryOps.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ DECLARE_DISPATCH(void(*)(TensorIterator&, c10::optional<Generator>), random_full
7777
DECLARE_DISPATCH(void(*)(TensorIterator&, c10::optional<Generator>), random_stub);
7878
DECLARE_DISPATCH(void(*)(TensorIterator&, const int64_t), polygamma_stub);
7979
DECLARE_DISPATCH(void(*)(TensorIterator&, Scalar a, Scalar b), clamp_stub);
80-
DECLARE_DISPATCH(void(*)(Tensor&, const Tensor&, int64_t, bool, c10::optional<Generator>), multinomial_stub);
80+
DECLARE_DISPATCH(
81+
void (*)(Tensor&, const Tensor&, int64_t, c10::optional<Generator>),
82+
multinomial_with_replacement_stub);
8183
DECLARE_DISPATCH(
8284
void (*)(
8385
TensorIterator&,

aten/src/ATen/native/cpu/CatKernel.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,20 @@ struct InputMeta {
1515

1616
InputMeta(const Tensor& t, int64_t dim, int64_t inner)
1717
: data_ptr(t.data_ptr())
18-
, inner_size(t.size(dim) * inner) {}
18+
, inner_size(t.sizes()[dim] * inner) {}
1919
};
2020

2121
template <typename scalar_t>
2222
void cat_serial_kernel_impl(Tensor& result, TensorList tensors, int64_t dim) {
23-
int64_t outer = result.numel() / (result.size(dim) * result.stride(dim));
23+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
24+
dim >= 0 && dim < result.dim(), "dim out of range in cat_serial_kernel_impl");
25+
int64_t outer = result.numel() / (result.sizes()[dim] * result.strides()[dim]);
2426
scalar_t* result_data = result.data_ptr<scalar_t>();
2527
int64_t ninputs = tensors.size();
2628
std::vector<InputMeta> inputs;
2729
inputs.reserve(ninputs);
2830
for (auto const &tensor : tensors) {
29-
inputs.emplace_back(tensor, dim, result.stride(dim));
31+
inputs.emplace_back(tensor, dim, result.strides()[dim]);
3032
}
3133

3234
using Vec = vec256::Vec256<scalar_t>;

aten/src/ATen/native/cpu/MultinomialKernel.cpp

Lines changed: 15 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@ namespace at {
1111
namespace native {
1212
namespace {
1313

14-
template<typename scalar_t>
15-
void multinomial_apply(Tensor& result, const Tensor& self, const int64_t n_sample, const bool with_replacement, c10::optional<Generator> generator) {
14+
template <typename scalar_t>
15+
void multinomial_with_replacement_apply(
16+
Tensor& result,
17+
const Tensor& self,
18+
const int64_t n_sample,
19+
c10::optional<Generator> generator) {
1620
auto gen = get_generator_or_default<CPUGeneratorImpl>(generator, detail::getDefaultCPUGenerator());
1721
// See Note [Acquire lock when using random generators]
1822
std::lock_guard<std::mutex> lock(gen->mutex_);
@@ -61,8 +65,6 @@ void multinomial_apply(Tensor& result, const Tensor& self, const int64_t n_sampl
6165
}
6266

6367
TORCH_CHECK(sum > 0, "invalid multinomial distribution (sum of probabilities <= 0)");
64-
TORCH_CHECK(with_replacement || (n_categories - n_zeros >= n_sample),
65-
"invalid multinomial distribution (with replacement=False, not enough non-negative category to sample)");
6668

6769
/* normalize cumulative probability distribution so that last val is 1
6870
i.e. doesn't assume original self row sums to one */
@@ -100,45 +102,23 @@ void multinomial_apply(Tensor& result, const Tensor& self, const int64_t n_sampl
100102

101103
/* store in result tensor (will be incremented for lua compat by wrapper) */
102104
result_ptr[i * result_dist_stride_0 + j * result_dist_stride_1] = sample_idx;
103-
104-
/* Once a sample is drawn, it cannot be drawn again. ie sample without replacement */
105-
if (!with_replacement && j < n_sample - 1) {
106-
/* update cumulative distribution so that sample cannot be drawn again */
107-
scalar_t diff;
108-
scalar_t new_val = 0;
109-
scalar_t sum;
110-
111-
if (sample_idx != 0) {
112-
new_val = cum_dist_ptr[(sample_idx - 1) * cum_dist_stride_0];
113-
}
114-
/* marginal cumulative mass (i.e. original probability) of sample */
115-
diff = cum_dist_ptr[sample_idx * cum_dist_stride_0] - new_val;
116-
/* new sum of marginals is not one anymore... */
117-
sum = 1.0 - diff;
118-
for (int64_t k = 0; k < n_categories; k++) {
119-
new_val = cum_dist_ptr[k * cum_dist_stride_0];
120-
if (k >= sample_idx) {
121-
/* remove sampled probability mass from later cumulative probabilities */
122-
new_val -= diff;
123-
}
124-
/* make total marginals sum to one */
125-
new_val /= sum;
126-
cum_dist_ptr[k * cum_dist_stride_0] = new_val;
127-
}
128-
}
129105
}
130106
}
131107
}
132108

133-
static void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n_sample, const bool with_replacement, c10::optional<Generator> gen) {
109+
static void multinomial_with_replacement_kernel_impl(
110+
Tensor& result,
111+
const Tensor& self,
112+
const int64_t n_sample,
113+
c10::optional<Generator> gen) {
134114
AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "multinomial", [&] {
135-
multinomial_apply<scalar_t>(result, self, n_sample, with_replacement, gen);
115+
multinomial_with_replacement_apply<scalar_t>(result, self, n_sample, gen);
136116
});
137117
}
138-
139118
}
140119

141-
REGISTER_DISPATCH(multinomial_stub, &multinomial_kernel_impl);
142-
120+
REGISTER_DISPATCH(
121+
multinomial_with_replacement_stub,
122+
&multinomial_with_replacement_kernel_impl);
143123
}
144124
}

aten/src/ATen/native/cuda/Dropout.cu

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ fused_dropout_kernel_vec(at::cuda::detail::TensorInfo<scalar_t, IndexType> a,
5757

5858
accscalar_t pinv = accscalar_t(1)/p;
5959

60+
// Helps align the total number of times curand_uniform4 is called by each thread for the same totalElements
61+
// in the vec=2 and vec=4 cases.
62+
bool gridxvec_loop_state = 0;
63+
64+
float4 rand;
65+
6066
// Note: Vectorized loads means we'll stride each thread by an additional VEC factor, as we'll load VEC elements at a time
6167
for (IndexType linearIndex = idx * VEC;
6268
linearIndex < totalElements;
@@ -69,12 +75,21 @@ fused_dropout_kernel_vec(at::cuda::detail::TensorInfo<scalar_t, IndexType> a,
6975
//curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything
7076
// Note: need a new set of random values per 4 elements -- we'll handle VEC elements in this thread, so need ceil(VEC / 4)
7177
// sets of rand.
72-
float4 rand = curand_uniform4(&state);
78+
if ((VEC == 4) || (gridxvec_loop_state == 0)) {
79+
rand = curand_uniform4(&state);
80+
} else {
81+
// sets up the last two values we generated last iteration to be used this iteration.
82+
rand.x = rand.z;
83+
rand.y = rand.w;
84+
gridxvec_loop_state ^= 1;
85+
}
7386

7487
rand.x = rand.x < p;
7588
rand.y = rand.y < p;
76-
rand.z = rand.z < p;
77-
rand.w = rand.w < p;
89+
if (VEC == 4) {
90+
rand.z = rand.z < p;
91+
rand.w = rand.w < p;
92+
}
7893

7994
// Note: We explicitly check for is_contiguous() before launching the vectorized kernel
8095
// and replace IndexToOffset call with linearIndex to allow vectorization of NHWC (or other)

aten/src/ATen/native/cuda/MultinomialKernel.cu

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,11 @@ sampleMultinomialOnce(int64_t* dest,
300300
}
301301
}
302302
303-
void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n_sample, const bool with_replacement, c10::optional<Generator> generator) {
303+
void multinomial_with_replacement_kernel_impl(
304+
Tensor& result,
305+
const Tensor& self,
306+
const int64_t n_sample,
307+
c10::optional<Generator> generator) {
304308
auto gen = get_generator_or_default<CUDAGeneratorImpl>(generator, cuda::detail::getDefaultCUDAGenerator());
305309
306310
int inputSize = self.dim();
@@ -371,7 +375,6 @@ void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n
371375
372376
PhiloxCudaState rng_engine_inputs;
373377
374-
if (with_replacement) {
375378
// Binary search is warp divergent (so effectively we're running
376379
// with just a single thread), but for better utilization,
377380
// we need each block to have at least 4 warps.
@@ -402,7 +405,6 @@ void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n
402405
prefixSum.data_ptr<scalar_t>(),
403406
normDist.data_ptr<scalar_t>());
404407
C10_CUDA_KERNEL_LAUNCH_CHECK();
405-
}
406408
}
407409
});
408410
@@ -412,6 +414,7 @@ void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n
412414
}
413415
}
414416
415-
REGISTER_DISPATCH(multinomial_stub, &multinomial_kernel_impl);
416-
417+
REGISTER_DISPATCH(
418+
multinomial_with_replacement_stub,
419+
&multinomial_with_replacement_kernel_impl);
417420
}}

0 commit comments

Comments
 (0)