xref: /aosp_15_r20/external/pytorch/pt_template_srcs.bzl (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# This file keeps a list of PyTorch source files that are used for templated selective build.
2# NB: as this is PyTorch Edge selective build, we assume only CPU targets are
3# being built
4
5load("@bazel_skylib//lib:paths.bzl", "paths")
6load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode")
7load(":build_variables.bzl", "aten_native_source_list")
8load(
9    ":ufunc_defs.bzl",
10    "aten_ufunc_generated_cpu_kernel_sources",
11    "aten_ufunc_generated_cpu_sources",
12)
13
14# Files in this list are supposed to be built separately for each app,
15# for different operator allow lists.
16TEMPLATE_SOURCE_LIST = [
17    "torch/csrc/jit/runtime/register_prim_ops.cpp",
18    "torch/csrc/jit/runtime/register_special_ops.cpp",
19] + aten_native_source_list
20
21# For selective build, we can lump the CPU and CPU kernel sources altogether
22# because there is only ever one vectorization variant that is compiled
23def aten_ufunc_generated_all_cpu_sources(gencode_pattern = "{}"):
24    return (
25        aten_ufunc_generated_cpu_sources(gencode_pattern) +
26        aten_ufunc_generated_cpu_kernel_sources(gencode_pattern)
27    )
28
29TEMPLATE_MASKRCNN_SOURCE_LIST = [
30    "register_maskrcnn_ops.cpp",
31]
32
33TEMPLATE_BATCH_BOX_COX_SOURCE_LIST = [
34    "register_batch_box_cox_ops.cpp",
35]
36
37METAL_SOURCE_LIST = [
38    "aten/src/ATen/native/metal/MetalAten.mm",
39    "aten/src/ATen/native/metal/MetalGuardImpl.cpp",
40    "aten/src/ATen/native/metal/MetalPrepackOpRegister.cpp",
41    "aten/src/ATen/native/metal/MetalCommandBuffer.mm",
42    "aten/src/ATen/native/metal/MetalContext.mm",
43    "aten/src/ATen/native/metal/MetalConvParams.mm",
44    "aten/src/ATen/native/metal/MetalTensorImplStorage.mm",
45    "aten/src/ATen/native/metal/MetalTensorUtils.mm",
46    "aten/src/ATen/native/metal/mpscnn/MPSCNNClampOp.mm",
47    "aten/src/ATen/native/metal/mpscnn/MPSCNNConvOp.mm",
48    "aten/src/ATen/native/metal/mpscnn/MPSCNNFullyConnectedOp.mm",
49    "aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.mm",
50    "aten/src/ATen/native/metal/mpscnn/MPSCNNUtils.mm",
51    "aten/src/ATen/native/metal/mpscnn/MPSImage+Tensor.mm",
52    "aten/src/ATen/native/metal/mpscnn/MPSImageUtils.mm",
53    "aten/src/ATen/native/metal/mpscnn/MPSImageWrapper.mm",
54    "aten/src/ATen/native/metal/ops/MetalAddmm.mm",
55    "aten/src/ATen/native/metal/ops/MetalBinaryElementwise.mm",
56    "aten/src/ATen/native/metal/ops/MetalChunk.mm",
57    "aten/src/ATen/native/metal/ops/MetalClamp.mm",
58    "aten/src/ATen/native/metal/ops/MetalConcat.mm",
59    "aten/src/ATen/native/metal/ops/MetalConvolution.mm",
60    "aten/src/ATen/native/metal/ops/MetalCopy.mm",
61    "aten/src/ATen/native/metal/ops/MetalHardswish.mm",
62    "aten/src/ATen/native/metal/ops/MetalHardshrink.mm",
63    "aten/src/ATen/native/metal/ops/MetalLeakyReLU.mm",
64    "aten/src/ATen/native/metal/ops/MetalNeurons.mm",
65    "aten/src/ATen/native/metal/ops/MetalPadding.mm",
66    "aten/src/ATen/native/metal/ops/MetalPooling.mm",
67    "aten/src/ATen/native/metal/ops/MetalReduce.mm",
68    "aten/src/ATen/native/metal/ops/MetalReshape.mm",
69    "aten/src/ATen/native/metal/ops/MetalSoftmax.mm",
70    "aten/src/ATen/native/metal/ops/MetalTranspose.mm",
71    "aten/src/ATen/native/metal/ops/MetalUpsamplingNearest.mm",
72]
73
74UNET_METAL_PREPACK_SOURCE_LIST = [
75    "unet_metal_prepack.cpp",
76    "unet_metal_prepack.mm",
77]
78
79METAL_MASKRCNN_SOURCE_LIST = [
80    "maskrcnn/srcs/GenerateProposals.mm",
81    "maskrcnn/srcs/RoIAlign.mm",
82]
83
84# The get_template_source_dict() returns a dict containing a path prefix
85# and a list of .cpp source files containing operator definitions and
86# registrations that should get selected via templated selective build.
87# The file selected_mobile_ops.h has the list of selected top level
88# operators.
89# NB: doesn't include generated files; copy_template_registration_files
90# handles those specially
91def get_template_source_dict():
92    ret = {}
93    for file_path in TEMPLATE_SOURCE_LIST:
94        path_prefix = paths.dirname(file_path)
95        if path_prefix not in ret:
96            ret[path_prefix] = []
97        ret[path_prefix].append(file_path)
98    return ret
99
100def get_gen_oplist_outs():
101    return {
102        "SupportedMobileModelsRegistration.cpp": [
103            "SupportedMobileModelsRegistration.cpp",
104        ],
105        "selected_mobile_ops.h": [
106            "selected_mobile_ops.h",
107        ],
108        "selected_operators.yaml": [
109            "selected_operators.yaml",
110        ],
111    }
112
113def get_generate_code_bin_outs():
114    outs = {
115        "autograd/generated/ADInplaceOrViewTypeEverything.cpp": ["autograd/generated/ADInplaceOrViewTypeEverything.cpp"],
116        "autograd/generated/ADInplaceOrViewType_0.cpp": ["autograd/generated/ADInplaceOrViewType_0.cpp"],
117        "autograd/generated/ADInplaceOrViewType_1.cpp": ["autograd/generated/ADInplaceOrViewType_1.cpp"],
118        "autograd/generated/Functions.cpp": ["autograd/generated/Functions.cpp"],
119        "autograd/generated/Functions.h": ["autograd/generated/Functions.h"],
120        "autograd/generated/TraceTypeEverything.cpp": ["autograd/generated/TraceTypeEverything.cpp"],
121        "autograd/generated/TraceType_0.cpp": ["autograd/generated/TraceType_0.cpp"],
122        "autograd/generated/TraceType_1.cpp": ["autograd/generated/TraceType_1.cpp"],
123        "autograd/generated/TraceType_2.cpp": ["autograd/generated/TraceType_2.cpp"],
124        "autograd/generated/TraceType_3.cpp": ["autograd/generated/TraceType_3.cpp"],
125        "autograd/generated/TraceType_4.cpp": ["autograd/generated/TraceType_4.cpp"],
126        "autograd/generated/VariableType.h": ["autograd/generated/VariableType.h"],
127        "autograd/generated/VariableTypeEverything.cpp": ["autograd/generated/VariableTypeEverything.cpp"],
128        "autograd/generated/VariableType_0.cpp": ["autograd/generated/VariableType_0.cpp"],
129        "autograd/generated/VariableType_1.cpp": ["autograd/generated/VariableType_1.cpp"],
130        "autograd/generated/VariableType_2.cpp": ["autograd/generated/VariableType_2.cpp"],
131        "autograd/generated/VariableType_3.cpp": ["autograd/generated/VariableType_3.cpp"],
132        "autograd/generated/VariableType_4.cpp": ["autograd/generated/VariableType_4.cpp"],
133        "autograd/generated/variable_factories.h": ["autograd/generated/variable_factories.h"],
134        "autograd/generated/ViewFuncs.cpp": ["autograd/generated/ViewFuncs.cpp"],
135        "autograd/generated/ViewFuncs.h": ["autograd/generated/ViewFuncs.h"],
136    }
137
138    if is_arvr_mode():
139        outs.update({
140            "autograd/generated/python_enum_tag.cpp": ["autograd/generated/python_enum_tag.cpp"],
141            "autograd/generated/python_fft_functions.cpp": ["autograd/generated/python_fft_functions.cpp"],
142            "autograd/generated/python_functions.h": ["autograd/generated/python_functions.h"],
143            "autograd/generated/python_functions_0.cpp": ["autograd/generated/python_functions_0.cpp"],
144            "autograd/generated/python_functions_1.cpp": ["autograd/generated/python_functions_1.cpp"],
145            "autograd/generated/python_functions_2.cpp": ["autograd/generated/python_functions_2.cpp"],
146            "autograd/generated/python_functions_3.cpp": ["autograd/generated/python_functions_3.cpp"],
147            "autograd/generated/python_functions_4.cpp": ["autograd/generated/python_functions_4.cpp"],
148            "autograd/generated/python_linalg_functions.cpp": ["autograd/generated/python_linalg_functions.cpp"],
149            "autograd/generated/python_nested_functions.cpp": ["autograd/generated/python_nested_functions.cpp"],
150            "autograd/generated/python_nn_functions.cpp": ["autograd/generated/python_nn_functions.cpp"],
151            "autograd/generated/python_return_types.h": ["autograd/generated/python_return_types.h"],
152            "autograd/generated/python_return_types.cpp": ["autograd/generated/python_return_types.cpp"],
153            "autograd/generated/python_sparse_functions.cpp": ["autograd/generated/python_sparse_functions.cpp"],
154            "autograd/generated/python_special_functions.cpp": ["autograd/generated/python_special_functions.cpp"],
155            "autograd/generated/python_torch_functions_0.cpp": ["autograd/generated/python_torch_functions_0.cpp"],
156            "autograd/generated/python_torch_functions_1.cpp": ["autograd/generated/python_torch_functions_1.cpp"],
157            "autograd/generated/python_torch_functions_2.cpp": ["autograd/generated/python_torch_functions_2.cpp"],
158            "autograd/generated/python_variable_methods.cpp": ["autograd/generated/python_variable_methods.cpp"],
159        })
160    return outs
161
162def get_template_registration_files_outs(is_oss = False):
163    outs = {}
164    if not is_oss:
165        for file_path in TEMPLATE_MASKRCNN_SOURCE_LIST:
166            outs[file_path] = [file_path]
167
168        for file_path in TEMPLATE_BATCH_BOX_COX_SOURCE_LIST:
169            outs[file_path] = [file_path]
170
171    for file_path in TEMPLATE_SOURCE_LIST:
172        outs[file_path] = [file_path]
173
174    for base_name in aten_ufunc_generated_all_cpu_sources():
175        file_path = "aten/src/ATen/{}".format(base_name)
176        outs[file_path] = [file_path]
177
178    return outs
179
180def get_template_registration_file_rules(rule_name, is_oss = False):
181    rules = []
182    for file_path in TEMPLATE_SOURCE_LIST if is_oss else (TEMPLATE_SOURCE_LIST + TEMPLATE_MASKRCNN_SOURCE_LIST + TEMPLATE_BATCH_BOX_COX_SOURCE_LIST):
183        rules.append(":{}[{}]".format(rule_name, file_path))
184    for file_path in aten_ufunc_generated_all_cpu_sources():
185        rules.append(":{}[aten/src/ATen/{}]".format(rule_name, file_path))
186
187    return rules
188
189# ---------------------METAL RULES---------------------
190def get_metal_source_dict():
191    ret = {}
192    for file_path in METAL_SOURCE_LIST:
193        path_prefix = paths.dirname(file_path)
194        if path_prefix not in ret:
195            ret[path_prefix] = []
196        ret[path_prefix].append(file_path)
197    return ret
198
199def get_metal_registration_files_outs():
200    outs = {}
201    for file_path in METAL_SOURCE_LIST:
202        outs[file_path] = [file_path]
203
204    for file_path in UNET_METAL_PREPACK_SOURCE_LIST:
205        outs[file_path] = [file_path]
206
207    for file_path in METAL_MASKRCNN_SOURCE_LIST:
208        outs[file_path] = [file_path]
209    return outs
210
211# There is a really weird issue with the arvr windows builds where
212# the custom op files are breaking them. See https://fburl.com/za87443c
213# The hack is just to not build them for that platform and pray they arent needed.
214def get_metal_registration_files_outs_windows():
215    outs = {}
216    for file_path in METAL_SOURCE_LIST:
217        outs[file_path] = [file_path]
218    return outs
219
220def get_metal_registration_files_rules(rule_name):
221    ret = {}
222    objc_rules = []
223    cxx_rules = []
224
225    for file_path in METAL_SOURCE_LIST + METAL_MASKRCNN_SOURCE_LIST + UNET_METAL_PREPACK_SOURCE_LIST:
226        if ".cpp" not in file_path:
227            objc_rules.append(":{}[{}]".format(rule_name, file_path))
228        else:
229            cxx_rules.append(":{}[{}]".format(rule_name, file_path))
230    ret["objc"] = objc_rules
231    ret["cxx"] = cxx_rules
232    return ret
233
234def get_metal_registration_files_rules_windows(rule_name):
235    ret = {}
236    objc_rules = []
237    cxx_rules = []
238
239    for file_path in METAL_SOURCE_LIST:
240        if ".cpp" not in file_path:
241            objc_rules.append(":{}[{}]".format(rule_name, file_path))
242        else:
243            cxx_rules.append(":{}[{}]".format(rule_name, file_path))
244    ret["objc"] = objc_rules
245    ret["cxx"] = cxx_rules
246    return ret
247