Skip to content

Commit 1c70eae

Browse files
committed
Merge branch 'main' into bf/noop-elimination
2 parents ee28fed + 1f29190 commit 1c70eae

64 files changed

Lines changed: 1540 additions & 460 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
4022ff142a5392aa5197e05f4dfe85d356f742bf
1+
047bbc720fda70cd5742c76b3c9e01d504577d65

aten/src/ATen/native/TensorShape.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3366,7 +3366,7 @@ static std::vector<Tensor> _pad_chunk(
33663366
std::vector<int64_t> view_sizes(
33673367
tensor_size.begin(), tensor_size.begin() + dim);
33683368
view_sizes.insert(view_sizes.end(), {num_chunks, -1});
3369-
padded_tensors.push_back(padded_tensor.view(view_sizes));
3369+
padded_tensors.push_back(padded_tensor.reshape(view_sizes));
33703370
}
33713371
return padded_tensors;
33723372
}

aten/src/ATen/native/cuda/CUDALoops.cuh

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -612,28 +612,41 @@ struct check_binary_functor_types_for_specialization<
612612
};
613613

614614
// The following is a list of type specializations for vectorized_templated
615-
// elementwise kernel. It refers to the first and second runtime types of the
616-
// arguments of a binary functor.
617-
615+
// elementwise kernel. The three types refer to runtime types of the output
616+
// tensor, first tensor argument, and the second tensor argument used for a
617+
// binary functor.
618618
constexpr std::array rt_binary_specializations = {
619-
std::array<c10::ScalarType, 2>(
619+
std::array<c10::ScalarType, 3>(
620620
{c10::CppTypeToScalarType<float>::value,
621+
c10::CppTypeToScalarType<float>::value,
621622
c10::CppTypeToScalarType<BFloat16>::value}),
622-
std::array<c10::ScalarType, 2>(
623+
std::array<c10::ScalarType, 3>(
624+
{c10::CppTypeToScalarType<float>::value,
625+
c10::CppTypeToScalarType<BFloat16>::value,
626+
c10::CppTypeToScalarType<float>::value}),
627+
std::array<c10::ScalarType, 3>(
623628
{c10::CppTypeToScalarType<BFloat16>::value,
629+
c10::CppTypeToScalarType<BFloat16>::value,
624630
c10::CppTypeToScalarType<float>::value}),
625-
std::array<c10::ScalarType, 2>(
631+
std::array<c10::ScalarType, 3>(
626632
{c10::CppTypeToScalarType<float>::value,
633+
c10::CppTypeToScalarType<float>::value,
627634
c10::CppTypeToScalarType<Half>::value}),
628-
std::array<c10::ScalarType, 2>(
635+
std::array<c10::ScalarType, 3>(
636+
{c10::CppTypeToScalarType<float>::value,
637+
c10::CppTypeToScalarType<Half>::value,
638+
c10::CppTypeToScalarType<float>::value}),
639+
std::array<c10::ScalarType, 3>(
629640
{c10::CppTypeToScalarType<Half>::value,
641+
c10::CppTypeToScalarType<Half>::value,
630642
c10::CppTypeToScalarType<float>::value})};
631643

632644
bool check_binary_rt_types_for_specialization(TensorIteratorBase& iter) {
633645
if (iter.ninputs() != 2)
634646
return false;
635647
for (auto spec : rt_binary_specializations)
636-
if (iter.input_dtype(0) == spec[0] && iter.input_dtype(1) == spec[1])
648+
if (iter.dtype(0) == spec[0] && iter.input_dtype(0) == spec[1] &&
649+
iter.input_dtype(1) == spec[2])
637650
return true;
638651
return false;
639652
}
@@ -648,6 +661,7 @@ struct type_specialized_kernel_launcher {
648661
typename loader_t,
649662
typename storer_t>
650663
static void apply(
664+
ScalarType ret_t,
651665
ScalarType arg0_t,
652666
ScalarType arg1_t,
653667
int64_t numel,
@@ -657,22 +671,22 @@ struct type_specialized_kernel_launcher {
657671
out_calc_t output_offset_calculator,
658672
loader_t loader,
659673
storer_t storer) {
660-
using traits = function_traits<func_t>;
661-
using return_t = typename traits::result_type;
662-
if (arg0_t == rt_binary_specializations[arg_index][0] &&
663-
arg1_t == rt_binary_specializations[arg_index][1])
674+
if (ret_t == rt_binary_specializations[arg_index][0] &&
675+
arg0_t == rt_binary_specializations[arg_index][1] &&
676+
arg1_t == rt_binary_specializations[arg_index][2])
664677
launch_vectorized_templated_kernel<
665678
func_t,
666679
array_t,
667680
inp_calc_t,
668681
out_calc_t,
669682
loader_t,
670683
storer_t,
671-
return_t,
672684
decltype(c10::impl::ScalarTypeToCPPType<
673685
rt_binary_specializations[arg_index][0]>::t),
674686
decltype(c10::impl::ScalarTypeToCPPType<
675-
rt_binary_specializations[arg_index][1]>::t)>(
687+
rt_binary_specializations[arg_index][1]>::t),
688+
decltype(c10::impl::ScalarTypeToCPPType<
689+
rt_binary_specializations[arg_index][2]>::t)>(
676690
numel,
677691
f,
678692
data,
@@ -712,7 +726,6 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
712726
#ifdef USE_ROCM
713727
// Attempt to call specialized vectorized elementwise kernel
714728
// that enables interleaving.
715-
716729
if (check_binary_rt_types_for_specialization(iter) &&
717730
memory::can_vectorize_up_to<func_t>(data) > 1) {
718731
// constexpr to reduce the amount of kernels generated for
@@ -740,6 +753,7 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
740753
type_specialized_kernel_launcher,
741754
rt_binary_specializations.size()>::
742755
with_args(
756+
iter.dtype(0),
743757
iter.input_dtype(0),
744758
iter.input_dtype(1),
745759
numel,

aten/src/ATen/native/cuda/MemoryAccess.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,8 +407,8 @@ struct vectorized_templated {
407407
// float(float,bfloat16) and functor add on float(float,float).
408408
template <typename scalar_t>
409409
__device__ inline void store(scalar_t* from, int idx) {
410-
using vec_t = aligned_vector<scalar_t, vec_size>;
411-
scalar_t* to = reinterpret_cast<scalar_t*>(data[0]) + block_work_size * idx;
410+
using vec_t = aligned_vector<CastToT, vec_size>;
411+
CastToT* to = reinterpret_cast<CastToT*>(data[0]) + block_work_size * idx;
412412
vec_t* to_ = reinterpret_cast<vec_t*>(to);
413413
int thread_idx = threadIdx.x;
414414
#pragma unroll

aten/src/ATen/native/cuda/TensorShape.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -422,11 +422,12 @@ static __global__ void chunk_cat_cuda_kernel(
422422
}
423423

424424
bool all_contiguous(TensorList tensors) {
425-
bool contiguous = true;
426425
for (const auto& t : tensors) {
427-
contiguous &= t.is_non_overlapping_and_dense();
426+
if (!t.is_contiguous()) {
427+
return false;
428+
}
428429
}
429-
return contiguous;
430+
return true;
430431
}
431432

432433
// Get leading dimensions before `dim`-th dimension.

aten/src/ATen/native/transformers/attention.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ REGISTER_AVX512_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
449449
REGISTER_VSX_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
450450
REGISTER_ZVECTOR_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
451451
REGISTER_SVE256_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
452-
REGISTER_HPU_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_meta);
452+
REGISTER_HPU_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_meta)
453453

454454
int64_t _fused_sdp_choice_meta(
455455
const Tensor& query_,

test/cpp_extensions/open_registration_extension/README.md

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,37 @@
1+
# PyTorch OpenReg
2+
13
This folder contains a self-contained example of a PyTorch out-of-tree backend leveraging the "PrivateUse1" backend from core.
24

35
## How to use
6+
47
Install as standalone with `python setup.py develop` (or install) from this folder.
5-
You can run test via `python test/test_openreg.py`.
8+
You can run test via `python {PYTORCH_ROOT_PATH}/test/test_openreg.py`.
69

710
## Design principles
11+
812
For simplicity anything that can be implemented from python is done so.
913
A real implementation will most likely want to call these different APIs from c++ directly.
1014

1115
The current version sends everything back to python and contains enough implementation to run basic model, transfer host/device and printing.
1216

1317
The codebase is split as follows:
14-
- `pytorch_openreg/__init__.py` imports torch to get core state initialized, imports `._aten_impl` to register our aten op implementations to torch, imports `.C` to load our c++ extension that registers more ops, allocator and hooks and finally renames the PrivateUse1 backend and register our python-side module.
15-
- `pytorch_openreg/_aten_impl.py` does two main things. Use the `_register_same_name()` function to register hooks from c++ (like getDevice, getStream, etc) and send them to our device daemon. Define a new `torch.Library` that registers a fallback that will be called whenever a backend kernel for PrivateUse1 is called. It contains the logic to handle all kind of native functions, computing the output metadata, allocating it and only calling into the device daemon to perform computation
16-
- `pytorch_openreg/_device_daemon.py` contains the Allocator (responsible for allocating memory on the device side, as int8 buffers, and recreating nice looking Tensors on the device side to be able to use aten ops to run code there), `run_op` that is the logic running on the device side to perform compute (for simplicity of coverage, we are re-building full blown Tensors here and calling aten ops on them). It also contains the Daemon responsible for the device worker process and sending data back and forth.
17-
- `pytorch_openreg/_meta_parser.py` mainly contain utilities to send objects over the wire from the user process to the device process. The main class there is `OpenRegTensorMeta` that contains all the metadata sent to the device which should be enough for it to populate the output Tensor.
18+
19+
- `pytorch_openreg/__init__.py`
20+
- imports torch to get core state initialized.
21+
- imports `._aten_impl` to register our aten op implementations to torch.
22+
- imports `.C` to load our c++ extension that registers more ops, allocator and hooks.
23+
- renames the PrivateUse1 backend and register our python-side module.
24+
- `pytorch_openreg/_aten_impl.py`
25+
- Define a new `torch.Library` that registers a fallback that will be called whenever a backend kernel for PrivateUse1 is called. It contains the logic to handle all kind of native functions, computing the output metadata, allocating it and only calling into the device daemon to perform computation.
26+
- `pytorch_openreg/_device_daemon.py`
27+
- contains the Allocator (responsible for allocating memory on the device side and host side, as int8 buffers).
28+
- contains `Driver`, which as user-process driver to deal with some information needed to be done in driver.
29+
- contains `Executor`, which as device-process exector to do something related device logic.
30+
- `pytorch_openreg/_meta_parser.py` mainly contain utilities to send objects over the wire from the user process to the device process.
31+
- The main class there is `OpenRegTensorMeta` that contains all the metadata sent to the device which should be enough for it to populate the output Tensor.
1832

1933
## Next steps
2034

21-
Currently, the autograd test is disabled because it's missing the getStream implementation.
2235
The main next step would be to:
23-
- Split the daemon into a proper user-process driver vs device-process executor. The main goal would be to better mimick which information is held on the user-process side and when we're actually communicating with the device. In particular current device or stream should be user-process informations.
24-
- Add Stream/Event system. Most likely by having multiple requests queue that go to the device from the driver.
25-
- Add RNG Generator.
2636

27-
Longer term:
2837
- Replace the current `open_registration_extension.cpp` test in PyTorch CI with this.
29-
- Build this module in the CI environment and enable Device-generic tests on this device.

test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenReg.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,38 @@
44

55
namespace openreg {
66

7+
using openreg_ptr_t = uint64_t;
8+
79
void set_impl_factory(PyObject* factory);
810
py::function get_method(const char* name);
911

12+
static constexpr char kFreeMethod[] = "free";
13+
static constexpr char kHostFreeMethod[] = "hostFree";
14+
15+
template <const char* name>
16+
static void ReportAndDelete(void* ptr) {
17+
if (!ptr || !Py_IsInitialized()) {
18+
return;
19+
}
20+
21+
py::gil_scoped_acquire acquire;
22+
23+
PyObject *type = nullptr, *value = nullptr, *traceback = nullptr;
24+
// Always stash, this will be a no-op if there is no error
25+
PyErr_Fetch(&type, &value, &traceback);
26+
27+
TORCH_CHECK(
28+
get_method(name)(reinterpret_cast<openreg_ptr_t>(ptr)).cast<bool>(),
29+
"Failed to free memory pointer at ",
30+
ptr);
31+
32+
// If that user code raised an error, just print it without raising it
33+
if (PyErr_Occurred()) {
34+
PyErr_Print();
35+
}
36+
37+
// Restore the original error
38+
PyErr_Restore(type, value, traceback);
39+
}
40+
1041
} // namespace openreg

test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegHooks.cpp

Lines changed: 27 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,17 @@
33
#include <ATen/CPUGeneratorImpl.h>
44
#include <ATen/core/GeneratorForPrivateuseone.h>
55
#include <ATen/detail/PrivateUse1HooksInterface.h>
6+
7+
#include <c10/core/Allocator.h>
68
#include <c10/core/Device.h>
79
#include <c10/core/impl/DeviceGuardImplInterface.h>
8-
#include <c10/util/CallOnce.h>
9-
10-
#include <iostream>
1110

1211
namespace openreg {
13-
1412
namespace {
13+
1514
// Python factory function where real implementations can be found
1615
PyObject* py_factory;
1716

18-
using host_ptr_t = uint64_t;
19-
2017
struct HostAllocator final : at::Allocator {
2118
HostAllocator() = default;
2219

@@ -25,35 +22,25 @@ struct HostAllocator final : at::Allocator {
2522
void* data = nullptr;
2623
if (nbytes > 0) {
2724
data = reinterpret_cast<void*>(
28-
get_method("hostMalloc")(nbytes).cast<host_ptr_t>());
25+
get_method("hostMalloc")(nbytes).cast<openreg_ptr_t>());
2926
TORCH_CHECK(data, "Failed to allocator ", nbytes, " bytes on host.");
3027
}
31-
return {data, data, &ReportAndDelete, at::Device(at::kCPU)};
32-
}
33-
34-
static void ReportAndDelete(void* ptr) {
35-
if (!ptr) {
36-
return;
37-
}
38-
py::gil_scoped_acquire acquire;
39-
TORCH_CHECK(
40-
get_method("hostFree")(reinterpret_cast<host_ptr_t>(ptr)).cast<bool>(),
41-
"Failed to free memory pointer at ",
42-
ptr);
28+
return {data, data, &ReportAndDelete<kHostFreeMethod>, at::Device(at::kCPU)};
4329
}
4430

4531
at::DeleterFnPtr raw_deleter() const override {
46-
return &ReportAndDelete;
32+
return &ReportAndDelete<kHostFreeMethod>;
4733
}
4834

4935
void copy_data(void* dest, const void* src, std::size_t count) const final {
5036
py::gil_scoped_acquire acquire;
5137
get_method("hostCopyData")(
52-
reinterpret_cast<host_ptr_t>(dest),
53-
reinterpret_cast<host_ptr_t>(src),
38+
reinterpret_cast<openreg_ptr_t>(dest),
39+
reinterpret_cast<openreg_ptr_t>(src),
5440
count);
5541
}
5642
};
43+
5744
static HostAllocator global_host_alloc;
5845

5946
static c10::DeviceIndex device_count() {
@@ -82,20 +69,8 @@ static at::Generator make_openreg_generator(c10::DeviceIndex device_index) {
8269
// Default, global generators, one per device.
8370
static std::vector<at::Generator> default_generators;
8471

85-
static void initGenerators() {
86-
auto deivce_nums = device_count();
87-
default_generators.resize(deivce_nums);
88-
for (auto i = 0; i < deivce_nums; i++) {
89-
default_generators[i] = make_openreg_generator(i);
90-
default_generators[i].seed();
91-
}
92-
}
93-
94-
// C++ hooks implementation
95-
struct OpenRegHooksArgs : public at::PrivateUse1HooksArgs {};
96-
9772
struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface {
98-
OpenRegHooksInterface(OpenRegHooksArgs) {};
73+
OpenRegHooksInterface() {};
9974
~OpenRegHooksInterface() override = default;
10075

10176
bool hasPrimaryContext(c10::DeviceIndex device_index) const override {
@@ -109,14 +84,22 @@ struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface {
10984

11085
bool isPinnedPtr(const void* data) const override {
11186
py::gil_scoped_acquire acquire;
112-
return get_method("isPinnedPtr")(reinterpret_cast<host_ptr_t>(data))
87+
return get_method("isPinnedPtr")(reinterpret_cast<openreg_ptr_t>(data))
11388
.cast<bool>();
11489
}
11590

11691
const at::Generator& getDefaultGenerator(
11792
c10::DeviceIndex device_index) const override {
118-
static c10::once_flag generator_init_flag;
119-
c10::call_once(generator_init_flag, initGenerators);
93+
static bool flag [[maybe_unused]] = []() {
94+
auto deivce_nums = device_count();
95+
default_generators.resize(deivce_nums);
96+
for (auto i = 0; i < deivce_nums; i++) {
97+
default_generators[i] = make_openreg_generator(i);
98+
default_generators[i].seed();
99+
}
100+
return true;
101+
}();
102+
120103
c10::DeviceIndex idx = device_index;
121104
if (idx == -1) {
122105
idx = current_device_idx();
@@ -131,27 +114,11 @@ struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface {
131114
}
132115
};
133116

134-
int register_hook() {
135-
at::RegisterPrivateUse1HooksInterface(
136-
new OpenRegHooksInterface(OpenRegHooksArgs{}));
137-
return 0;
138-
}
139-
int temp_register_hook = register_hook();
140-
141-
TORCH_DECLARE_REGISTRY(
142-
PrivateUse1HooksRegistry,
143-
OpenRegHooksInterface,
144-
OpenRegHooksArgs);
145-
C10_DEFINE_REGISTRY(
146-
PrivateUse1HooksRegistry,
147-
OpenRegHooksInterface,
148-
OpenRegHooksArgs);
149-
// Using Create function to get PrivateUse1HooksInterface point from
150-
// PrivateUse1HooksRegistry class.
151-
C10_REGISTER_TYPED_CLASS(
152-
PrivateUse1HooksRegistry,
153-
"OpenRegHooks",
154-
OpenRegHooksInterface);
117+
static bool register_hook_flag [[maybe_unused]] = []() {
118+
at::RegisterPrivateUse1HooksInterface(new OpenRegHooksInterface());
119+
120+
return true;
121+
}();
155122

156123
// Device guard registration
157124
struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
@@ -379,4 +346,5 @@ py::function get_method(const char* name) {
379346
auto factory = py::cast<py::function>(py_factory);
380347
return factory(name);
381348
}
349+
382350
} // namespace openreg

0 commit comments

Comments
 (0)