1 #include <ATen/Config.h>
2 #include <torch/csrc/inductor/aoti_torch/mkldnn_tensor.h>
3
4 #if AT_MKLDNN_ENABLED()
5 #include <ATen/native/mkldnn/MKLDNNCommon.h>
6 #include <ideep.hpp>
7 #endif
8
9 namespace torch::aot_inductor {
10
11 #if AT_MKLDNN_ENABLED()
12
data_ptr_from_mkldnn(at::Tensor * mkldnn_tensor)13 void* data_ptr_from_mkldnn(at::Tensor* mkldnn_tensor) {
14 return reinterpret_cast<void*>(
15 at::native::data_ptr_from_mkldnn(*mkldnn_tensor));
16 }
17
mkldnn_tensor_from_data_ptr(void * data_ptr,at::IntArrayRef dims,at::ScalarType dtype,at::Device device,const uint8_t * opaque_metadata,int64_t opaque_metadata_size)18 at::Tensor mkldnn_tensor_from_data_ptr(
19 void* data_ptr,
20 at::IntArrayRef dims,
21 at::ScalarType dtype,
22 at::Device device,
23 const uint8_t* opaque_metadata,
24 int64_t opaque_metadata_size) {
25 return at::native::mkldnn_tensor_from_data_ptr(
26 data_ptr, dims, dtype, device, opaque_metadata, opaque_metadata_size);
27 }
28
29 #else
30
31 void* data_ptr_from_mkldnn(at::Tensor* mkldnn_tensor) {
32 TORCH_CHECK(false, "MKL-DNN build is disabled");
33 }
34
35 at::Tensor mkldnn_tensor_from_data_ptr(
36 void* data_ptr,
37 at::IntArrayRef dims,
38 at::ScalarType dtype,
39 at::Device device,
40 const uint8_t* opaque_metadata,
41 int64_t opaque_metadata_size) {
42 TORCH_CHECK(false, "MKL-DNN build is disabled");
43 }
44
45 #endif
46
47 } // namespace torch::aot_inductor
48