1# This file keeps a list of PyTorch source files that are used for templated selective build. 2# NB: as this is PyTorch Edge selective build, we assume only CPU targets are 3# being built 4 5load("@bazel_skylib//lib:paths.bzl", "paths") 6load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode") 7load(":build_variables.bzl", "aten_native_source_list") 8load( 9 ":ufunc_defs.bzl", 10 "aten_ufunc_generated_cpu_kernel_sources", 11 "aten_ufunc_generated_cpu_sources", 12) 13 14# Files in this list are supposed to be built separately for each app, 15# for different operator allow lists. 16TEMPLATE_SOURCE_LIST = [ 17 "torch/csrc/jit/runtime/register_prim_ops.cpp", 18 "torch/csrc/jit/runtime/register_special_ops.cpp", 19] + aten_native_source_list 20 21# For selective build, we can lump the CPU and CPU kernel sources altogether 22# because there is only ever one vectorization variant that is compiled 23def aten_ufunc_generated_all_cpu_sources(gencode_pattern = "{}"): 24 return ( 25 aten_ufunc_generated_cpu_sources(gencode_pattern) + 26 aten_ufunc_generated_cpu_kernel_sources(gencode_pattern) 27 ) 28 29TEMPLATE_MASKRCNN_SOURCE_LIST = [ 30 "register_maskrcnn_ops.cpp", 31] 32 33TEMPLATE_BATCH_BOX_COX_SOURCE_LIST = [ 34 "register_batch_box_cox_ops.cpp", 35] 36 37METAL_SOURCE_LIST = [ 38 "aten/src/ATen/native/metal/MetalAten.mm", 39 "aten/src/ATen/native/metal/MetalGuardImpl.cpp", 40 "aten/src/ATen/native/metal/MetalPrepackOpRegister.cpp", 41 "aten/src/ATen/native/metal/MetalCommandBuffer.mm", 42 "aten/src/ATen/native/metal/MetalContext.mm", 43 "aten/src/ATen/native/metal/MetalConvParams.mm", 44 "aten/src/ATen/native/metal/MetalTensorImplStorage.mm", 45 "aten/src/ATen/native/metal/MetalTensorUtils.mm", 46 "aten/src/ATen/native/metal/mpscnn/MPSCNNClampOp.mm", 47 "aten/src/ATen/native/metal/mpscnn/MPSCNNConvOp.mm", 48 "aten/src/ATen/native/metal/mpscnn/MPSCNNFullyConnectedOp.mm", 49 "aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.mm", 50 "aten/src/ATen/native/metal/mpscnn/MPSCNNUtils.mm", 51 "aten/src/ATen/native/metal/mpscnn/MPSImage+Tensor.mm", 52 "aten/src/ATen/native/metal/mpscnn/MPSImageUtils.mm", 53 "aten/src/ATen/native/metal/mpscnn/MPSImageWrapper.mm", 54 "aten/src/ATen/native/metal/ops/MetalAddmm.mm", 55 "aten/src/ATen/native/metal/ops/MetalBinaryElementwise.mm", 56 "aten/src/ATen/native/metal/ops/MetalChunk.mm", 57 "aten/src/ATen/native/metal/ops/MetalClamp.mm", 58 "aten/src/ATen/native/metal/ops/MetalConcat.mm", 59 "aten/src/ATen/native/metal/ops/MetalConvolution.mm", 60 "aten/src/ATen/native/metal/ops/MetalCopy.mm", 61 "aten/src/ATen/native/metal/ops/MetalHardswish.mm", 62 "aten/src/ATen/native/metal/ops/MetalHardshrink.mm", 63 "aten/src/ATen/native/metal/ops/MetalLeakyReLU.mm", 64 "aten/src/ATen/native/metal/ops/MetalNeurons.mm", 65 "aten/src/ATen/native/metal/ops/MetalPadding.mm", 66 "aten/src/ATen/native/metal/ops/MetalPooling.mm", 67 "aten/src/ATen/native/metal/ops/MetalReduce.mm", 68 "aten/src/ATen/native/metal/ops/MetalReshape.mm", 69 "aten/src/ATen/native/metal/ops/MetalSoftmax.mm", 70 "aten/src/ATen/native/metal/ops/MetalTranspose.mm", 71 "aten/src/ATen/native/metal/ops/MetalUpsamplingNearest.mm", 72] 73 74UNET_METAL_PREPACK_SOURCE_LIST = [ 75 "unet_metal_prepack.cpp", 76 "unet_metal_prepack.mm", 77] 78 79METAL_MASKRCNN_SOURCE_LIST = [ 80 "maskrcnn/srcs/GenerateProposals.mm", 81 "maskrcnn/srcs/RoIAlign.mm", 82] 83 84# The get_template_source_dict() returns a dict containing a path prefix 85# and a list of .cpp source files containing operator definitions and 86# registrations that should get selected via templated selective build. 87# The file selected_mobile_ops.h has the list of selected top level 88# operators. 89# NB: doesn't include generated files; copy_template_registration_files 90# handles those specially 91def get_template_source_dict(): 92 ret = {} 93 for file_path in TEMPLATE_SOURCE_LIST: 94 path_prefix = paths.dirname(file_path) 95 if path_prefix not in ret: 96 ret[path_prefix] = [] 97 ret[path_prefix].append(file_path) 98 return ret 99 100def get_gen_oplist_outs(): 101 return { 102 "SupportedMobileModelsRegistration.cpp": [ 103 "SupportedMobileModelsRegistration.cpp", 104 ], 105 "selected_mobile_ops.h": [ 106 "selected_mobile_ops.h", 107 ], 108 "selected_operators.yaml": [ 109 "selected_operators.yaml", 110 ], 111 } 112 113def get_generate_code_bin_outs(): 114 outs = { 115 "autograd/generated/ADInplaceOrViewTypeEverything.cpp": ["autograd/generated/ADInplaceOrViewTypeEverything.cpp"], 116 "autograd/generated/ADInplaceOrViewType_0.cpp": ["autograd/generated/ADInplaceOrViewType_0.cpp"], 117 "autograd/generated/ADInplaceOrViewType_1.cpp": ["autograd/generated/ADInplaceOrViewType_1.cpp"], 118 "autograd/generated/Functions.cpp": ["autograd/generated/Functions.cpp"], 119 "autograd/generated/Functions.h": ["autograd/generated/Functions.h"], 120 "autograd/generated/TraceTypeEverything.cpp": ["autograd/generated/TraceTypeEverything.cpp"], 121 "autograd/generated/TraceType_0.cpp": ["autograd/generated/TraceType_0.cpp"], 122 "autograd/generated/TraceType_1.cpp": ["autograd/generated/TraceType_1.cpp"], 123 "autograd/generated/TraceType_2.cpp": ["autograd/generated/TraceType_2.cpp"], 124 "autograd/generated/TraceType_3.cpp": ["autograd/generated/TraceType_3.cpp"], 125 "autograd/generated/TraceType_4.cpp": ["autograd/generated/TraceType_4.cpp"], 126 "autograd/generated/VariableType.h": ["autograd/generated/VariableType.h"], 127 "autograd/generated/VariableTypeEverything.cpp": ["autograd/generated/VariableTypeEverything.cpp"], 128 "autograd/generated/VariableType_0.cpp": ["autograd/generated/VariableType_0.cpp"], 129 "autograd/generated/VariableType_1.cpp": ["autograd/generated/VariableType_1.cpp"], 130 "autograd/generated/VariableType_2.cpp": ["autograd/generated/VariableType_2.cpp"], 131 "autograd/generated/VariableType_3.cpp": ["autograd/generated/VariableType_3.cpp"], 132 "autograd/generated/VariableType_4.cpp": ["autograd/generated/VariableType_4.cpp"], 133 "autograd/generated/variable_factories.h": ["autograd/generated/variable_factories.h"], 134 "autograd/generated/ViewFuncs.cpp": ["autograd/generated/ViewFuncs.cpp"], 135 "autograd/generated/ViewFuncs.h": ["autograd/generated/ViewFuncs.h"], 136 } 137 138 if is_arvr_mode(): 139 outs.update({ 140 "autograd/generated/python_enum_tag.cpp": ["autograd/generated/python_enum_tag.cpp"], 141 "autograd/generated/python_fft_functions.cpp": ["autograd/generated/python_fft_functions.cpp"], 142 "autograd/generated/python_functions.h": ["autograd/generated/python_functions.h"], 143 "autograd/generated/python_functions_0.cpp": ["autograd/generated/python_functions_0.cpp"], 144 "autograd/generated/python_functions_1.cpp": ["autograd/generated/python_functions_1.cpp"], 145 "autograd/generated/python_functions_2.cpp": ["autograd/generated/python_functions_2.cpp"], 146 "autograd/generated/python_functions_3.cpp": ["autograd/generated/python_functions_3.cpp"], 147 "autograd/generated/python_functions_4.cpp": ["autograd/generated/python_functions_4.cpp"], 148 "autograd/generated/python_linalg_functions.cpp": ["autograd/generated/python_linalg_functions.cpp"], 149 "autograd/generated/python_nested_functions.cpp": ["autograd/generated/python_nested_functions.cpp"], 150 "autograd/generated/python_nn_functions.cpp": ["autograd/generated/python_nn_functions.cpp"], 151 "autograd/generated/python_return_types.h": ["autograd/generated/python_return_types.h"], 152 "autograd/generated/python_return_types.cpp": ["autograd/generated/python_return_types.cpp"], 153 "autograd/generated/python_sparse_functions.cpp": ["autograd/generated/python_sparse_functions.cpp"], 154 "autograd/generated/python_special_functions.cpp": ["autograd/generated/python_special_functions.cpp"], 155 "autograd/generated/python_torch_functions_0.cpp": ["autograd/generated/python_torch_functions_0.cpp"], 156 "autograd/generated/python_torch_functions_1.cpp": ["autograd/generated/python_torch_functions_1.cpp"], 157 "autograd/generated/python_torch_functions_2.cpp": ["autograd/generated/python_torch_functions_2.cpp"], 158 "autograd/generated/python_variable_methods.cpp": ["autograd/generated/python_variable_methods.cpp"], 159 }) 160 return outs 161 162def get_template_registration_files_outs(is_oss = False): 163 outs = {} 164 if not is_oss: 165 for file_path in TEMPLATE_MASKRCNN_SOURCE_LIST: 166 outs[file_path] = [file_path] 167 168 for file_path in TEMPLATE_BATCH_BOX_COX_SOURCE_LIST: 169 outs[file_path] = [file_path] 170 171 for file_path in TEMPLATE_SOURCE_LIST: 172 outs[file_path] = [file_path] 173 174 for base_name in aten_ufunc_generated_all_cpu_sources(): 175 file_path = "aten/src/ATen/{}".format(base_name) 176 outs[file_path] = [file_path] 177 178 return outs 179 180def get_template_registration_file_rules(rule_name, is_oss = False): 181 rules = [] 182 for file_path in TEMPLATE_SOURCE_LIST if is_oss else (TEMPLATE_SOURCE_LIST + TEMPLATE_MASKRCNN_SOURCE_LIST + TEMPLATE_BATCH_BOX_COX_SOURCE_LIST): 183 rules.append(":{}[{}]".format(rule_name, file_path)) 184 for file_path in aten_ufunc_generated_all_cpu_sources(): 185 rules.append(":{}[aten/src/ATen/{}]".format(rule_name, file_path)) 186 187 return rules 188 189# ---------------------METAL RULES--------------------- 190def get_metal_source_dict(): 191 ret = {} 192 for file_path in METAL_SOURCE_LIST: 193 path_prefix = paths.dirname(file_path) 194 if path_prefix not in ret: 195 ret[path_prefix] = [] 196 ret[path_prefix].append(file_path) 197 return ret 198 199def get_metal_registration_files_outs(): 200 outs = {} 201 for file_path in METAL_SOURCE_LIST: 202 outs[file_path] = [file_path] 203 204 for file_path in UNET_METAL_PREPACK_SOURCE_LIST: 205 outs[file_path] = [file_path] 206 207 for file_path in METAL_MASKRCNN_SOURCE_LIST: 208 outs[file_path] = [file_path] 209 return outs 210 211# There is a really weird issue with the arvr windows builds where 212# the custom op files are breaking them. See https://fburl.com/za87443c 213# The hack is just to not build them for that platform and pray they arent needed. 214def get_metal_registration_files_outs_windows(): 215 outs = {} 216 for file_path in METAL_SOURCE_LIST: 217 outs[file_path] = [file_path] 218 return outs 219 220def get_metal_registration_files_rules(rule_name): 221 ret = {} 222 objc_rules = [] 223 cxx_rules = [] 224 225 for file_path in METAL_SOURCE_LIST + METAL_MASKRCNN_SOURCE_LIST + UNET_METAL_PREPACK_SOURCE_LIST: 226 if ".cpp" not in file_path: 227 objc_rules.append(":{}[{}]".format(rule_name, file_path)) 228 else: 229 cxx_rules.append(":{}[{}]".format(rule_name, file_path)) 230 ret["objc"] = objc_rules 231 ret["cxx"] = cxx_rules 232 return ret 233 234def get_metal_registration_files_rules_windows(rule_name): 235 ret = {} 236 objc_rules = [] 237 cxx_rules = [] 238 239 for file_path in METAL_SOURCE_LIST: 240 if ".cpp" not in file_path: 241 objc_rules.append(":{}[{}]".format(rule_name, file_path)) 242 else: 243 cxx_rules.append(":{}[{}]".format(rule_name, file_path)) 244 ret["objc"] = objc_rules 245 ret["cxx"] = cxx_rules 246 return ret 247