88
99#pragma once
1010
11+ #include < executorch/backends/aoti/export.h>
1112#include < executorch/backends/aoti/utils.h>
1213#include < executorch/runtime/core/error.h>
1314#include < executorch/runtime/core/exec_aten/exec_aten.h>
@@ -23,57 +24,89 @@ namespace aoti {
2324using executorch::runtime::Error;
2425using executorch::runtime::etensor::Tensor;
2526
27+ // Global storage for tensor metadata
28+ extern std::unordered_map<Tensor*, std::vector<int64_t >> tensor_to_sizes;
29+ extern std::unordered_map<Tensor*, std::vector<int64_t >> tensor_to_strides;
30+
2631extern " C" {
2732
2833// Common AOTI type aliases
2934using AOTIRuntimeError = Error;
3035using AOTITorchError = Error;
3136
32- // Global storage for tensor metadata
33- extern std::unordered_map<Tensor*, std::vector<int64_t >> tensor_to_sizes;
34- extern std::unordered_map<Tensor*, std::vector<int64_t >> tensor_to_strides;
35-
3637// Attribute-related operations (memory-irrelevant)
37- AOTITorchError aoti_torch_get_data_ptr (Tensor* tensor, void ** ret_data_ptr);
38+ AOTI_SHIM_EXPORT AOTITorchError
39+ aoti_torch_get_data_ptr (Tensor* tensor, void ** ret_data_ptr);
3840
39- AOTITorchError aoti_torch_get_storage_offset (
40- Tensor* tensor,
41- int64_t * ret_storage_offset);
41+ AOTI_SHIM_EXPORT AOTITorchError
42+ aoti_torch_get_storage_offset (Tensor* tensor, int64_t * ret_storage_offset);
4243
43- AOTITorchError aoti_torch_get_strides (Tensor* tensor, int64_t ** ret_strides);
44+ AOTI_SHIM_EXPORT AOTITorchError
45+ aoti_torch_get_strides (Tensor* tensor, int64_t ** ret_strides);
4446
45- AOTITorchError aoti_torch_get_dtype (Tensor* tensor, int32_t * ret_dtype);
47+ AOTI_SHIM_EXPORT AOTITorchError
48+ aoti_torch_get_dtype (Tensor* tensor, int32_t * ret_dtype);
4649
47- AOTITorchError aoti_torch_get_sizes (Tensor* tensor, int64_t ** ret_sizes);
50+ AOTI_SHIM_EXPORT AOTITorchError
51+ aoti_torch_get_sizes (Tensor* tensor, int64_t ** ret_sizes);
4852
49- AOTITorchError aoti_torch_get_storage_size (Tensor* tensor, int64_t * ret_size);
53+ AOTI_SHIM_EXPORT AOTITorchError
54+ aoti_torch_get_storage_size (Tensor* tensor, int64_t * ret_size);
5055
51- AOTITorchError aoti_torch_get_device_index (
52- Tensor* tensor,
53- int32_t * ret_device_index);
56+ AOTI_SHIM_EXPORT AOTITorchError
57+ aoti_torch_get_device_index (Tensor* tensor, int32_t * ret_device_index);
5458
55- AOTITorchError aoti_torch_get_dim (Tensor* tensor, int64_t * ret_dim);
59+ AOTI_SHIM_EXPORT AOTITorchError
60+ aoti_torch_get_dim (Tensor* tensor, int64_t * ret_dim);
5661
5762// Utility functions for device and layout information
58- int32_t aoti_torch_device_type_cpu ();
59- int32_t aoti_torch_layout_strided ();
60- int32_t aoti_torch_dtype_float32 ();
61- int32_t aoti_torch_dtype_bfloat16 ();
62- int32_t aoti_torch_dtype_int8 ();
63- int32_t aoti_torch_dtype_int16 ();
64- int32_t aoti_torch_dtype_int32 ();
65- int32_t aoti_torch_dtype_int64 ();
66- int32_t aoti_torch_dtype_bool ();
63+ AOTI_SHIM_EXPORT int32_t aoti_torch_device_type_cpu ();
64+ AOTI_SHIM_EXPORT int32_t aoti_torch_layout_strided ();
65+ AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_float32 ();
66+ AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bfloat16 ();
67+ AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int8 ();
68+ AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int16 ();
69+ AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int32 ();
70+ AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int64 ();
6771
6872// Dtype utility function needed by Metal backend
69- size_t aoti_torch_dtype_element_size (int32_t dtype);
73+ AOTI_SHIM_EXPORT size_t aoti_torch_dtype_element_size (int32_t dtype);
7074
7175// Autograd mode functions
72- int32_t aoti_torch_grad_mode_is_enabled ();
73- void aoti_torch_grad_mode_set_enabled (bool enabled);
76+ AOTI_SHIM_EXPORT int32_t aoti_torch_grad_mode_is_enabled ();
77+ AOTI_SHIM_EXPORT void aoti_torch_grad_mode_set_enabled (bool enabled);
7478
7579// Cleanup functions for clearing global state
76- void cleanup_tensor_metadata ();
80+ AOTI_SHIM_EXPORT void cleanup_tensor_metadata ();
81+
82+ AOTI_SHIM_EXPORT void aoti_torch_warn (
83+ const char * func,
84+ const char * file,
85+ uint32_t line,
86+ const char * msg);
87+
88+ AOTI_SHIM_EXPORT AOTITorchError
89+ aoti_torch_get_storage_size (Tensor* tensor, int64_t * ret_size);
90+
91+ AOTI_SHIM_EXPORT AOTITorchError
92+ aoti_torch_clone_preserve_strides (Tensor* self, Tensor** ret_new_tensor);
93+
94+ AOTI_SHIM_EXPORT AOTITorchError
95+ aoti_torch_clone (Tensor* self, Tensor** ret_new_tensor);
96+
97+ AOTI_SHIM_EXPORT AOTITorchError
98+ aoti_torch_new_tensor_handle (Tensor* orig_handle, Tensor** new_handle);
99+
100+ AOTI_SHIM_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob (
101+ void * data_ptr,
102+ int64_t ndim,
103+ const int64_t * sizes,
104+ const int64_t * strides,
105+ int64_t storage_offset,
106+ int32_t dtype,
107+ int32_t device_type,
108+ int32_t device_index,
109+ Tensor** ret_new_tensor);
77110
78111} // extern "C"
79112
0 commit comments