Skip to content

Commit e0408b9

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Merge torch_from_blob and torch_from_blob_v2 into a single shim (#177048)
Mirror the changes in #176440 Mirror the changes on release/2.11 branch Keep just the _v2 signature under the name torch_from_blob, and have the C++ wrapper in ops.h adapt simple function-pointer deleters via a trampoline using if constexpr, avoiding heap allocation for that case. Authored with Claude. Pull Request resolved: #177048 Approved by: https://github.com/malfet
1 parent a12a78b commit e0408b9

8 files changed

Lines changed: 113 additions & 240 deletions

File tree

test/cpp_extensions/libtorch_agn_2_12_extension/csrc/my_from_blob_with_lambda_deleter.cpp renamed to test/cpp_extensions/libtorch_agn_2_11_extension/csrc/my_from_blob_with_lambda_deleter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ static int64_t g_lambda_deleter_call_count = 0;
1414

1515
// Wrapper for from_blob with a capturing-lambda deleter.
1616
// The lambda captures a pointer to the global counter and increments it,
17-
// which exercises the torch_from_blob_v2 code path (deleter + context).
17+
// which exercises the capturing-lambda code path in torch_from_blob.
1818
Tensor my_from_blob_with_lambda_deleter(
1919
int64_t data_ptr,
2020
torch::headeronly::HeaderOnlyArrayRef<int64_t> sizes,
@@ -60,7 +60,7 @@ STABLE_TORCH_LIBRARY_IMPL(
6060
#ifdef LAE_USE_CUDA
6161

6262
// Same as my_from_blob_with_cuda_deleter (from 2.11) but uses a non-capturing
63-
// lambda deleter, exercising the from_blob_v2 code path.
63+
// lambda deleter.
6464
Tensor my_from_blob_with_cuda_lambda_deleter(
6565
int64_t numel,
6666
torch::stable::Device device) {

test/cpp_extensions/libtorch_agn_2_11_extension/libtorch_agn_2_11/ops.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,59 @@ def my_from_blob_with_cuda_deleter(numel: int, device) -> Tensor:
5757
)
5858

5959

60+
def my_from_blob_with_lambda_deleter(data_ptr, sizes, strides, device, dtype) -> Tensor:
61+
"""
62+
Creates a Tensor from existing memory with a capturing-lambda deleter.
63+
64+
The deleter is a capturing lambda that updates a global call count,
65+
exercising the capturing-lambda code path in torch_from_blob.
66+
67+
Args:
68+
data_ptr: int - pointer to the data buffer
69+
sizes: tuple[int] - size of the tensor
70+
strides: tuple[int] - strides of the tensor
71+
device: Device - device on which the tensor resides
72+
dtype: ScalarType - data type of the tensor
73+
74+
Returns: Tensor - tensor wrapping the existing memory
75+
"""
76+
return torch.ops.libtorch_agn_2_11.my_from_blob_with_lambda_deleter.default(
77+
data_ptr, sizes, strides, device, dtype
78+
)
79+
80+
81+
def get_lambda_deleter_call_count() -> int:
82+
"""
83+
Returns the number of times the lambda test deleter has been called.
84+
"""
85+
return torch.ops.libtorch_agn_2_11.get_lambda_deleter_call_count.default()
86+
87+
88+
def reset_lambda_deleter_call_count() -> None:
89+
"""
90+
Resets the lambda deleter call counter to zero.
91+
"""
92+
torch.ops.libtorch_agn_2_11.reset_lambda_deleter_call_count.default()
93+
94+
95+
def my_from_blob_with_cuda_lambda_deleter(numel: int, device) -> Tensor:
96+
"""
97+
Creates a CUDA tensor that owns its memory via cudaMalloc, using a lambda deleter.
98+
99+
Similar to my_from_blob_with_cuda_deleter but uses the capturing-lambda
100+
code path in torch_from_blob.
101+
102+
Args:
103+
numel: int - number of elements in the tensor
104+
device: Device - CUDA device
105+
106+
Returns: Tensor - a 1D float32 tensor of zeros
107+
"""
108+
return torch.ops.libtorch_agn_2_11.my_from_blob_with_cuda_lambda_deleter.default(
109+
numel, device
110+
)
111+
112+
60113
# =============================================================================
61114
# Proxy for inherited ops (from libtorch_agn_2_9 and libtorch_agn_2_10 csrc/)
62115
#

test/cpp_extensions/libtorch_agn_2_12_extension/libtorch_agn_2_12/ops.py

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,4 @@
11
import torch
2-
from torch import Tensor
3-
4-
5-
def my_from_blob_with_lambda_deleter(data_ptr, sizes, strides, device, dtype) -> Tensor:
6-
"""
7-
Creates a Tensor from existing memory with a capturing-lambda deleter.
8-
9-
The lambda deleter captures a pointer to a global counter and increments it,
10-
exercising the torch_from_blob_v2 code path (deleter + context).
11-
12-
Args:
13-
data_ptr: int - pointer to the data buffer
14-
sizes: tuple[int] - size of the tensor
15-
strides: tuple[int] - strides of the tensor
16-
device: Device - device on which the tensor resides
17-
dtype: ScalarType - data type of the tensor
18-
19-
Returns: Tensor - tensor wrapping the existing memory
20-
"""
21-
return torch.ops.libtorch_agn_2_12.my_from_blob_with_lambda_deleter.default(
22-
data_ptr, sizes, strides, device, dtype
23-
)
24-
25-
26-
def get_lambda_deleter_call_count() -> int:
27-
"""
28-
Returns the number of times the lambda deleter has been called.
29-
"""
30-
return torch.ops.libtorch_agn_2_12.get_lambda_deleter_call_count.default()
31-
32-
33-
def reset_lambda_deleter_call_count() -> None:
34-
"""
35-
Resets the lambda deleter call counter to zero.
36-
"""
37-
torch.ops.libtorch_agn_2_12.reset_lambda_deleter_call_count.default()
38-
39-
40-
def my_from_blob_with_cuda_lambda_deleter(numel: int, device) -> Tensor:
41-
"""
42-
Creates a CUDA tensor that owns its memory via cudaMalloc with a lambda deleter.
43-
44-
The tensor's memory is allocated with cudaMalloc and will be freed
45-
with cudaFree when the tensor is destroyed (via from_blob's lambda deleter).
46-
47-
Args:
48-
numel: int - number of elements in the tensor
49-
device: Device - CUDA device
50-
51-
Returns: Tensor - a 1D float32 tensor of zeros
52-
"""
53-
return torch.ops.libtorch_agn_2_12.my_from_blob_with_cuda_lambda_deleter.default(
54-
numel, device
55-
)
562

573

584
# =============================================================================

test/cpp_extensions/test_libtorch_agnostic.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1798,11 +1798,11 @@ def inner():
17981798
curr_mem = torch.cuda.memory_allocated(device)
17991799
self.assertEqual(curr_mem, init_mem)
18001800

1801-
@skipIfTorchVersionLessThan(2, 12)
1801+
@skipIfTorchVersionLessThan(2, 11)
18021802
@skipIfTorchDynamo("no data pointer defined for FakeTensor, FunctionalTensor")
18031803
def test_my_from_blob_with_lambda_deleter(self, device):
1804-
"""Test for from_blob with capturing-lambda deleter (2.12 feature)."""
1805-
import libtorch_agn_2_12 as libtorch_agnostic
1804+
"""Test for from_blob with capturing-lambda deleter (2.11 feature)."""
1805+
import libtorch_agn_2_11 as libtorch_agnostic
18061806

18071807
from_blob_fn = libtorch_agnostic.ops.my_from_blob_with_lambda_deleter
18081808
get_count = libtorch_agnostic.ops.get_lambda_deleter_call_count
@@ -1872,10 +1872,10 @@ def test_my_from_blob_with_cuda_deleter_no_leak(self, device):
18721872
self.assertEqual(curr_mem, init_mem)
18731873

18741874
@onlyCUDA
1875-
@skipIfTorchVersionLessThan(2, 12)
1875+
@skipIfTorchVersionLessThan(2, 11)
18761876
def test_my_from_blob_with_cuda_lambda_deleter_no_leak(self, device):
18771877
"""Test that from_blob lambda deleter properly frees cudaMalloc'd memory."""
1878-
import libtorch_agn_2_12 as libtorch_agnostic
1878+
import libtorch_agn_2_11 as libtorch_agnostic
18791879

18801880
from_blob_fn = libtorch_agnostic.ops.my_from_blob_with_cuda_lambda_deleter
18811881

torch/csrc/shim_common.cpp

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -655,49 +655,6 @@ TORCH_DTYPE_IMPL(float4_e2m1fn_x2, Float4_e2m1fn_x2)
655655
#undef TORCH_DTYPE_IMPL
656656

657657
AOTI_TORCH_EXPORT AOTITorchError torch_from_blob(
658-
void* data,
659-
int64_t ndim,
660-
const int64_t* sizes_ptr,
661-
const int64_t* strides_ptr,
662-
int64_t storage_offset,
663-
int32_t dtype,
664-
int32_t device_type,
665-
int32_t device_index,
666-
AtenTensorHandle* ret_new_tensor,
667-
int32_t layout,
668-
const uint8_t* opaque_metadata,
669-
int64_t opaque_metadata_size,
670-
void (*deleter)(void*)) {
671-
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
672-
c10::IntArrayRef sizes(sizes_ptr, ndim);
673-
c10::IntArrayRef strides(strides_ptr, ndim);
674-
c10::Device device(static_cast<c10::DeviceType>(device_type), device_index);
675-
c10::TensorOptions options = c10::TensorOptions().device(device).dtype(
676-
static_cast<c10::ScalarType>(dtype));
677-
at::Tensor tensor;
678-
if (data != nullptr) {
679-
if (deleter != nullptr) {
680-
tensor = at::for_blob(data, sizes)
681-
.strides(strides)
682-
.storage_offset(storage_offset)
683-
.deleter(deleter)
684-
.options(options)
685-
.make_tensor();
686-
} else {
687-
tensor = at::for_blob(data, sizes)
688-
.strides(strides)
689-
.storage_offset(storage_offset)
690-
.options(options)
691-
.make_tensor();
692-
}
693-
} else {
694-
tensor = at::empty_strided(sizes, strides, options);
695-
}
696-
*ret_new_tensor = torch::aot_inductor::new_tensor_handle(std::move(tensor));
697-
});
698-
}
699-
700-
AOTI_TORCH_EXPORT AOTITorchError torch_from_blob_v2(
701658
void* data,
702659
int64_t ndim,
703660
const int64_t* sizes_ptr,
@@ -721,8 +678,6 @@ AOTI_TORCH_EXPORT AOTITorchError torch_from_blob_v2(
721678
at::Tensor tensor;
722679
if (data != nullptr) {
723680
if (deleter_callback != nullptr) {
724-
// Combine the two-arg C callback and its context into a single-arg
725-
// C++ callable that at::for_blob().deleter() expects.
726681
auto wrapped_deleter = [deleter_callback, deleter_ctx](void* data) {
727682
deleter_callback(data, deleter_ctx);
728683
};

torch/csrc/stable/c/shim.h

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,9 @@ AOTI_TORCH_EXPORT int32_t torch_dtype_float8_e8m0fnu();
165165
AOTI_TORCH_EXPORT int32_t torch_dtype_float4_e2m1fn_x2();
166166

167167
// Creates a tensor from an existing data blob with an optional deleter.
168-
// The deleter is called with the data pointer when the tensor's storage
169-
// is deallocated.
168+
// The deleter receives both the data pointer and a caller-supplied context
169+
// pointer, which allows passing capturing lambdas across the C ABI boundary
170+
// by heap-allocating the callable and passing it as deleter_ctx.
170171
AOTI_TORCH_EXPORT AOTITorchError torch_from_blob(
171172
void* data,
172173
int64_t ndim,
@@ -180,35 +181,10 @@ AOTI_TORCH_EXPORT AOTITorchError torch_from_blob(
180181
int32_t layout,
181182
const uint8_t* opaque_metadata,
182183
int64_t opaque_metadata_size,
183-
void (*deleter)(void*));
184-
185-
#endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_11_0
186-
187-
/**
188-
* The beginning of all shims added in 2.12.0 onwards.
189-
*/
190-
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_12_0
191-
192-
// Like torch_from_blob, but accepts a deleter with a context pointer.
193-
// This allows passing capturing lambdas across the C ABI boundary by
194-
// heap-allocating the callable and passing it as deleter_ctx.
195-
AOTI_TORCH_EXPORT AOTITorchError torch_from_blob_v2(
196-
void* data,
197-
int64_t ndim,
198-
const int64_t* sizes_ptr,
199-
const int64_t* strides_ptr,
200-
int64_t storage_offset,
201-
int32_t dtype,
202-
int32_t device_type,
203-
int32_t device_index,
204-
AtenTensorHandle* ret,
205-
int32_t layout,
206-
const uint8_t* opaque_metadata,
207-
int64_t opaque_metadata_size,
208184
void (*deleter)(void* data, void* ctx),
209185
void* deleter_ctx);
210186

211-
#endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_12_0
187+
#endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_11_0
212188

213189
#ifdef __cplusplus
214190
} // extern "C"

torch/csrc/stable/c/shim_function_versions.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,3 @@ torch_parse_device_string: TORCH_VERSION_2_10_0
2626
torch_dtype_float4_e2m1fn_x2: TORCH_VERSION_2_11_0
2727
torch_dtype_float8_e8m0fnu: TORCH_VERSION_2_11_0
2828
torch_from_blob: TORCH_VERSION_2_11_0
29-
torch_from_blob_v2: TORCH_VERSION_2_12_0

0 commit comments

Comments
 (0)