Skip to content

Commit cfe1c6e

Browse files
Ailing Zhangfacebook-github-bot
authored andcommitted
Update XLAPreAutograd keys. (#40265)
Summary: Pull Request resolved: #40265 Differential Revision: D22137998 Pulled By: ailzhang fbshipit-source-id: 41edac06f8aafa5d4c1dcefd5da81be6c9ac4a9c
1 parent 5c133eb commit cfe1c6e

4 files changed

Lines changed: 10 additions & 8 deletions

File tree

c10/core/DispatchKey.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -230,13 +230,6 @@ static_assert(
230230
C10_API const char* toString(DispatchKey);
231231
C10_API std::ostream& operator<<(std::ostream&, DispatchKey);
232232

233-
// For backwards compatibility with XLA repository
234-
// (I don't want to fix this in XLA right now because there might be
235-
// more renaming coming in the future.)
236-
static inline DispatchKey XLA() {
237-
return DispatchKey::XLA;
238-
}
239-
240233
// These are some convenience identifiers for dispatch keys which are
241234
// shorter to type than their long counterparts. Note that some of these
242235
// dispatch keys directly correspond to DeviceType; and most APIs that

c10/core/DispatchKeySet.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,4 +132,11 @@ static inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) {
132132
return s.highestPriorityTypeId();
133133
}
134134

135+
// For backwards compatibility with XLA repository
136+
// (I don't want to fix this in XLA right now because there might be
137+
// more renaming coming in the future.)
138+
static inline DispatchKeySet XLA() {
139+
return DispatchKeySet{DispatchKey::XLA, DispatchKey::XLAPreAutograd};
140+
}
141+
135142
}

c10/core/TensorOptions.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,8 @@ inline DeviceType computeDeviceType(DispatchKey tid) {
643643
return DeviceType::MSNPU;
644644
} else if (tid == DispatchKey::XLA) {
645645
return DeviceType::XLA;
646+
} else if (tid == DispatchKey::XLAPreAutograd) {
647+
return DeviceType::XLA;
646648
} else if (tid == DispatchKey::SparseCPU) {
647649
return DeviceType::CPU;
648650
} else if (tid == DispatchKey::SparseCUDA) {

torch/csrc/utils/tensor_new.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ void check_base_legacy_new(c10::DispatchKey dispatch_key, at::Layout expected_la
336336
TORCH_CHECK(dispatch_key == c10::DispatchKey::CPU
337337
|| dispatch_key == c10::DispatchKey::CUDA
338338
|| dispatch_key == c10::DispatchKey::HIP
339-
|| dispatch_key == c10::XLA(),
339+
|| c10::XLA().has(dispatch_key),
340340
"new(): expected DispatchKey: ", c10::DispatchKey::CPU,
341341
" or ", c10::DispatchKey::CUDA,
342342
" or ", c10::DispatchKey::HIP,

0 commit comments

Comments
 (0)