Skip to content

Commit 234c288

Browse files
Dynamic link shim (#15663)
The symbols need to be explicitly exported on windows. It would also be convenient for the backend to not be dynamically linked so split the shim out. --------- Co-authored-by: roman-janik-nxp <roman.janik@nxp.com>
1 parent 6e2a46b commit 234c288

File tree

11 files changed

+325
-98
lines changed

11 files changed

+325
-98
lines changed

CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -591,8 +591,9 @@ endif()
591591
if(EXECUTORCH_BUILD_CUDA)
592592
# Build CUDA-specific AOTI functionality
593593
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/cuda)
594-
# Add aoti_cuda to backends - it already depends on aoti_common
595-
list(APPEND _executorch_backends aoti_cuda)
594+
# Add aoti_cuda_backend to backends - it transitively includes aoti_cuda_shims
595+
# and cuda_platform
596+
list(APPEND _executorch_backends aoti_cuda_backend)
596597
endif()
597598

598599
if(EXECUTORCH_BUILD_METAL)

backends/aoti/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ target_compile_options(
3838
PUBLIC $<$<CXX_COMPILER_ID:MSVC>:/EHsc /GR>
3939
$<$<NOT:$<CXX_COMPILER_ID:MSVC>>:-fexceptions -frtti -fPIC>
4040
)
41+
target_compile_definitions(
42+
aoti_common PRIVATE $<$<PLATFORM_ID:Windows>:EXPORT_AOTI_FUNCTIONS>
43+
)
4144
# Ensure symbols are exported properly
4245
if(APPLE)
4346
target_link_options(aoti_common PUBLIC -Wl,-export_dynamic)

backends/aoti/common_shims.cpp

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ namespace aoti {
1616

1717
namespace internal {
1818
// Global storage for tensor metadata
19-
std::unordered_map<Tensor*, std::vector<int64_t>> tensor_to_sizes;
20-
std::unordered_map<Tensor*, std::vector<int64_t>> tensor_to_strides;
19+
AOTI_SHIM_EXPORT std::unordered_map<Tensor*, std::vector<int64_t>>
20+
tensor_to_sizes;
21+
AOTI_SHIM_EXPORT std::unordered_map<Tensor*, std::vector<int64_t>>
22+
tensor_to_strides;
2123
} // namespace internal
2224

2325
extern "C" {
@@ -204,6 +206,69 @@ void cleanup_tensor_metadata() {
204206
internal::tensor_to_strides.clear();
205207
}
206208

209+
AOTI_SHIM_EXPORT void aoti_torch_warn(
210+
const char* func,
211+
const char* file,
212+
uint32_t line,
213+
const char* msg) {
214+
ET_LOG(Error, "[%s:%u] %s: %s", file, line, func, msg);
215+
}
216+
217+
AOTI_SHIM_EXPORT AOTITorchError
218+
aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size) {
219+
(void)tensor;
220+
(void)ret_size;
221+
throw std::runtime_error("Not implemented");
222+
return Error::Internal;
223+
}
224+
225+
AOTI_SHIM_EXPORT AOTITorchError
226+
aoti_torch_clone_preserve_strides(Tensor* self, Tensor** ret_new_tensor) {
227+
(void)self;
228+
(void)ret_new_tensor;
229+
throw std::runtime_error("Not implemented");
230+
return Error::Internal;
231+
}
232+
233+
AOTI_SHIM_EXPORT AOTITorchError
234+
aoti_torch_clone(Tensor* self, Tensor** ret_new_tensor) {
235+
(void)self;
236+
(void)ret_new_tensor;
237+
throw std::runtime_error("Not implemented");
238+
return Error::Internal;
239+
}
240+
241+
AOTI_SHIM_EXPORT AOTITorchError
242+
aoti_torch_new_tensor_handle(Tensor* orig_handle, Tensor** new_handle) {
243+
(void)orig_handle;
244+
(void)new_handle;
245+
throw std::runtime_error("Not implemented");
246+
return Error::Internal;
247+
}
248+
249+
AOTI_SHIM_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob(
250+
void* data_ptr,
251+
int64_t ndim,
252+
const int64_t* sizes,
253+
const int64_t* strides,
254+
int64_t storage_offset,
255+
int32_t dtype,
256+
int32_t device_type,
257+
int32_t device_index,
258+
Tensor** ret_new_tensor) {
259+
(void)data_ptr;
260+
(void)ndim;
261+
(void)sizes;
262+
(void)strides;
263+
(void)storage_offset;
264+
(void)dtype;
265+
(void)device_type;
266+
(void)device_index;
267+
(void)ret_new_tensor;
268+
throw std::runtime_error("Not implemented");
269+
return Error::Internal;
270+
}
271+
207272
} // extern "C"
208273

209274
} // namespace aoti

backends/aoti/common_shims.h

Lines changed: 62 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#pragma once
1010

11+
#include <executorch/backends/aoti/export.h>
1112
#include <executorch/backends/aoti/utils.h>
1213
#include <executorch/runtime/core/error.h>
1314
#include <executorch/runtime/core/exec_aten/exec_aten.h>
@@ -23,57 +24,89 @@ namespace aoti {
2324
using executorch::runtime::Error;
2425
using executorch::runtime::etensor::Tensor;
2526

27+
// Global storage for tensor metadata
28+
extern std::unordered_map<Tensor*, std::vector<int64_t>> tensor_to_sizes;
29+
extern std::unordered_map<Tensor*, std::vector<int64_t>> tensor_to_strides;
30+
2631
extern "C" {
2732

2833
// Common AOTI type aliases
2934
using AOTIRuntimeError = Error;
3035
using AOTITorchError = Error;
3136

32-
// Global storage for tensor metadata
33-
extern std::unordered_map<Tensor*, std::vector<int64_t>> tensor_to_sizes;
34-
extern std::unordered_map<Tensor*, std::vector<int64_t>> tensor_to_strides;
35-
3637
// Attribute-related operations (memory-irrelevant)
37-
AOTITorchError aoti_torch_get_data_ptr(Tensor* tensor, void** ret_data_ptr);
38+
AOTI_SHIM_EXPORT AOTITorchError
39+
aoti_torch_get_data_ptr(Tensor* tensor, void** ret_data_ptr);
3840

39-
AOTITorchError aoti_torch_get_storage_offset(
40-
Tensor* tensor,
41-
int64_t* ret_storage_offset);
41+
AOTI_SHIM_EXPORT AOTITorchError
42+
aoti_torch_get_storage_offset(Tensor* tensor, int64_t* ret_storage_offset);
4243

43-
AOTITorchError aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides);
44+
AOTI_SHIM_EXPORT AOTITorchError
45+
aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides);
4446

45-
AOTITorchError aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype);
47+
AOTI_SHIM_EXPORT AOTITorchError
48+
aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype);
4649

47-
AOTITorchError aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes);
50+
AOTI_SHIM_EXPORT AOTITorchError
51+
aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes);
4852

49-
AOTITorchError aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size);
53+
AOTI_SHIM_EXPORT AOTITorchError
54+
aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size);
5055

51-
AOTITorchError aoti_torch_get_device_index(
52-
Tensor* tensor,
53-
int32_t* ret_device_index);
56+
AOTI_SHIM_EXPORT AOTITorchError
57+
aoti_torch_get_device_index(Tensor* tensor, int32_t* ret_device_index);
5458

55-
AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim);
59+
AOTI_SHIM_EXPORT AOTITorchError
60+
aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim);
5661

5762
// Utility functions for device and layout information
58-
int32_t aoti_torch_device_type_cpu();
59-
int32_t aoti_torch_layout_strided();
60-
int32_t aoti_torch_dtype_float32();
61-
int32_t aoti_torch_dtype_bfloat16();
62-
int32_t aoti_torch_dtype_int8();
63-
int32_t aoti_torch_dtype_int16();
64-
int32_t aoti_torch_dtype_int32();
65-
int32_t aoti_torch_dtype_int64();
66-
int32_t aoti_torch_dtype_bool();
63+
AOTI_SHIM_EXPORT int32_t aoti_torch_device_type_cpu();
64+
AOTI_SHIM_EXPORT int32_t aoti_torch_layout_strided();
65+
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_float32();
66+
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bfloat16();
67+
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int8();
68+
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int16();
69+
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int32();
70+
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int64();
6771

6872
// Dtype utility function needed by Metal backend
69-
size_t aoti_torch_dtype_element_size(int32_t dtype);
73+
AOTI_SHIM_EXPORT size_t aoti_torch_dtype_element_size(int32_t dtype);
7074

7175
// Autograd mode functions
72-
int32_t aoti_torch_grad_mode_is_enabled();
73-
void aoti_torch_grad_mode_set_enabled(bool enabled);
76+
AOTI_SHIM_EXPORT int32_t aoti_torch_grad_mode_is_enabled();
77+
AOTI_SHIM_EXPORT void aoti_torch_grad_mode_set_enabled(bool enabled);
7478

7579
// Cleanup functions for clearing global state
76-
void cleanup_tensor_metadata();
80+
AOTI_SHIM_EXPORT void cleanup_tensor_metadata();
81+
82+
AOTI_SHIM_EXPORT void aoti_torch_warn(
83+
const char* func,
84+
const char* file,
85+
uint32_t line,
86+
const char* msg);
87+
88+
AOTI_SHIM_EXPORT AOTITorchError
89+
aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size);
90+
91+
AOTI_SHIM_EXPORT AOTITorchError
92+
aoti_torch_clone_preserve_strides(Tensor* self, Tensor** ret_new_tensor);
93+
94+
AOTI_SHIM_EXPORT AOTITorchError
95+
aoti_torch_clone(Tensor* self, Tensor** ret_new_tensor);
96+
97+
AOTI_SHIM_EXPORT AOTITorchError
98+
aoti_torch_new_tensor_handle(Tensor* orig_handle, Tensor** new_handle);
99+
100+
AOTI_SHIM_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob(
101+
void* data_ptr,
102+
int64_t ndim,
103+
const int64_t* sizes,
104+
const int64_t* strides,
105+
int64_t storage_offset,
106+
int32_t dtype,
107+
int32_t device_type,
108+
int32_t device_index,
109+
Tensor** ret_new_tensor);
77110

78111
} // extern "C"
79112

backends/aoti/export.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
// Define export macro for Windows DLL
12+
// When building the aoti_cuda library, EXPORT_AOTI_FUNCTIONS is defined by
13+
// CMake, which causes this macro to export symbols using __declspec(dllexport).
14+
// When consuming the library, the macro imports symbols using
15+
// __declspec(dllimport). On non-Windows platforms, the macro is empty and has
16+
// no effect.
17+
#ifdef _WIN32
18+
#ifdef EXPORT_AOTI_FUNCTIONS
19+
#define AOTI_SHIM_EXPORT __declspec(dllexport)
20+
#else
21+
#define AOTI_SHIM_EXPORT __declspec(dllimport)
22+
#endif
23+
#else
24+
#define AOTI_SHIM_EXPORT
25+
#endif

0 commit comments

Comments
 (0)