xref: /aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Config.h>
2 #include <torch/csrc/inductor/aoti_torch/mkldnn_tensor.h>
3 
4 #if AT_MKLDNN_ENABLED()
5 #include <ATen/native/mkldnn/MKLDNNCommon.h>
6 #include <ideep.hpp>
7 #endif
8 
9 namespace torch::aot_inductor {
10 
11 #if AT_MKLDNN_ENABLED()
12 
data_ptr_from_mkldnn(at::Tensor * mkldnn_tensor)13 void* data_ptr_from_mkldnn(at::Tensor* mkldnn_tensor) {
14   return reinterpret_cast<void*>(
15       at::native::data_ptr_from_mkldnn(*mkldnn_tensor));
16 }
17 
mkldnn_tensor_from_data_ptr(void * data_ptr,at::IntArrayRef dims,at::ScalarType dtype,at::Device device,const uint8_t * opaque_metadata,int64_t opaque_metadata_size)18 at::Tensor mkldnn_tensor_from_data_ptr(
19     void* data_ptr,
20     at::IntArrayRef dims,
21     at::ScalarType dtype,
22     at::Device device,
23     const uint8_t* opaque_metadata,
24     int64_t opaque_metadata_size) {
25   return at::native::mkldnn_tensor_from_data_ptr(
26       data_ptr, dims, dtype, device, opaque_metadata, opaque_metadata_size);
27 }
28 
29 #else
30 
31 void* data_ptr_from_mkldnn(at::Tensor* mkldnn_tensor) {
32   TORCH_CHECK(false, "MKL-DNN build is disabled");
33 }
34 
35 at::Tensor mkldnn_tensor_from_data_ptr(
36     void* data_ptr,
37     at::IntArrayRef dims,
38     at::ScalarType dtype,
39     at::Device device,
40     const uint8_t* opaque_metadata,
41     int64_t opaque_metadata_size) {
42   TORCH_CHECK(false, "MKL-DNN build is disabled");
43 }
44 
45 #endif
46 
47 } // namespace torch::aot_inductor
48